Lint the openai extension

This commit is contained in:
oobabooga 2023-09-15 20:11:16 -07:00
parent 760510db52
commit 8f97e87cac
12 changed files with 79 additions and 69 deletions

View file

@ -1,18 +1,15 @@
import time
import yaml
import tiktoken
import torch
import torch.nn.functional as F
from math import log, exp
from transformers import LogitsProcessor, LogitsProcessorList
import yaml
from extensions.openai.defaults import clamp, default, get_default_req_params
from extensions.openai.errors import InvalidRequestError
from extensions.openai.utils import debug_msg, end_line
from modules import shared
from modules.text_generation import encode, decode, generate_reply
from extensions.openai.defaults import get_default_req_params, default, clamp
from extensions.openai.utils import end_line, debug_msg
from extensions.openai.errors import *
from modules.text_generation import decode, encode, generate_reply
from transformers import LogitsProcessor, LogitsProcessorList
# Thanks to @Cypherfox [Cypherfoxy] for the logits code, blame to @matatonic
@ -21,7 +18,7 @@ class LogitsBiasProcessor(LogitsProcessor):
self.logit_bias = logit_bias
if self.logit_bias:
self.keys = list([int(key) for key in self.logit_bias.keys()])
values = [ self.logit_bias[str(key)] for key in self.keys ]
values = [self.logit_bias[str(key)] for key in self.keys]
self.values = torch.tensor(values, dtype=torch.float, device=shared.model.device)
debug_msg(f"{self})")
@ -36,6 +33,7 @@ class LogitsBiasProcessor(LogitsProcessor):
def __repr__(self):
return f"<{self.__class__.__name__}(logit_bias={self.logit_bias})>"
class LogprobProcessor(LogitsProcessor):
def __init__(self, logprobs=None):
self.logprobs = logprobs
@ -44,9 +42,9 @@ class LogprobProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor:
if self.logprobs is not None: # 0-5
log_e_probabilities = F.log_softmax(logits, dim=1)
top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs+1)
top_tokens = [ decode(tok) for tok in top_indices[0] ]
top_probs = [ float(x) for x in top_values[0] ]
top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs + 1)
top_tokens = [decode(tok) for tok in top_indices[0]]
top_probs = [float(x) for x in top_values[0]]
self.token_alternatives = dict(zip(top_tokens, top_probs))
debug_msg(repr(self))
return logits
@ -56,14 +54,15 @@ class LogprobProcessor(LogitsProcessor):
def convert_logprobs_to_tiktoken(model, logprobs):
# more problems than it's worth.
# try:
# encoder = tiktoken.encoding_for_model(model)
# # just pick the first one if it encodes to multiple tokens... 99.9% not required and maybe worse overall.
# return dict([(encoder.decode([encoder.encode(token)[0]]), prob) for token, prob in logprobs.items()])
# except KeyError:
# # assume native tokens if we can't find the tokenizer
# return logprobs
# more problems than it's worth.
# try:
# encoder = tiktoken.encoding_for_model(model)
# # just pick the first one if it encodes to multiple tokens... 99.9% not required and maybe worse overall.
# return dict([(encoder.decode([encoder.encode(token)[0]]), prob) for token, prob in logprobs.items()])
# except KeyError:
# # assume native tokens if we can't find the tokenizer
# return logprobs
return logprobs
@ -115,7 +114,7 @@ def marshal_common_params(body):
new_logit_bias = {}
for logit, bias in logit_bias.items():
for x in encode(encoder.decode([int(logit)]), add_special_tokens=False)[0]:
if int(x) in [0, 1, 2, 29871]: # XXX LLAMA tokens
if int(x) in [0, 1, 2, 29871]: # XXX LLAMA tokens
continue
new_logit_bias[str(int(x))] = bias
debug_msg('logit_bias_map', logit_bias, '->', new_logit_bias)
@ -146,7 +145,7 @@ def messages_to_prompt(body: dict, req_params: dict, max_tokens):
if body.get('function_call', ''): # chat only, 'none', 'auto', {'name': 'func'}
raise InvalidRequestError(message="function_call is not supported.", param='function_call')
if not 'messages' in body:
if 'messages' not in body:
raise InvalidRequestError(message="messages is required", param='messages')
messages = body['messages']
@ -159,7 +158,7 @@ def messages_to_prompt(body: dict, req_params: dict, max_tokens):
'prompt': 'Assistant:',
}
if not 'stopping_strings' in req_params:
if 'stopping_strings' not in req_params:
req_params['stopping_strings'] = []
# Instruct models can be much better
@ -169,7 +168,7 @@ def messages_to_prompt(body: dict, req_params: dict, max_tokens):
template = instruct['turn_template']
system_message_template = "{message}"
system_message_default = instruct.get('context', '') # can be missing
system_message_default = instruct.get('context', '') # can be missing
bot_start = template.find('<|bot|>') # So far, 100% of instruction templates have this token
user_message_template = template[:bot_start].replace('<|user-message|>', '{message}').replace('<|user|>', instruct.get('user', ''))
bot_message_template = template[bot_start:].replace('<|bot-message|>', '{message}').replace('<|bot|>', instruct.get('bot', ''))
@ -216,7 +215,7 @@ def messages_to_prompt(body: dict, req_params: dict, max_tokens):
raise InvalidRequestError(message="messages: missing role", param='messages')
if 'content' not in m:
raise InvalidRequestError(message="messages: missing content", param='messages')
role = m['role']
content = m['content']
# name = m.get('name', None)
@ -439,7 +438,7 @@ def completions(body: dict, is_legacy: bool = False):
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
prompt_str = 'context' if is_legacy else 'prompt'
if not prompt_str in body:
if prompt_str not in body:
raise InvalidRequestError("Missing required input", param=prompt_str)
prompt_arg = body[prompt_str]
@ -455,7 +454,7 @@ def completions(body: dict, is_legacy: bool = False):
requested_model = req_params.pop('requested_model')
logprob_proc = req_params.pop('logprob_proc', None)
stopping_strings = req_params.pop('stopping_strings', [])
#req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
# req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
req_params['echo'] = default(body, 'echo', req_params['echo'])
req_params['top_k'] = default(body, 'best_of', req_params['top_k'])
@ -538,10 +537,12 @@ def stream_completions(body: dict, is_legacy: bool = False):
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
prompt_str = 'context' if is_legacy else 'prompt'
if not prompt_str in body:
if prompt_str not in body:
raise InvalidRequestError("Missing required input", param=prompt_str)
prompt = body[prompt_str]
req_params = marshal_common_params(body)
requested_model = req_params.pop('requested_model')
if isinstance(prompt, list):
if prompt and isinstance(prompt[0], int):
try:
@ -553,15 +554,13 @@ def stream_completions(body: dict, is_legacy: bool = False):
raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str)
# common params
req_params = marshal_common_params(body)
req_params['stream'] = True
max_tokens_str = 'length' if is_legacy else 'max_tokens'
max_tokens = default(body, max_tokens_str, req_params['max_new_tokens'])
req_params['max_new_tokens'] = max_tokens
requested_model = req_params.pop('requested_model')
logprob_proc = req_params.pop('logprob_proc', None)
stopping_strings = req_params.pop('stopping_strings', [])
#req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
# req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
req_params['echo'] = default(body, 'echo', req_params['echo'])
req_params['top_k'] = default(body, 'best_of', req_params['top_k'])