This commit is contained in:
oobabooga 2023-07-12 11:33:25 -07:00
parent 9b55d3a9f9
commit e202190c4f
24 changed files with 146 additions and 125 deletions

View file

@ -36,12 +36,12 @@ class LogprobProcessor(LogitsProcessor):
super().__init__()
def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor:
if self.logprobs is not None: # 0-5
if self.logprobs is not None: # 0-5
log_e_probabilities = F.log_softmax(logits, dim=1)
# XXX hack. should find the selected token and include the prob of that
# ... but we just +1 here instead because we don't know it yet.
top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs + 1)
top_tokens = [ decode(tok) for tok in top_indices[0] ]
top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs + 1)
top_tokens = [decode(tok) for tok in top_indices[0]]
self.token_alternatives = dict(zip(top_tokens, top_values[0].tolist()))
return logits
@ -50,9 +50,9 @@ def convert_logprobs_to_tiktoken(model, logprobs):
try:
encoder = tiktoken.encoding_for_model(model)
# just pick the first one if it encodes to multiple tokens... 99.9% not required and maybe worse overall.
return dict([ (encoder.decode([encoder.encode(token)[0]]), prob) for token, prob in logprobs.items() ])
return dict([(encoder.decode([encoder.encode(token)[0]]), prob) for token, prob in logprobs.items()])
except KeyError:
# assume native tokens if we can't find the tokenizer
# assume native tokens if we can't find the tokenizer
return logprobs
@ -71,9 +71,9 @@ def marshal_common_params(body):
# OpenAI API Parameters
# model - ignored for now, TODO: When we can reliably load a model or lora from a name only change this
req_params['requested_model'] = body.get('model', shared.model_name)
req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
req_params['temperature'] = clamp(default(body, 'temperature', req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0/2.0
req_params['temperature'] = clamp(default(body, 'temperature', req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0/2.0
req_params['top_p'] = clamp(default(body, 'top_p', req_params['top_p']), 0.001, 1.0)
n = default(body, 'n', 1)
if n != 1:
@ -81,7 +81,7 @@ def marshal_common_params(body):
if 'stop' in body: # str or array, max len 4 (ignored)
if isinstance(body['stop'], str):
req_params['stopping_strings'] = [body['stop']] # non-standard parameter
req_params['stopping_strings'] = [body['stop']] # non-standard parameter
elif isinstance(body['stop'], list):
req_params['stopping_strings'] = body['stop']
@ -91,7 +91,7 @@ def marshal_common_params(body):
logits_processor = []
logit_bias = body.get('logit_bias', None)
if logit_bias: # {str: float, ...}
if logit_bias: # {str: float, ...}
# XXX convert tokens from tiktoken based on requested model
# Ex.: 'logit_bias': {'1129': 100, '11442': 100, '16243': 100}
try:
@ -103,19 +103,19 @@ def marshal_common_params(body):
print(logit_bias, '->', new_logit_bias)
logit_bias = new_logit_bias
except KeyError:
pass # assume native tokens if we can't find the tokenizer
pass # assume native tokens if we can't find the tokenizer
logits_processor = [LogitsBiasProcessor(logit_bias)]
logprobs = None # coming to chat eventually
logprobs = None # coming to chat eventually
if 'logprobs' in body:
logprobs = default(body, 'logprobs', 0) # maybe cap at topk? don't clamp 0-5.
logprobs = default(body, 'logprobs', 0) # maybe cap at topk? don't clamp 0-5.
req_params['logprob_proc'] = LogprobProcessor(logprobs)
logits_processor.extend([req_params['logprob_proc']])
else:
logprobs = None
if logits_processor: # requires logits_processor support
if logits_processor: # requires logits_processor support
req_params['logits_processor'] = LogitsProcessorList(logits_processor)
return req_params
@ -123,14 +123,14 @@ def marshal_common_params(body):
def messages_to_prompt(body: dict, req_params: dict, max_tokens):
# functions
if body.get('functions', []): # chat only
if body.get('functions', []): # chat only
raise InvalidRequestError(message="functions is not supported.", param='functions')
if body.get('function_call', ''): # chat only, 'none', 'auto', {'name': 'func'}
if body.get('function_call', ''): # chat only, 'none', 'auto', {'name': 'func'}
raise InvalidRequestError(message="function_call is not supported.", param='function_call')
if not 'messages' in body:
raise InvalidRequestError(message="messages is required", param='messages')
messages = body['messages']
role_formats = {
@ -152,11 +152,11 @@ def messages_to_prompt(body: dict, req_params: dict, max_tokens):
template = instruct['turn_template']
system_message_template = "{message}"
system_message_default = instruct['context']
bot_start = template.find('<|bot|>') # So far, 100% of instruction templates have this token
bot_start = template.find('<|bot|>') # So far, 100% of instruction templates have this token
user_message_template = template[:bot_start].replace('<|user-message|>', '{message}').replace('<|user|>', instruct['user'])
bot_message_template = template[bot_start:].replace('<|bot-message|>', '{message}').replace('<|bot|>', instruct['bot'])
bot_prompt = bot_message_template[:bot_message_template.find('{message}')].rstrip(' ')
role_formats = {
'user': user_message_template,
'assistant': bot_message_template,
@ -167,7 +167,7 @@ def messages_to_prompt(body: dict, req_params: dict, max_tokens):
if 'Alpaca' in shared.settings['instruction_template']:
req_params['stopping_strings'].extend(['\n###'])
elif instruct['user']: # WizardLM and some others have no user prompt.
elif instruct['user']: # WizardLM and some others have no user prompt.
req_params['stopping_strings'].extend(['\n' + instruct['user'], instruct['user']])
debug_msg(f"Loaded instruction role format: {shared.settings['instruction_template']}")
@ -220,16 +220,16 @@ def messages_to_prompt(body: dict, req_params: dict, max_tokens):
if max_tokens > 0 and token_count + max_tokens > req_params['truncation_length']:
err_msg = f"This model maximum context length is {req_params['truncation_length']} tokens. However, your messages resulted in over {token_count} tokens and max_tokens is {max_tokens}."
print(f"Warning: ${err_msg}")
#raise InvalidRequestError(message=err_msg)
# raise InvalidRequestError(message=err_msg)
return prompt, token_count
def chat_completions(body: dict, is_legacy: bool=False) -> dict:
def chat_completions(body: dict, is_legacy: bool = False) -> dict:
# Chat Completions
object_type = 'chat.completions'
created_time = int(time.time())
cmpl_id = "chatcmpl-%d" % (int(time.time()*1000000000))
cmpl_id = "chatcmpl-%d" % (int(time.time() * 1000000000))
resp_list = 'data' if is_legacy else 'choices'
# common params
@ -237,7 +237,7 @@ def chat_completions(body: dict, is_legacy: bool=False) -> dict:
req_params['stream'] = False
requested_model = req_params.pop('requested_model')
logprob_proc = req_params.pop('logprob_proc', None)
req_params['top_k'] = 20 # There is no best_of/top_k param for chat, but it is much improved with a higher top_k.
req_params['top_k'] = 20 # There is no best_of/top_k param for chat, but it is much improved with a higher top_k.
# chat default max_tokens is 'inf', but also flexible
max_tokens = 0
@ -254,7 +254,7 @@ def chat_completions(body: dict, is_legacy: bool=False) -> dict:
# generate reply #######################################
debug_msg({'prompt': prompt, 'req_params': req_params})
stopping_strings = req_params.pop('stopping_strings', [])
logprob_proc = req_params.pop('logprob_proc', None)
logprob_proc = req_params.pop('logprob_proc', None)
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
answer = ''
@ -286,9 +286,9 @@ def chat_completions(body: dict, is_legacy: bool=False) -> dict:
"total_tokens": token_count + completion_token_count
}
}
if logprob_proc: # not official for chat yet
if logprob_proc: # not official for chat yet
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
# else:
# resp[resp_list][0]["logprobs"] = None
@ -296,12 +296,12 @@ def chat_completions(body: dict, is_legacy: bool=False) -> dict:
# generator
def stream_chat_completions(body: dict, is_legacy: bool=False):
def stream_chat_completions(body: dict, is_legacy: bool = False):
# Chat Completions
stream_object_type = 'chat.completions.chunk'
created_time = int(time.time())
cmpl_id = "chatcmpl-%d" % (int(time.time()*1000000000))
cmpl_id = "chatcmpl-%d" % (int(time.time() * 1000000000))
resp_list = 'data' if is_legacy else 'choices'
# common params
@ -309,7 +309,7 @@ def stream_chat_completions(body: dict, is_legacy: bool=False):
req_params['stream'] = True
requested_model = req_params.pop('requested_model')
logprob_proc = req_params.pop('logprob_proc', None)
req_params['top_k'] = 20 # There is no best_of/top_k param for chat, but it is much improved with a higher top_k.
req_params['top_k'] = 20 # There is no best_of/top_k param for chat, but it is much improved with a higher top_k.
# chat default max_tokens is 'inf', but also flexible
max_tokens = 0
@ -339,10 +339,10 @@ def stream_chat_completions(body: dict, is_legacy: bool=False):
}],
}
if logprob_proc: # not official for chat yet
if logprob_proc: # not official for chat yet
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
#else:
chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
# else:
# chunk[resp_list][0]["logprobs"] = None
return chunk
@ -350,7 +350,7 @@ def stream_chat_completions(body: dict, is_legacy: bool=False):
# generate reply #######################################
debug_msg({'prompt': prompt, 'req_params': req_params})
stopping_strings = req_params.pop('stopping_strings', [])
logprob_proc = req_params.pop('logprob_proc', None)
@ -377,9 +377,8 @@ def stream_chat_completions(body: dict, is_legacy: bool=False):
completion_token_count += len(encode(new_content)[0])
chunk = chat_streaming_chunk(new_content)
yield chunk
yield chunk
stop_reason = "stop"
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
@ -396,12 +395,12 @@ def stream_chat_completions(body: dict, is_legacy: bool=False):
yield chunk
def completions(body: dict, is_legacy: bool=False):
def completions(body: dict, is_legacy: bool = False):
# Legacy
# Text Completions
object_type = 'text_completion'
created_time = int(time.time())
cmpl_id = "conv-%d" % (int(time.time()*1000000000))
cmpl_id = "conv-%d" % (int(time.time() * 1000000000))
resp_list = 'data' if is_legacy else 'choices'
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
@ -433,7 +432,7 @@ def completions(body: dict, is_legacy: bool=False):
if token_count + max_tokens > req_params['truncation_length']:
err_msg = f"The token count of your prompt ({token_count}) plus max_tokens ({max_tokens}) cannot exceed the model's context length ({req_params['truncation_length']})."
#print(f"Warning: ${err_msg}")
# print(f"Warning: ${err_msg}")
raise InvalidRequestError(message=err_msg, param=max_tokens_str)
req_params['echo'] = default(body, 'echo', req_params['echo'])
@ -478,21 +477,21 @@ def completions(body: dict, is_legacy: bool=False):
if logprob_proc:
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
else:
resp[resp_list][0]["logprobs"] = None
resp[resp_list][0]["logprobs"] = None
return resp
# generator
def stream_completions(body: dict, is_legacy: bool=False):
def stream_completions(body: dict, is_legacy: bool = False):
# Legacy
# Text Completions
#object_type = 'text_completion'
# object_type = 'text_completion'
stream_object_type = 'text_completion.chunk'
created_time = int(time.time())
cmpl_id = "conv-%d" % (int(time.time()*1000000000))
cmpl_id = "conv-%d" % (int(time.time() * 1000000000))
resp_list = 'data' if is_legacy else 'choices'
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
@ -524,7 +523,7 @@ def stream_completions(body: dict, is_legacy: bool=False):
if token_count + max_tokens > req_params['truncation_length']:
err_msg = f"The token count of your prompt ({token_count}) plus max_tokens ({max_tokens}) cannot exceed the model's context length ({req_params['truncation_length']})."
#print(f"Warning: ${err_msg}")
# print(f"Warning: ${err_msg}")
raise InvalidRequestError(message=err_msg, param=max_tokens_str)
req_params['echo'] = default(body, 'echo', req_params['echo'])
@ -545,9 +544,9 @@ def stream_completions(body: dict, is_legacy: bool=False):
}
if logprob_proc:
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
else:
chunk[resp_list][0]["logprobs"] = None
chunk[resp_list][0]["logprobs"] = None
return chunk
@ -583,7 +582,6 @@ def stream_completions(body: dict, is_legacy: bool=False):
completion_token_count += len(encode(new_content)[0])
yield chunk
stop_reason = "stop"
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
stop_reason = "length"