extensions/openai: load extension settings via settings.yaml (#3953)

This commit is contained in:
Chenxiao Wang 2023-09-18 09:39:29 +08:00 committed by GitHub
parent cc8eda298a
commit 347aed4254
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 48 additions and 16 deletions

View file

@ -5,15 +5,25 @@ from extensions.openai.errors import ServiceUnavailableError
from extensions.openai.utils import debug_msg, float_list_to_base64
from sentence_transformers import SentenceTransformer
st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2"
embeddings_model = None
# OPENEDAI_EMBEDDING_DEVICE: auto (best or cpu), cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone
embeddings_device = os.environ.get("OPENEDAI_EMBEDDING_DEVICE", "cpu")
if embeddings_device.lower() == 'auto':
embeddings_device = None
embeddings_params_initialized = False
# using 'lazy loading' to avoid circular import
# so this function will be executed only once
def initialize_embedding_params():
global embeddings_params_initialized
if not embeddings_params_initialized:
global st_model, embeddings_model, embeddings_device
from extensions.openai.script import params
st_model = os.environ.get("OPENEDAI_EMBEDDING_MODEL", params.get('embedding_model', 'all-mpnet-base-v2'))
embeddings_model = None
# OPENEDAI_EMBEDDING_DEVICE: auto (best or cpu), cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone
embeddings_device = os.environ.get("OPENEDAI_EMBEDDING_DEVICE", params.get('embedding_device', 'cpu'))
if embeddings_device.lower() == 'auto':
embeddings_device = None
embeddings_params_initialized = True
def load_embedding_model(model: str) -> SentenceTransformer:
initialize_embedding_params()
global embeddings_device, embeddings_model
try:
embeddings_model = 'loading...' # flag
@ -29,6 +39,7 @@ def load_embedding_model(model: str) -> SentenceTransformer:
def get_embeddings_model() -> SentenceTransformer:
initialize_embedding_params()
global embeddings_model, st_model
if st_model and not embeddings_model:
embeddings_model = load_embedding_model(st_model) # lazy load the model
@ -36,6 +47,7 @@ def get_embeddings_model() -> SentenceTransformer:
def get_embeddings_model_name() -> str:
initialize_embedding_params()
global st_model
return st_model