Lint the openai extension

This commit is contained in:
oobabooga 2023-09-15 20:11:16 -07:00
parent 760510db52
commit 8f97e87cac
12 changed files with 79 additions and 69 deletions

View file

@ -1,8 +1,9 @@
import os
from sentence_transformers import SentenceTransformer
import numpy as np
from extensions.openai.utils import float_list_to_base64, debug_msg
from extensions.openai.errors import *
from extensions.openai.errors import ServiceUnavailableError
from extensions.openai.utils import debug_msg, float_list_to_base64
from sentence_transformers import SentenceTransformer
st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2"
embeddings_model = None
@ -11,10 +12,11 @@ embeddings_device = os.environ.get("OPENEDAI_EMBEDDING_DEVICE", "cpu")
if embeddings_device.lower() == 'auto':
embeddings_device = None
def load_embedding_model(model: str) -> SentenceTransformer:
global embeddings_device, embeddings_model
try:
embeddings_model = 'loading...' # flag
embeddings_model = 'loading...' # flag
# see: https://www.sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer
emb_model = SentenceTransformer(model, device=embeddings_device)
# ... emb_model.device doesn't seem to work, always cpu anyways? but specify cpu anyways to free more VRAM
@ -41,6 +43,7 @@ def get_embeddings_model_name() -> str:
def get_embeddings(input: list) -> np.ndarray:
return get_embeddings_model().encode(input, convert_to_numpy=True, normalize_embeddings=True, convert_to_tensor=False, device=embeddings_device)
def embeddings(input: list, encoding_format: str) -> dict:
embeddings = get_embeddings(input)