extensions/openai: load extension settings via settings.yaml (#3953)
This commit is contained in:
parent
cc8eda298a
commit
347aed4254
6 changed files with 48 additions and 16 deletions
|
@ -5,15 +5,25 @@ from extensions.openai.errors import ServiceUnavailableError
|
|||
from extensions.openai.utils import debug_msg, float_list_to_base64
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2"
|
||||
embeddings_model = None
|
||||
# OPENEDAI_EMBEDDING_DEVICE: auto (best or cpu), cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone
|
||||
embeddings_device = os.environ.get("OPENEDAI_EMBEDDING_DEVICE", "cpu")
|
||||
if embeddings_device.lower() == 'auto':
|
||||
embeddings_device = None
|
||||
embeddings_params_initialized = False
|
||||
# using 'lazy loading' to avoid circular import
|
||||
# so this function will be executed only once
|
||||
def initialize_embedding_params():
|
||||
global embeddings_params_initialized
|
||||
if not embeddings_params_initialized:
|
||||
global st_model, embeddings_model, embeddings_device
|
||||
from extensions.openai.script import params
|
||||
st_model = os.environ.get("OPENEDAI_EMBEDDING_MODEL", params.get('embedding_model', 'all-mpnet-base-v2'))
|
||||
embeddings_model = None
|
||||
# OPENEDAI_EMBEDDING_DEVICE: auto (best or cpu), cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone
|
||||
embeddings_device = os.environ.get("OPENEDAI_EMBEDDING_DEVICE", params.get('embedding_device', 'cpu'))
|
||||
if embeddings_device.lower() == 'auto':
|
||||
embeddings_device = None
|
||||
embeddings_params_initialized = True
|
||||
|
||||
|
||||
def load_embedding_model(model: str) -> SentenceTransformer:
|
||||
initialize_embedding_params()
|
||||
global embeddings_device, embeddings_model
|
||||
try:
|
||||
embeddings_model = 'loading...' # flag
|
||||
|
@ -29,6 +39,7 @@ def load_embedding_model(model: str) -> SentenceTransformer:
|
|||
|
||||
|
||||
def get_embeddings_model() -> SentenceTransformer:
|
||||
initialize_embedding_params()
|
||||
global embeddings_model, st_model
|
||||
if st_model and not embeddings_model:
|
||||
embeddings_model = load_embedding_model(st_model) # lazy load the model
|
||||
|
@ -36,6 +47,7 @@ def get_embeddings_model() -> SentenceTransformer:
|
|||
|
||||
|
||||
def get_embeddings_model_name() -> str:
|
||||
initialize_embedding_params()
|
||||
global st_model
|
||||
return st_model
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue