lint
This commit is contained in:
parent
9b55d3a9f9
commit
e202190c4f
24 changed files with 146 additions and 125 deletions
|
@ -6,26 +6,30 @@ from extensions.openai.errors import *
|
|||
st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2"
|
||||
embeddings_model = None
|
||||
|
||||
|
||||
def load_embedding_model(model):
|
||||
try:
|
||||
emb_model = SentenceTransformer(model)
|
||||
print(f"\nLoaded embedding model: {model}, max sequence length: {emb_model.max_seq_length}")
|
||||
except Exception as e:
|
||||
print(f"\nError: Failed to load embedding model: {model}")
|
||||
raise ServiceUnavailableError(f"Error: Failed to load embedding model: {model}", internal_message = repr(e))
|
||||
|
||||
raise ServiceUnavailableError(f"Error: Failed to load embedding model: {model}", internal_message=repr(e))
|
||||
|
||||
return emb_model
|
||||
|
||||
|
||||
def get_embeddings_model():
|
||||
global embeddings_model, st_model
|
||||
if st_model and not embeddings_model:
|
||||
embeddings_model = load_embedding_model(st_model) # lazy load the model
|
||||
embeddings_model = load_embedding_model(st_model) # lazy load the model
|
||||
return embeddings_model
|
||||
|
||||
|
||||
def get_embeddings_model_name():
|
||||
global st_model
|
||||
return st_model
|
||||
|
||||
|
||||
def embeddings(input: list, encoding_format: str):
|
||||
|
||||
embeddings = get_embeddings_model().encode(input).tolist()
|
||||
|
@ -47,4 +51,4 @@ def embeddings(input: list, encoding_format: str):
|
|||
|
||||
debug_msg(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}")
|
||||
|
||||
return response
|
||||
return response
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue