Make OpenAI API the default API (#4430)

This commit is contained in:
oobabooga 2023-11-06 02:38:29 -03:00 committed by GitHub
parent 84d957ba62
commit ec17a5d2b7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 769 additions and 1432 deletions

View file

@ -6,9 +6,13 @@ from extensions.openai.utils import debug_msg, float_list_to_base64
from sentence_transformers import SentenceTransformer
embeddings_params_initialized = False
# using 'lazy loading' to avoid circular import
# so this function will be executed only once
def initialize_embedding_params():
'''
using 'lazy loading' to avoid circular import
so this function will be executed only once
'''
global embeddings_params_initialized
if not embeddings_params_initialized:
global st_model, embeddings_model, embeddings_device
@ -26,7 +30,7 @@ def load_embedding_model(model: str) -> SentenceTransformer:
initialize_embedding_params()
global embeddings_device, embeddings_model
try:
print(f"\Try embedding model: {model} on {embeddings_device}")
print(f"Try embedding model: {model} on {embeddings_device}")
# see: https://www.sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer
embeddings_model = SentenceTransformer(model, device=embeddings_device)
# ... embeddings_model.device doesn't seem to work, always cpu anyways? but specify cpu anyways to free more VRAM
@ -54,7 +58,7 @@ def get_embeddings(input: list) -> np.ndarray:
model = get_embeddings_model()
debug_msg(f"embedding model : {model}")
embedding = model.encode(input, convert_to_numpy=True, normalize_embeddings=True, convert_to_tensor=False)
debug_msg(f"embedding result : {embedding}") # might be too long even for debug, use at you own will
debug_msg(f"embedding result : {embedding}") # might be too long even for debug, use at you own will
return embedding