Make OpenAI API the default API (#4430)
This commit is contained in:
parent
84d957ba62
commit
ec17a5d2b7
22 changed files with 769 additions and 1432 deletions
|
@ -1,18 +1,23 @@
|
|||
import copy
|
||||
import time
|
||||
from collections import deque
|
||||
|
||||
import tiktoken
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import yaml
|
||||
from extensions.openai.defaults import clamp, default, get_default_req_params
|
||||
from extensions.openai.errors import InvalidRequestError
|
||||
from extensions.openai.utils import debug_msg, end_line
|
||||
from extensions.openai.utils import debug_msg
|
||||
from modules import shared
|
||||
from modules.chat import (
|
||||
generate_chat_prompt,
|
||||
generate_chat_reply,
|
||||
load_character_memoized
|
||||
)
|
||||
from modules.presets import load_preset_memoized
|
||||
from modules.text_generation import decode, encode, generate_reply
|
||||
from transformers import LogitsProcessor, LogitsProcessorList
|
||||
|
||||
|
||||
# Thanks to @Cypherfox [Cypherfoxy] for the logits code, blame to @matatonic
|
||||
class LogitsBiasProcessor(LogitsProcessor):
|
||||
def __init__(self, logit_bias={}):
|
||||
self.logit_bias = logit_bias
|
||||
|
@ -28,6 +33,7 @@ class LogitsBiasProcessor(LogitsProcessor):
|
|||
logits[0, self.keys] += self.values
|
||||
debug_msg(" --> ", logits[0, self.keys])
|
||||
debug_msg(" max/min ", float(torch.max(logits[0])), float(torch.min(logits[0])))
|
||||
|
||||
return logits
|
||||
|
||||
def __repr__(self):
|
||||
|
@ -47,6 +53,7 @@ class LogprobProcessor(LogitsProcessor):
|
|||
top_probs = [float(x) for x in top_values[0]]
|
||||
self.token_alternatives = dict(zip(top_tokens, top_probs))
|
||||
debug_msg(repr(self))
|
||||
|
||||
return logits
|
||||
|
||||
def __repr__(self):
|
||||
|
@ -66,43 +73,28 @@ def convert_logprobs_to_tiktoken(model, logprobs):
|
|||
return logprobs
|
||||
|
||||
|
||||
def marshal_common_params(body):
|
||||
# Request Parameters
|
||||
# Try to use openai defaults or map them to something with the same intent
|
||||
def process_parameters(body, is_legacy=False):
|
||||
generate_params = body
|
||||
max_tokens_str = 'length' if is_legacy else 'max_tokens'
|
||||
generate_params['max_new_tokens'] = body.pop(max_tokens_str)
|
||||
if generate_params['truncation_length'] == 0:
|
||||
if shared.args.loader and shared.args.loader.lower().startswith('exllama'):
|
||||
generate_params['truncation_length'] = shared.args.max_seq_len
|
||||
elif shared.args.loader and shared.args.loader in ['llama.cpp', 'llamacpp_HF', 'ctransformers']:
|
||||
generate_params['truncation_length'] = shared.args.n_ctx
|
||||
else:
|
||||
generate_params['truncation_length'] = shared.settings['truncation_length']
|
||||
|
||||
req_params = get_default_req_params()
|
||||
|
||||
# Common request parameters
|
||||
req_params['truncation_length'] = shared.settings['truncation_length']
|
||||
req_params['add_bos_token'] = shared.settings.get('add_bos_token', req_params['add_bos_token'])
|
||||
req_params['seed'] = shared.settings.get('seed', req_params['seed'])
|
||||
req_params['custom_stopping_strings'] = shared.settings['custom_stopping_strings']
|
||||
|
||||
# OpenAI API Parameters
|
||||
# model - ignored for now, TODO: When we can reliably load a model or lora from a name only change this
|
||||
req_params['requested_model'] = body.get('model', shared.model_name)
|
||||
|
||||
req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
|
||||
req_params['temperature'] = clamp(default(body, 'temperature', req_params['temperature']), 0.01, 1.99) # fixup absolute 0.0/2.0
|
||||
req_params['top_p'] = clamp(default(body, 'top_p', req_params['top_p']), 0.01, 1.0)
|
||||
n = default(body, 'n', 1)
|
||||
if n != 1:
|
||||
raise InvalidRequestError(message="Only n = 1 is supported.", param='n')
|
||||
if body['preset'] is not None:
|
||||
preset = load_preset_memoized(body['preset'])
|
||||
generate_params.update(preset)
|
||||
|
||||
generate_params['custom_stopping_strings'] = []
|
||||
if 'stop' in body: # str or array, max len 4 (ignored)
|
||||
if isinstance(body['stop'], str):
|
||||
req_params['stopping_strings'] = [body['stop']] # non-standard parameter
|
||||
generate_params['custom_stopping_strings'] = [body['stop']]
|
||||
elif isinstance(body['stop'], list):
|
||||
req_params['stopping_strings'] = body['stop']
|
||||
|
||||
# presence_penalty - ignored
|
||||
# frequency_penalty - ignored
|
||||
|
||||
# pass through unofficial params
|
||||
req_params['repetition_penalty'] = default(body, 'repetition_penalty', req_params['repetition_penalty'])
|
||||
req_params['encoder_repetition_penalty'] = default(body, 'encoder_repetition_penalty', req_params['encoder_repetition_penalty'])
|
||||
|
||||
# user - ignored
|
||||
generate_params['custom_stopping_strings'] = body['stop']
|
||||
|
||||
logits_processor = []
|
||||
logit_bias = body.get('logit_bias', None)
|
||||
|
@ -110,12 +102,13 @@ def marshal_common_params(body):
|
|||
# XXX convert tokens from tiktoken based on requested model
|
||||
# Ex.: 'logit_bias': {'1129': 100, '11442': 100, '16243': 100}
|
||||
try:
|
||||
encoder = tiktoken.encoding_for_model(req_params['requested_model'])
|
||||
encoder = tiktoken.encoding_for_model(generate_params['model'])
|
||||
new_logit_bias = {}
|
||||
for logit, bias in logit_bias.items():
|
||||
for x in encode(encoder.decode([int(logit)]), add_special_tokens=False)[0]:
|
||||
if int(x) in [0, 1, 2, 29871]: # XXX LLAMA tokens
|
||||
continue
|
||||
|
||||
new_logit_bias[str(int(x))] = bias
|
||||
debug_msg('logit_bias_map', logit_bias, '->', new_logit_bias)
|
||||
logit_bias = new_logit_bias
|
||||
|
@ -126,238 +119,129 @@ def marshal_common_params(body):
|
|||
|
||||
logprobs = None # coming to chat eventually
|
||||
if 'logprobs' in body:
|
||||
logprobs = default(body, 'logprobs', 0) # maybe cap at topk? don't clamp 0-5.
|
||||
req_params['logprob_proc'] = LogprobProcessor(logprobs)
|
||||
logits_processor.extend([req_params['logprob_proc']])
|
||||
logprobs = body.get('logprobs', 0) # maybe cap at topk? don't clamp 0-5.
|
||||
generate_params['logprob_proc'] = LogprobProcessor(logprobs)
|
||||
logits_processor.extend([generate_params['logprob_proc']])
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
if logits_processor: # requires logits_processor support
|
||||
req_params['logits_processor'] = LogitsProcessorList(logits_processor)
|
||||
generate_params['logits_processor'] = LogitsProcessorList(logits_processor)
|
||||
|
||||
return req_params
|
||||
return generate_params
|
||||
|
||||
|
||||
def messages_to_prompt(body: dict, req_params: dict, max_tokens):
|
||||
# functions
|
||||
if body.get('functions', []): # chat only
|
||||
def convert_history(history):
|
||||
'''
|
||||
Chat histories in this program are in the format [message, reply].
|
||||
This function converts OpenAI histories to that format.
|
||||
'''
|
||||
chat_dialogue = []
|
||||
current_message = ""
|
||||
current_reply = ""
|
||||
user_input = ""
|
||||
|
||||
for entry in history:
|
||||
content = entry["content"]
|
||||
role = entry["role"]
|
||||
|
||||
if role == "user":
|
||||
user_input = content
|
||||
if current_message:
|
||||
chat_dialogue.append([current_message, ''])
|
||||
current_message = ""
|
||||
current_message = content
|
||||
elif role == "assistant":
|
||||
current_reply = content
|
||||
if current_message:
|
||||
chat_dialogue.append([current_message, current_reply])
|
||||
current_message = ""
|
||||
current_reply = ""
|
||||
else:
|
||||
chat_dialogue.append(['', current_reply])
|
||||
|
||||
# if current_message:
|
||||
# chat_dialogue.append([current_message, ''])
|
||||
|
||||
return user_input, {'internal': chat_dialogue, 'visible': copy.deepcopy(chat_dialogue)}
|
||||
|
||||
|
||||
def chat_completions_common(body: dict, is_legacy: bool = False, stream=False) -> dict:
|
||||
if body.get('functions', []):
|
||||
raise InvalidRequestError(message="functions is not supported.", param='functions')
|
||||
if body.get('function_call', ''): # chat only, 'none', 'auto', {'name': 'func'}
|
||||
|
||||
if body.get('function_call', ''):
|
||||
raise InvalidRequestError(message="function_call is not supported.", param='function_call')
|
||||
|
||||
if 'messages' not in body:
|
||||
raise InvalidRequestError(message="messages is required", param='messages')
|
||||
|
||||
messages = body['messages']
|
||||
|
||||
role_formats = {
|
||||
'user': 'User: {message}\n',
|
||||
'assistant': 'Assistant: {message}\n',
|
||||
'system': '{message}',
|
||||
'context': 'You are a helpful assistant. Answer as concisely as possible.\nUser: I want your assistance.\nAssistant: Sure! What can I do for you?',
|
||||
'prompt': 'Assistant:',
|
||||
}
|
||||
|
||||
if 'stopping_strings' not in req_params:
|
||||
req_params['stopping_strings'] = []
|
||||
|
||||
# Instruct models can be much better
|
||||
if shared.settings['instruction_template']:
|
||||
try:
|
||||
instruct = yaml.safe_load(open(f"instruction-templates/{shared.settings['instruction_template']}.yaml", 'r'))
|
||||
|
||||
template = instruct['turn_template']
|
||||
system_message_template = "{message}"
|
||||
system_message_default = instruct.get('context', '') # can be missing
|
||||
bot_start = template.find('<|bot|>') # So far, 100% of instruction templates have this token
|
||||
user_message_template = template[:bot_start].replace('<|user-message|>', '{message}').replace('<|user|>', instruct.get('user', ''))
|
||||
bot_message_template = template[bot_start:].replace('<|bot-message|>', '{message}').replace('<|bot|>', instruct.get('bot', ''))
|
||||
bot_prompt = bot_message_template[:bot_message_template.find('{message}')].rstrip(' ')
|
||||
|
||||
role_formats = {
|
||||
'user': user_message_template,
|
||||
'assistant': bot_message_template,
|
||||
'system': system_message_template,
|
||||
'context': system_message_default,
|
||||
'prompt': bot_prompt,
|
||||
}
|
||||
|
||||
if 'Alpaca' in shared.settings['instruction_template']:
|
||||
req_params['stopping_strings'].extend(['\n###'])
|
||||
elif instruct['user']: # WizardLM and some others have no user prompt.
|
||||
req_params['stopping_strings'].extend(['\n' + instruct['user'], instruct['user']])
|
||||
|
||||
debug_msg(f"Loaded instruction role format: {shared.settings['instruction_template']}")
|
||||
|
||||
except Exception as e:
|
||||
req_params['stopping_strings'].extend(['\nUser:', 'User:']) # XXX User: prompt here also
|
||||
|
||||
print(f"Exception: When loading instruction-templates/{shared.settings['instruction_template']}.yaml: {repr(e)}")
|
||||
print("Warning: Loaded default instruction-following template for model.")
|
||||
|
||||
else:
|
||||
req_params['stopping_strings'].extend(['\nUser:', 'User:']) # XXX User: prompt here also
|
||||
print("Warning: Loaded default instruction-following template for model.")
|
||||
|
||||
system_msgs = []
|
||||
chat_msgs = []
|
||||
|
||||
# You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible. Knowledge cutoff: {knowledge_cutoff} Current date: {current_date}
|
||||
context_msg = role_formats['system'].format(message=role_formats['context']) if role_formats['context'] else ''
|
||||
context_msg = end_line(context_msg)
|
||||
|
||||
# Maybe they sent both? This is not documented in the API, but some clients seem to do this.
|
||||
if 'prompt' in body:
|
||||
context_msg = end_line(role_formats['system'].format(message=body['prompt'])) + context_msg
|
||||
|
||||
for m in messages:
|
||||
if 'role' not in m:
|
||||
raise InvalidRequestError(message="messages: missing role", param='messages')
|
||||
elif m['role'] == 'function':
|
||||
raise InvalidRequestError(message="role: function is not supported.", param='messages')
|
||||
if 'content' not in m:
|
||||
raise InvalidRequestError(message="messages: missing content", param='messages')
|
||||
|
||||
role = m['role']
|
||||
content = m['content']
|
||||
# name = m.get('name', None)
|
||||
# function_call = m.get('function_call', None) # user name or function name with output in content
|
||||
msg = role_formats[role].format(message=content)
|
||||
if role == 'system':
|
||||
system_msgs.extend([msg])
|
||||
elif role == 'function':
|
||||
raise InvalidRequestError(message="role: function is not supported.", param='messages')
|
||||
else:
|
||||
chat_msgs.extend([msg])
|
||||
|
||||
system_msg = '\n'.join(system_msgs)
|
||||
system_msg = end_line(system_msg)
|
||||
|
||||
prompt = system_msg + context_msg + ''.join(chat_msgs) + role_formats['prompt']
|
||||
|
||||
token_count = len(encode(prompt)[0])
|
||||
|
||||
if token_count >= req_params['truncation_length']:
|
||||
err_msg = f"This model maximum context length is {req_params['truncation_length']} tokens. However, your messages resulted in over {token_count} tokens."
|
||||
raise InvalidRequestError(message=err_msg, param='messages')
|
||||
|
||||
if max_tokens > 0 and token_count + max_tokens > req_params['truncation_length']:
|
||||
err_msg = f"This model maximum context length is {req_params['truncation_length']} tokens. However, your messages resulted in over {token_count} tokens and max_tokens is {max_tokens}."
|
||||
print(f"Warning: ${err_msg}")
|
||||
# raise InvalidRequestError(message=err_msg, params='max_tokens')
|
||||
|
||||
return prompt, token_count
|
||||
|
||||
|
||||
def chat_completions(body: dict, is_legacy: bool = False) -> dict:
|
||||
# Chat Completions
|
||||
object_type = 'chat.completions'
|
||||
object_type = 'chat.completions' if not stream else 'chat.completions.chunk'
|
||||
created_time = int(time.time())
|
||||
cmpl_id = "chatcmpl-%d" % (int(time.time() * 1000000000))
|
||||
resp_list = 'data' if is_legacy else 'choices'
|
||||
|
||||
# common params
|
||||
req_params = marshal_common_params(body)
|
||||
req_params['stream'] = False
|
||||
requested_model = req_params.pop('requested_model')
|
||||
logprob_proc = req_params.pop('logprob_proc', None)
|
||||
req_params['top_k'] = 20 # There is no best_of/top_k param for chat, but it is much improved with a higher top_k.
|
||||
# generation parameters
|
||||
generate_params = process_parameters(body, is_legacy=is_legacy)
|
||||
continue_ = body['continue_']
|
||||
|
||||
# chat default max_tokens is 'inf', but also flexible
|
||||
max_tokens = 0
|
||||
max_tokens_str = 'length' if is_legacy else 'max_tokens'
|
||||
if max_tokens_str in body:
|
||||
max_tokens = default(body, max_tokens_str, req_params['truncation_length'])
|
||||
req_params['max_new_tokens'] = max_tokens
|
||||
else:
|
||||
req_params['max_new_tokens'] = req_params['truncation_length']
|
||||
# Instruction template
|
||||
instruction_template = body['instruction_template'] or shared.settings['instruction_template']
|
||||
name1_instruct, name2_instruct, _, _, context_instruct, turn_template = load_character_memoized(instruction_template, '', '', instruct=True)
|
||||
name1_instruct = body['name1_instruct'] or name1_instruct
|
||||
name2_instruct = body['name2_instruct'] or name2_instruct
|
||||
context_instruct = body['context_instruct'] or context_instruct
|
||||
turn_template = body['turn_template'] or turn_template
|
||||
|
||||
# format the prompt from messages
|
||||
prompt, token_count = messages_to_prompt(body, req_params, max_tokens) # updates req_params['stopping_strings']
|
||||
# Chat character
|
||||
character = body['character'] or shared.settings['character']
|
||||
name1 = body['name1'] or shared.settings['name1']
|
||||
name1, name2, _, greeting, context, _ = load_character_memoized(character, name1, '', instruct=False)
|
||||
name2 = body['name2'] or name2
|
||||
context = body['context'] or context
|
||||
greeting = body['greeting'] or greeting
|
||||
|
||||
# set real max, avoid deeper errors
|
||||
if req_params['max_new_tokens'] + token_count >= req_params['truncation_length']:
|
||||
req_params['max_new_tokens'] = req_params['truncation_length'] - token_count
|
||||
# History
|
||||
user_input, history = convert_history(messages)
|
||||
|
||||
stopping_strings = req_params.pop('stopping_strings', [])
|
||||
generate_params.update({
|
||||
'mode': body['mode'],
|
||||
'name1': name1,
|
||||
'name2': name2,
|
||||
'context': context,
|
||||
'greeting': greeting,
|
||||
'name1_instruct': name1_instruct,
|
||||
'name2_instruct': name2_instruct,
|
||||
'context_instruct': context_instruct,
|
||||
'turn_template': turn_template,
|
||||
'chat-instruct_command': body['chat_instruct_command'],
|
||||
'history': history,
|
||||
'stream': stream
|
||||
})
|
||||
|
||||
# generate reply #######################################
|
||||
debug_msg({'prompt': prompt, 'req_params': req_params})
|
||||
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
||||
max_tokens = generate_params['max_new_tokens']
|
||||
if max_tokens in [None, 0]:
|
||||
generate_params['max_new_tokens'] = 200
|
||||
generate_params['auto_max_new_tokens'] = True
|
||||
|
||||
answer = ''
|
||||
for a in generator:
|
||||
answer = a
|
||||
|
||||
# strip extra leading space off new generated content
|
||||
if answer and answer[0] == ' ':
|
||||
answer = answer[1:]
|
||||
|
||||
completion_token_count = len(encode(answer)[0])
|
||||
stop_reason = "stop"
|
||||
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= req_params['max_new_tokens']:
|
||||
stop_reason = "length"
|
||||
|
||||
resp = {
|
||||
"id": cmpl_id,
|
||||
"object": object_type,
|
||||
"created": created_time,
|
||||
"model": shared.model_name, # TODO: add Lora info?
|
||||
resp_list: [{
|
||||
"index": 0,
|
||||
"finish_reason": stop_reason,
|
||||
"message": {"role": "assistant", "content": answer}
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": token_count,
|
||||
"completion_tokens": completion_token_count,
|
||||
"total_tokens": token_count + completion_token_count
|
||||
}
|
||||
}
|
||||
if logprob_proc: # not official for chat yet
|
||||
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
|
||||
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
||||
# else:
|
||||
# resp[resp_list][0]["logprobs"] = None
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
# generator
|
||||
def stream_chat_completions(body: dict, is_legacy: bool = False):
|
||||
|
||||
# Chat Completions
|
||||
stream_object_type = 'chat.completions.chunk'
|
||||
created_time = int(time.time())
|
||||
cmpl_id = "chatcmpl-%d" % (int(time.time() * 1000000000))
|
||||
resp_list = 'data' if is_legacy else 'choices'
|
||||
|
||||
# common params
|
||||
req_params = marshal_common_params(body)
|
||||
req_params['stream'] = True
|
||||
requested_model = req_params.pop('requested_model')
|
||||
logprob_proc = req_params.pop('logprob_proc', None)
|
||||
req_params['top_k'] = 20 # There is no best_of/top_k param for chat, but it is much improved with a higher top_k.
|
||||
|
||||
# chat default max_tokens is 'inf', but also flexible
|
||||
max_tokens = 0
|
||||
max_tokens_str = 'length' if is_legacy else 'max_tokens'
|
||||
if max_tokens_str in body:
|
||||
max_tokens = default(body, max_tokens_str, req_params['truncation_length'])
|
||||
req_params['max_new_tokens'] = max_tokens
|
||||
else:
|
||||
req_params['max_new_tokens'] = req_params['truncation_length']
|
||||
|
||||
# format the prompt from messages
|
||||
prompt, token_count = messages_to_prompt(body, req_params, max_tokens) # updates req_params['stopping_strings']
|
||||
|
||||
# set real max, avoid deeper errors
|
||||
if req_params['max_new_tokens'] + token_count >= req_params['truncation_length']:
|
||||
req_params['max_new_tokens'] = req_params['truncation_length'] - token_count
|
||||
requested_model = generate_params.pop('model')
|
||||
logprob_proc = generate_params.pop('logprob_proc', None)
|
||||
|
||||
def chat_streaming_chunk(content):
|
||||
# begin streaming
|
||||
chunk = {
|
||||
"id": cmpl_id,
|
||||
"object": stream_object_type,
|
||||
"object": object_type,
|
||||
"created": created_time,
|
||||
"model": shared.model_name,
|
||||
resp_list: [{
|
||||
|
@ -376,262 +260,262 @@ def stream_chat_completions(body: dict, is_legacy: bool = False):
|
|||
# chunk[resp_list][0]["logprobs"] = None
|
||||
return chunk
|
||||
|
||||
yield chat_streaming_chunk('')
|
||||
if stream:
|
||||
yield chat_streaming_chunk('')
|
||||
|
||||
# generate reply #######################################
|
||||
debug_msg({'prompt': prompt, 'req_params': req_params})
|
||||
prompt = generate_chat_prompt(user_input, generate_params)
|
||||
token_count = len(encode(prompt)[0])
|
||||
debug_msg({'prompt': prompt, 'generate_params': generate_params})
|
||||
|
||||
stopping_strings = req_params.pop('stopping_strings', [])
|
||||
|
||||
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
||||
generator = generate_chat_reply(
|
||||
user_input, generate_params, regenerate=False, _continue=continue_, loading_message=False)
|
||||
|
||||
answer = ''
|
||||
seen_content = ''
|
||||
completion_token_count = 0
|
||||
|
||||
for a in generator:
|
||||
answer = a
|
||||
answer = a['internal'][-1][1]
|
||||
if stream:
|
||||
len_seen = len(seen_content)
|
||||
new_content = answer[len_seen:]
|
||||
|
||||
len_seen = len(seen_content)
|
||||
new_content = answer[len_seen:]
|
||||
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
|
||||
continue
|
||||
|
||||
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
|
||||
continue
|
||||
seen_content = answer
|
||||
|
||||
seen_content = answer
|
||||
# strip extra leading space off new generated content
|
||||
if len_seen == 0 and new_content[0] == ' ':
|
||||
new_content = new_content[1:]
|
||||
|
||||
# strip extra leading space off new generated content
|
||||
if len_seen == 0 and new_content[0] == ' ':
|
||||
new_content = new_content[1:]
|
||||
chunk = chat_streaming_chunk(new_content)
|
||||
|
||||
chunk = chat_streaming_chunk(new_content)
|
||||
|
||||
yield chunk
|
||||
|
||||
# to get the correct token_count, strip leading space if present
|
||||
if answer and answer[0] == ' ':
|
||||
answer = answer[1:]
|
||||
yield chunk
|
||||
|
||||
completion_token_count = len(encode(answer)[0])
|
||||
stop_reason = "stop"
|
||||
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= req_params['max_new_tokens']:
|
||||
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= generate_params['max_new_tokens']:
|
||||
stop_reason = "length"
|
||||
|
||||
chunk = chat_streaming_chunk('')
|
||||
chunk[resp_list][0]['finish_reason'] = stop_reason
|
||||
chunk['usage'] = {
|
||||
"prompt_tokens": token_count,
|
||||
"completion_tokens": completion_token_count,
|
||||
"total_tokens": token_count + completion_token_count
|
||||
}
|
||||
if stream:
|
||||
chunk = chat_streaming_chunk('')
|
||||
chunk[resp_list][0]['finish_reason'] = stop_reason
|
||||
chunk['usage'] = {
|
||||
"prompt_tokens": token_count,
|
||||
"completion_tokens": completion_token_count,
|
||||
"total_tokens": token_count + completion_token_count
|
||||
}
|
||||
|
||||
yield chunk
|
||||
yield chunk
|
||||
else:
|
||||
resp = {
|
||||
"id": cmpl_id,
|
||||
"object": object_type,
|
||||
"created": created_time,
|
||||
"model": shared.model_name,
|
||||
resp_list: [{
|
||||
"index": 0,
|
||||
"finish_reason": stop_reason,
|
||||
"message": {"role": "assistant", "content": answer}
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": token_count,
|
||||
"completion_tokens": completion_token_count,
|
||||
"total_tokens": token_count + completion_token_count
|
||||
}
|
||||
}
|
||||
if logprob_proc: # not official for chat yet
|
||||
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
|
||||
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
||||
# else:
|
||||
# resp[resp_list][0]["logprobs"] = None
|
||||
|
||||
yield resp
|
||||
|
||||
|
||||
def completions(body: dict, is_legacy: bool = False):
|
||||
# Legacy
|
||||
# Text Completions
|
||||
object_type = 'text_completion'
|
||||
def completions_common(body: dict, is_legacy: bool = False, stream=False):
|
||||
object_type = 'text_completion.chunk' if stream else 'text_completion'
|
||||
created_time = int(time.time())
|
||||
cmpl_id = "conv-%d" % (int(time.time() * 1000000000))
|
||||
resp_list = 'data' if is_legacy else 'choices'
|
||||
|
||||
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
|
||||
prompt_str = 'context' if is_legacy else 'prompt'
|
||||
|
||||
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
|
||||
if prompt_str not in body:
|
||||
raise InvalidRequestError("Missing required input", param=prompt_str)
|
||||
|
||||
prompt_arg = body[prompt_str]
|
||||
if isinstance(prompt_arg, str) or (isinstance(prompt_arg, list) and isinstance(prompt_arg[0], int)):
|
||||
prompt_arg = [prompt_arg]
|
||||
|
||||
# common params
|
||||
req_params = marshal_common_params(body)
|
||||
req_params['stream'] = False
|
||||
max_tokens_str = 'length' if is_legacy else 'max_tokens'
|
||||
max_tokens = default(body, max_tokens_str, req_params['max_new_tokens'])
|
||||
req_params['max_new_tokens'] = max_tokens
|
||||
requested_model = req_params.pop('requested_model')
|
||||
logprob_proc = req_params.pop('logprob_proc', None)
|
||||
stopping_strings = req_params.pop('stopping_strings', [])
|
||||
# req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
|
||||
req_params['echo'] = default(body, 'echo', req_params['echo'])
|
||||
req_params['top_k'] = default(body, 'best_of', req_params['top_k'])
|
||||
generate_params = process_parameters(body, is_legacy=is_legacy)
|
||||
max_tokens = generate_params['max_new_tokens']
|
||||
generate_params['stream'] = stream
|
||||
requested_model = generate_params.pop('model')
|
||||
logprob_proc = generate_params.pop('logprob_proc', None)
|
||||
# generate_params['suffix'] = body.get('suffix', generate_params['suffix'])
|
||||
generate_params['echo'] = body.get('echo', generate_params['echo'])
|
||||
|
||||
resp_list_data = []
|
||||
total_completion_token_count = 0
|
||||
total_prompt_token_count = 0
|
||||
if not stream:
|
||||
prompt_arg = body[prompt_str]
|
||||
if isinstance(prompt_arg, str) or (isinstance(prompt_arg, list) and isinstance(prompt_arg[0], int)):
|
||||
prompt_arg = [prompt_arg]
|
||||
|
||||
for idx, prompt in enumerate(prompt_arg, start=0):
|
||||
if isinstance(prompt[0], int):
|
||||
# token lists
|
||||
if requested_model == shared.model_name:
|
||||
prompt = decode(prompt)[0]
|
||||
else:
|
||||
resp_list_data = []
|
||||
total_completion_token_count = 0
|
||||
total_prompt_token_count = 0
|
||||
|
||||
for idx, prompt in enumerate(prompt_arg, start=0):
|
||||
if isinstance(prompt[0], int):
|
||||
# token lists
|
||||
if requested_model == shared.model_name:
|
||||
prompt = decode(prompt)[0]
|
||||
else:
|
||||
try:
|
||||
encoder = tiktoken.encoding_for_model(requested_model)
|
||||
prompt = encoder.decode(prompt)
|
||||
except KeyError:
|
||||
prompt = decode(prompt)[0]
|
||||
|
||||
token_count = len(encode(prompt)[0])
|
||||
total_prompt_token_count += token_count
|
||||
|
||||
# generate reply #######################################
|
||||
debug_msg({'prompt': prompt, 'generate_params': generate_params})
|
||||
generator = generate_reply(prompt, generate_params, is_chat=False)
|
||||
answer = ''
|
||||
|
||||
for a in generator:
|
||||
answer = a
|
||||
|
||||
# strip extra leading space off new generated content
|
||||
if answer and answer[0] == ' ':
|
||||
answer = answer[1:]
|
||||
|
||||
completion_token_count = len(encode(answer)[0])
|
||||
total_completion_token_count += completion_token_count
|
||||
stop_reason = "stop"
|
||||
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||
stop_reason = "length"
|
||||
|
||||
respi = {
|
||||
"index": idx,
|
||||
"finish_reason": stop_reason,
|
||||
"text": answer,
|
||||
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
|
||||
}
|
||||
|
||||
resp_list_data.extend([respi])
|
||||
|
||||
resp = {
|
||||
"id": cmpl_id,
|
||||
"object": object_type,
|
||||
"created": created_time,
|
||||
"model": shared.model_name,
|
||||
resp_list: resp_list_data,
|
||||
"usage": {
|
||||
"prompt_tokens": total_prompt_token_count,
|
||||
"completion_tokens": total_completion_token_count,
|
||||
"total_tokens": total_prompt_token_count + total_completion_token_count
|
||||
}
|
||||
}
|
||||
|
||||
yield resp
|
||||
else:
|
||||
prompt = body[prompt_str]
|
||||
if isinstance(prompt, list):
|
||||
if prompt and isinstance(prompt[0], int):
|
||||
try:
|
||||
encoder = tiktoken.encoding_for_model(requested_model)
|
||||
prompt = encoder.decode(prompt)
|
||||
except KeyError:
|
||||
prompt = decode(prompt)[0]
|
||||
else:
|
||||
raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str)
|
||||
|
||||
token_count = len(encode(prompt)[0])
|
||||
total_prompt_token_count += token_count
|
||||
|
||||
if token_count + max_tokens > req_params['truncation_length']:
|
||||
err_msg = f"The token count of your prompt ({token_count}) plus max_tokens ({max_tokens}) cannot exceed the model's context length ({req_params['truncation_length']})."
|
||||
# print(f"Warning: ${err_msg}")
|
||||
raise InvalidRequestError(message=err_msg, param=max_tokens_str)
|
||||
def text_streaming_chunk(content):
|
||||
# begin streaming
|
||||
chunk = {
|
||||
"id": cmpl_id,
|
||||
"object": object_type,
|
||||
"created": created_time,
|
||||
"model": shared.model_name,
|
||||
resp_list: [{
|
||||
"index": 0,
|
||||
"finish_reason": None,
|
||||
"text": content,
|
||||
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
|
||||
}],
|
||||
}
|
||||
|
||||
return chunk
|
||||
|
||||
yield text_streaming_chunk('')
|
||||
|
||||
# generate reply #######################################
|
||||
debug_msg({'prompt': prompt, 'req_params': req_params})
|
||||
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
||||
debug_msg({'prompt': prompt, 'generate_params': generate_params})
|
||||
generator = generate_reply(prompt, generate_params, is_chat=False)
|
||||
|
||||
answer = ''
|
||||
seen_content = ''
|
||||
completion_token_count = 0
|
||||
|
||||
for a in generator:
|
||||
answer = a
|
||||
|
||||
# strip extra leading space off new generated content
|
||||
len_seen = len(seen_content)
|
||||
new_content = answer[len_seen:]
|
||||
|
||||
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
|
||||
continue
|
||||
|
||||
seen_content = answer
|
||||
|
||||
# strip extra leading space off new generated content
|
||||
if len_seen == 0 and new_content[0] == ' ':
|
||||
new_content = new_content[1:]
|
||||
|
||||
chunk = text_streaming_chunk(new_content)
|
||||
|
||||
yield chunk
|
||||
|
||||
# to get the correct count, we strip the leading space if present
|
||||
if answer and answer[0] == ' ':
|
||||
answer = answer[1:]
|
||||
|
||||
completion_token_count = len(encode(answer)[0])
|
||||
total_completion_token_count += completion_token_count
|
||||
stop_reason = "stop"
|
||||
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||
if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||
stop_reason = "length"
|
||||
|
||||
respi = {
|
||||
"index": idx,
|
||||
"finish_reason": stop_reason,
|
||||
"text": answer,
|
||||
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
|
||||
chunk = text_streaming_chunk('')
|
||||
chunk[resp_list][0]["finish_reason"] = stop_reason
|
||||
chunk["usage"] = {
|
||||
"prompt_tokens": token_count,
|
||||
"completion_tokens": completion_token_count,
|
||||
"total_tokens": token_count + completion_token_count
|
||||
}
|
||||
|
||||
resp_list_data.extend([respi])
|
||||
|
||||
resp = {
|
||||
"id": cmpl_id,
|
||||
"object": object_type,
|
||||
"created": created_time,
|
||||
"model": shared.model_name, # TODO: add Lora info?
|
||||
resp_list: resp_list_data,
|
||||
"usage": {
|
||||
"prompt_tokens": total_prompt_token_count,
|
||||
"completion_tokens": total_completion_token_count,
|
||||
"total_tokens": total_prompt_token_count + total_completion_token_count
|
||||
}
|
||||
}
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
# generator
|
||||
def stream_completions(body: dict, is_legacy: bool = False):
|
||||
# Legacy
|
||||
# Text Completions
|
||||
# object_type = 'text_completion'
|
||||
stream_object_type = 'text_completion.chunk'
|
||||
created_time = int(time.time())
|
||||
cmpl_id = "conv-%d" % (int(time.time() * 1000000000))
|
||||
resp_list = 'data' if is_legacy else 'choices'
|
||||
|
||||
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
|
||||
prompt_str = 'context' if is_legacy else 'prompt'
|
||||
if prompt_str not in body:
|
||||
raise InvalidRequestError("Missing required input", param=prompt_str)
|
||||
|
||||
prompt = body[prompt_str]
|
||||
req_params = marshal_common_params(body)
|
||||
requested_model = req_params.pop('requested_model')
|
||||
if isinstance(prompt, list):
|
||||
if prompt and isinstance(prompt[0], int):
|
||||
try:
|
||||
encoder = tiktoken.encoding_for_model(requested_model)
|
||||
prompt = encoder.decode(prompt)
|
||||
except KeyError:
|
||||
prompt = decode(prompt)[0]
|
||||
else:
|
||||
raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str)
|
||||
|
||||
# common params
|
||||
req_params['stream'] = True
|
||||
max_tokens_str = 'length' if is_legacy else 'max_tokens'
|
||||
max_tokens = default(body, max_tokens_str, req_params['max_new_tokens'])
|
||||
req_params['max_new_tokens'] = max_tokens
|
||||
logprob_proc = req_params.pop('logprob_proc', None)
|
||||
stopping_strings = req_params.pop('stopping_strings', [])
|
||||
# req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
|
||||
req_params['echo'] = default(body, 'echo', req_params['echo'])
|
||||
req_params['top_k'] = default(body, 'best_of', req_params['top_k'])
|
||||
|
||||
token_count = len(encode(prompt)[0])
|
||||
|
||||
if token_count + max_tokens > req_params['truncation_length']:
|
||||
err_msg = f"The token count of your prompt ({token_count}) plus max_tokens ({max_tokens}) cannot exceed the model's context length ({req_params['truncation_length']})."
|
||||
# print(f"Warning: ${err_msg}")
|
||||
raise InvalidRequestError(message=err_msg, param=max_tokens_str)
|
||||
|
||||
def text_streaming_chunk(content):
|
||||
# begin streaming
|
||||
chunk = {
|
||||
"id": cmpl_id,
|
||||
"object": stream_object_type,
|
||||
"created": created_time,
|
||||
"model": shared.model_name,
|
||||
resp_list: [{
|
||||
"index": 0,
|
||||
"finish_reason": None,
|
||||
"text": content,
|
||||
"logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None,
|
||||
}],
|
||||
}
|
||||
|
||||
return chunk
|
||||
|
||||
yield text_streaming_chunk('')
|
||||
|
||||
# generate reply #######################################
|
||||
debug_msg({'prompt': prompt, 'req_params': req_params})
|
||||
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
||||
|
||||
answer = ''
|
||||
seen_content = ''
|
||||
completion_token_count = 0
|
||||
|
||||
for a in generator:
|
||||
answer = a
|
||||
|
||||
len_seen = len(seen_content)
|
||||
new_content = answer[len_seen:]
|
||||
|
||||
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
|
||||
continue
|
||||
|
||||
seen_content = answer
|
||||
|
||||
# strip extra leading space off new generated content
|
||||
if len_seen == 0 and new_content[0] == ' ':
|
||||
new_content = new_content[1:]
|
||||
|
||||
chunk = text_streaming_chunk(new_content)
|
||||
|
||||
yield chunk
|
||||
|
||||
# to get the correct count, we strip the leading space if present
|
||||
if answer and answer[0] == ' ':
|
||||
answer = answer[1:]
|
||||
|
||||
completion_token_count = len(encode(answer)[0])
|
||||
stop_reason = "stop"
|
||||
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||
stop_reason = "length"
|
||||
def chat_completions(body: dict, is_legacy: bool = False) -> dict:
|
||||
generator = chat_completions_common(body, is_legacy, stream=False)
|
||||
return deque(generator, maxlen=1).pop()
|
||||
|
||||
chunk = text_streaming_chunk('')
|
||||
chunk[resp_list][0]["finish_reason"] = stop_reason
|
||||
chunk["usage"] = {
|
||||
"prompt_tokens": token_count,
|
||||
"completion_tokens": completion_token_count,
|
||||
"total_tokens": token_count + completion_token_count
|
||||
}
|
||||
|
||||
yield chunk
|
||||
def stream_chat_completions(body: dict, is_legacy: bool = False):
|
||||
for resp in chat_completions_common(body, is_legacy, stream=True):
|
||||
yield resp
|
||||
|
||||
|
||||
def completions(body: dict, is_legacy: bool = False) -> dict:
|
||||
generator = completions_common(body, is_legacy, stream=False)
|
||||
return deque(generator, maxlen=1).pop()
|
||||
|
||||
|
||||
def stream_completions(body: dict, is_legacy: bool = False):
|
||||
for resp in completions_common(body, is_legacy, stream=True):
|
||||
yield resp
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue