extensions/openai: Fixes for: embeddings, tokens, better errors. +Docs update, +Images, +logit_bias/logprobs, +more. (#3122)
This commit is contained in:
parent
1141987a0d
commit
90a4ab631c
10 changed files with 215 additions and 143 deletions
|
@ -3,6 +3,7 @@ import yaml
|
|||
import tiktoken
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from math import log, exp
|
||||
|
||||
from transformers import LogitsProcessor, LogitsProcessorList
|
||||
|
||||
|
@ -18,41 +19,50 @@ from extensions.openai.errors import *
|
|||
class LogitsBiasProcessor(LogitsProcessor):
|
||||
def __init__(self, logit_bias={}):
|
||||
self.logit_bias = logit_bias
|
||||
super().__init__()
|
||||
if self.logit_bias:
|
||||
self.keys = list([int(key) for key in self.logit_bias.keys()])
|
||||
values = [ self.logit_bias[str(key)] for key in self.keys ]
|
||||
self.values = torch.tensor(values, dtype=torch.float, device=shared.model.device)
|
||||
debug_msg(f"{self})")
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if self.logit_bias:
|
||||
keys = list([int(key) for key in self.logit_bias.keys()])
|
||||
values = list([int(val) for val in self.logit_bias.values()])
|
||||
logits[0, keys] += torch.tensor(values).cuda()
|
||||
|
||||
debug_msg(logits[0, self.keys], " + ", self.values)
|
||||
logits[0, self.keys] += self.values
|
||||
debug_msg(" --> ", logits[0, self.keys])
|
||||
debug_msg(" max/min ", float(torch.max(logits[0])), float(torch.min(logits[0])))
|
||||
return logits
|
||||
|
||||
def __repr__(self):
|
||||
return f"<{self.__class__.__name__}(logit_bias={self.logit_bias})>"
|
||||
|
||||
class LogprobProcessor(LogitsProcessor):
|
||||
def __init__(self, logprobs=None):
|
||||
self.logprobs = logprobs
|
||||
self.token_alternatives = {}
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if self.logprobs is not None: # 0-5
|
||||
log_e_probabilities = F.log_softmax(logits, dim=1)
|
||||
# XXX hack. should find the selected token and include the prob of that
|
||||
# ... but we just +1 here instead because we don't know it yet.
|
||||
top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs + 1)
|
||||
top_tokens = [decode(tok) for tok in top_indices[0]]
|
||||
self.token_alternatives = dict(zip(top_tokens, top_values[0].tolist()))
|
||||
top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs+1)
|
||||
top_tokens = [ decode(tok) for tok in top_indices[0] ]
|
||||
top_probs = [ float(x) for x in top_values[0] ]
|
||||
self.token_alternatives = dict(zip(top_tokens, top_probs))
|
||||
debug_msg(f"{self.__class__.__name__}(logprobs+1={self.logprobs+1}, token_alternatives={self.token_alternatives})")
|
||||
return logits
|
||||
|
||||
def __repr__(self):
|
||||
return f"<{self.__class__.__name__}(logprobs={self.logprobs}, token_alternatives={self.token_alternatives})>"
|
||||
|
||||
|
||||
def convert_logprobs_to_tiktoken(model, logprobs):
|
||||
try:
|
||||
encoder = tiktoken.encoding_for_model(model)
|
||||
# just pick the first one if it encodes to multiple tokens... 99.9% not required and maybe worse overall.
|
||||
return dict([(encoder.decode([encoder.encode(token)[0]]), prob) for token, prob in logprobs.items()])
|
||||
except KeyError:
|
||||
# assume native tokens if we can't find the tokenizer
|
||||
# more problems than it's worth.
|
||||
# try:
|
||||
# encoder = tiktoken.encoding_for_model(model)
|
||||
# # just pick the first one if it encodes to multiple tokens... 99.9% not required and maybe worse overall.
|
||||
# return dict([(encoder.decode([encoder.encode(token)[0]]), prob) for token, prob in logprobs.items()])
|
||||
# except KeyError:
|
||||
# # assume native tokens if we can't find the tokenizer
|
||||
return logprobs
|
||||
|
||||
|
||||
|
@ -73,8 +83,8 @@ def marshal_common_params(body):
|
|||
req_params['requested_model'] = body.get('model', shared.model_name)
|
||||
|
||||
req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
|
||||
req_params['temperature'] = clamp(default(body, 'temperature', req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0/2.0
|
||||
req_params['top_p'] = clamp(default(body, 'top_p', req_params['top_p']), 0.001, 1.0)
|
||||
req_params['temperature'] = clamp(default(body, 'temperature', req_params['temperature']), 0.01, 1.99) # fixup absolute 0.0/2.0
|
||||
req_params['top_p'] = clamp(default(body, 'top_p', req_params['top_p']), 0.01, 1.0)
|
||||
n = default(body, 'n', 1)
|
||||
if n != 1:
|
||||
raise InvalidRequestError(message="Only n = 1 is supported.", param='n')
|
||||
|
@ -87,6 +97,11 @@ def marshal_common_params(body):
|
|||
|
||||
# presence_penalty - ignored
|
||||
# frequency_penalty - ignored
|
||||
|
||||
# pass through unofficial params
|
||||
req_params['repetition_penalty'] = default(body, 'repetition_penalty', req_params['repetition_penalty'])
|
||||
req_params['encoder_repetition_penalty'] = default(body, 'encoder_repetition_penalty', req_params['encoder_repetition_penalty'])
|
||||
|
||||
# user - ignored
|
||||
|
||||
logits_processor = []
|
||||
|
@ -98,9 +113,11 @@ def marshal_common_params(body):
|
|||
encoder = tiktoken.encoding_for_model(req_params['requested_model'])
|
||||
new_logit_bias = {}
|
||||
for logit, bias in logit_bias.items():
|
||||
for x in encode(encoder.decode([int(logit)]))[0]:
|
||||
for x in encode(encoder.decode([int(logit)]), add_special_tokens=False)[0]:
|
||||
if int(x) in [0, 1, 2, 29871]: # XXX LLAMA tokens
|
||||
continue
|
||||
new_logit_bias[str(int(x))] = bias
|
||||
print(logit_bias, '->', new_logit_bias)
|
||||
debug_msg('logit_bias_map', logit_bias, '->', new_logit_bias)
|
||||
logit_bias = new_logit_bias
|
||||
except KeyError:
|
||||
pass # assume native tokens if we can't find the tokenizer
|
||||
|
@ -134,11 +151,11 @@ def messages_to_prompt(body: dict, req_params: dict, max_tokens):
|
|||
messages = body['messages']
|
||||
|
||||
role_formats = {
|
||||
'user': 'user: {message}\n',
|
||||
'assistant': 'assistant: {message}\n',
|
||||
'user': 'User: {message}\n',
|
||||
'assistant': 'Assistant: {message}\n',
|
||||
'system': '{message}',
|
||||
'context': 'You are a helpful assistant. Answer as concisely as possible.',
|
||||
'prompt': 'assistant:',
|
||||
'context': 'You are a helpful assistant. Answer as concisely as possible.\nUser: I want your assistance.\nAssistant: Sure! What can I do for you?',
|
||||
'prompt': 'Assistant:',
|
||||
}
|
||||
|
||||
if not 'stopping_strings' in req_params:
|
||||
|
@ -151,10 +168,10 @@ def messages_to_prompt(body: dict, req_params: dict, max_tokens):
|
|||
|
||||
template = instruct['turn_template']
|
||||
system_message_template = "{message}"
|
||||
system_message_default = instruct['context']
|
||||
system_message_default = instruct.get('context', '') # can be missing
|
||||
bot_start = template.find('<|bot|>') # So far, 100% of instruction templates have this token
|
||||
user_message_template = template[:bot_start].replace('<|user-message|>', '{message}').replace('<|user|>', instruct['user'])
|
||||
bot_message_template = template[bot_start:].replace('<|bot-message|>', '{message}').replace('<|bot|>', instruct['bot'])
|
||||
user_message_template = template[:bot_start].replace('<|user-message|>', '{message}').replace('<|user|>', instruct.get('user', ''))
|
||||
bot_message_template = template[bot_start:].replace('<|bot-message|>', '{message}').replace('<|bot|>', instruct.get('bot', ''))
|
||||
bot_prompt = bot_message_template[:bot_message_template.find('{message}')].rstrip(' ')
|
||||
|
||||
role_formats = {
|
||||
|
@ -173,13 +190,13 @@ def messages_to_prompt(body: dict, req_params: dict, max_tokens):
|
|||
debug_msg(f"Loaded instruction role format: {shared.settings['instruction_template']}")
|
||||
|
||||
except Exception as e:
|
||||
req_params['stopping_strings'].extend(['\nuser:'])
|
||||
req_params['stopping_strings'].extend(['\nUser:', 'User:']) # XXX User: prompt here also
|
||||
|
||||
print(f"Exception: When loading characters/instruction-following/{shared.settings['instruction_template']}.yaml: {repr(e)}")
|
||||
print("Warning: Loaded default instruction-following template for model.")
|
||||
|
||||
else:
|
||||
req_params['stopping_strings'].extend(['\nuser:'])
|
||||
req_params['stopping_strings'].extend(['\nUser:', 'User:']) # XXX User: prompt here also
|
||||
print("Warning: Loaded default instruction-following template for model.")
|
||||
|
||||
system_msgs = []
|
||||
|
@ -194,6 +211,11 @@ def messages_to_prompt(body: dict, req_params: dict, max_tokens):
|
|||
context_msg = end_line(role_formats['system'].format(message=body['prompt'])) + context_msg
|
||||
|
||||
for m in messages:
|
||||
if 'role' not in m:
|
||||
raise InvalidRequestError(message="messages: missing role", param='messages')
|
||||
if 'content' not in m:
|
||||
raise InvalidRequestError(message="messages: missing content", param='messages')
|
||||
|
||||
role = m['role']
|
||||
content = m['content']
|
||||
# name = m.get('name', None)
|
||||
|
@ -215,12 +237,12 @@ def messages_to_prompt(body: dict, req_params: dict, max_tokens):
|
|||
|
||||
if token_count >= req_params['truncation_length']:
|
||||
err_msg = f"This model maximum context length is {req_params['truncation_length']} tokens. However, your messages resulted in over {token_count} tokens."
|
||||
raise InvalidRequestError(message=err_msg)
|
||||
raise InvalidRequestError(message=err_msg, param='messages')
|
||||
|
||||
if max_tokens > 0 and token_count + max_tokens > req_params['truncation_length']:
|
||||
err_msg = f"This model maximum context length is {req_params['truncation_length']} tokens. However, your messages resulted in over {token_count} tokens and max_tokens is {max_tokens}."
|
||||
print(f"Warning: ${err_msg}")
|
||||
# raise InvalidRequestError(message=err_msg)
|
||||
# raise InvalidRequestError(message=err_msg, params='max_tokens')
|
||||
|
||||
return prompt, token_count
|
||||
|
||||
|
@ -251,6 +273,10 @@ def chat_completions(body: dict, is_legacy: bool = False) -> dict:
|
|||
# format the prompt from messages
|
||||
prompt, token_count = messages_to_prompt(body, req_params, max_tokens)
|
||||
|
||||
# set real max, avoid deeper errors
|
||||
if req_params['max_new_tokens'] + token_count >= req_params['truncation_length']:
|
||||
req_params['max_new_tokens'] = req_params['truncation_length'] - token_count
|
||||
|
||||
# generate reply #######################################
|
||||
debug_msg({'prompt': prompt, 'req_params': req_params})
|
||||
stopping_strings = req_params.pop('stopping_strings', [])
|
||||
|
@ -267,7 +293,7 @@ def chat_completions(body: dict, is_legacy: bool = False) -> dict:
|
|||
|
||||
completion_token_count = len(encode(answer)[0])
|
||||
stop_reason = "stop"
|
||||
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= req_params['max_new_tokens']:
|
||||
stop_reason = "length"
|
||||
|
||||
resp = {
|
||||
|
@ -323,6 +349,10 @@ def stream_chat_completions(body: dict, is_legacy: bool = False):
|
|||
# format the prompt from messages
|
||||
prompt, token_count = messages_to_prompt(body, req_params, max_tokens)
|
||||
|
||||
# set real max, avoid deeper errors
|
||||
if req_params['max_new_tokens'] + token_count >= req_params['truncation_length']:
|
||||
req_params['max_new_tokens'] = req_params['truncation_length'] - token_count
|
||||
|
||||
def chat_streaming_chunk(content):
|
||||
# begin streaming
|
||||
chunk = {
|
||||
|
@ -352,7 +382,6 @@ def stream_chat_completions(body: dict, is_legacy: bool = False):
|
|||
debug_msg({'prompt': prompt, 'req_params': req_params})
|
||||
|
||||
stopping_strings = req_params.pop('stopping_strings', [])
|
||||
logprob_proc = req_params.pop('logprob_proc', None)
|
||||
|
||||
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
||||
|
||||
|
@ -375,13 +404,17 @@ def stream_chat_completions(body: dict, is_legacy: bool = False):
|
|||
if len_seen == 0 and new_content[0] == ' ':
|
||||
new_content = new_content[1:]
|
||||
|
||||
completion_token_count += len(encode(new_content)[0])
|
||||
chunk = chat_streaming_chunk(new_content)
|
||||
|
||||
yield chunk
|
||||
|
||||
# to get the correct token_count, strip leading space if present
|
||||
if answer and answer[0] == ' ':
|
||||
answer = answer[1:]
|
||||
|
||||
completion_token_count = len(encode(answer)[0])
|
||||
stop_reason = "stop"
|
||||
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= req_params['max_new_tokens']:
|
||||
stop_reason = "length"
|
||||
|
||||
chunk = chat_streaming_chunk('')
|
||||
|
@ -413,7 +446,7 @@ def completions(body: dict, is_legacy: bool = False):
|
|||
if prompt and isinstance(prompt[0], int):
|
||||
try:
|
||||
encoder = tiktoken.encoding_for_model(requested_model)
|
||||
prompt = encode(encoder.decode(prompt))[0]
|
||||
prompt = encoder.decode(prompt)
|
||||
except KeyError:
|
||||
prompt = decode(prompt)[0]
|
||||
else:
|
||||
|
@ -441,7 +474,6 @@ def completions(body: dict, is_legacy: bool = False):
|
|||
# generate reply #######################################
|
||||
debug_msg({'prompt': prompt, 'req_params': req_params})
|
||||
stopping_strings = req_params.pop('stopping_strings', [])
|
||||
logprob_proc = req_params.pop('logprob_proc', None)
|
||||
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
||||
|
||||
answer = ''
|
||||
|
@ -475,7 +507,7 @@ def completions(body: dict, is_legacy: bool = False):
|
|||
}
|
||||
}
|
||||
|
||||
if logprob_proc:
|
||||
if logprob_proc and logprob_proc.token_alternatives:
|
||||
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
|
||||
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
||||
else:
|
||||
|
@ -504,7 +536,7 @@ def stream_completions(body: dict, is_legacy: bool = False):
|
|||
if prompt and isinstance(prompt[0], int):
|
||||
try:
|
||||
encoder = tiktoken.encoding_for_model(requested_model)
|
||||
prompt = encode(encoder.decode(prompt))[0]
|
||||
prompt = encoder.decode(prompt)
|
||||
except KeyError:
|
||||
prompt = decode(prompt)[0]
|
||||
else:
|
||||
|
@ -579,9 +611,13 @@ def stream_completions(body: dict, is_legacy: bool = False):
|
|||
|
||||
chunk = text_streaming_chunk(new_content)
|
||||
|
||||
completion_token_count += len(encode(new_content)[0])
|
||||
yield chunk
|
||||
|
||||
# to get the correct count, we strip the leading space if present
|
||||
if answer and answer[0] == ' ':
|
||||
answer = answer[1:]
|
||||
|
||||
completion_token_count = len(encode(answer)[0])
|
||||
stop_reason = "stop"
|
||||
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||
stop_reason = "length"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue