extensions/openai: Fixes for: embeddings, tokens, better errors. +Docs update, +Images, +logit_bias/logprobs, +more. (#3122)

This commit is contained in:
matatonic 2023-07-24 10:28:12 -04:00 committed by GitHub
parent 1141987a0d
commit 90a4ab631c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 215 additions and 143 deletions

View file

@ -1,7 +1,7 @@
import time
import numpy as np
from numpy.linalg import norm
from extensions.openai.embeddings import get_embeddings_model
from extensions.openai.embeddings import get_embeddings
moderations_disabled = False # return 0/false
@ -11,21 +11,21 @@ categories = ["sexual", "hate", "harassment", "self-harm", "sexual/minors", "hat
flag_threshold = 0.5
def get_category_embeddings():
def get_category_embeddings() -> dict:
global category_embeddings, categories
if category_embeddings is None:
embeddings = get_embeddings_model().encode(categories).tolist()
embeddings = get_embeddings(categories).tolist()
category_embeddings = dict(zip(categories, embeddings))
return category_embeddings
def cosine_similarity(a, b):
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
return np.dot(a, b) / (norm(a) * norm(b))
# seems most openai like with all-mpnet-base-v2
def mod_score(a, b):
def mod_score(a: np.ndarray, b: np.ndarray) -> float:
return 2.0 * np.dot(a, b)
@ -37,8 +37,7 @@ def moderations(input):
"results": [],
}
embeddings_model = get_embeddings_model()
if not embeddings_model or moderations_disabled:
if moderations_disabled:
results['results'] = [{
'categories': dict([(C, False) for C in categories]),
'category_scores': dict([(C, 0.0) for C in categories]),
@ -53,7 +52,7 @@ def moderations(input):
input = [input]
for in_str in input:
for ine in embeddings_model.encode([in_str]).tolist():
for ine in get_embeddings([in_str]):
category_scores = dict([(C, mod_score(category_embeddings[C], ine)) for C in categories])
category_flags = dict([(C, bool(category_scores[C] > flag_threshold)) for C in categories])
flagged = any(category_flags.values())