extensions/openai: Major docs update, fix #2852 (critical bug), minor improvements (#2849)

This commit is contained in:
matatonic 2023-06-24 21:50:04 -04:00 committed by GitHub
parent ebfcfa41f2
commit b45baeea41
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 143 additions and 67 deletions

View file

@ -54,7 +54,7 @@ default_req_params = {
'mirostat_eta': 0.1,
'ban_eos_token': False,
'skip_special_tokens': True,
'custom_stopping_strings': ['\n###'],
'custom_stopping_strings': '',
}
# Optional, install the module and download the model to enable
@ -254,7 +254,7 @@ class Handler(BaseHTTPRequestHandler):
return
is_legacy = '/generate' in self.path
is_chat = 'chat' in self.path
is_chat_request = 'chat' in self.path
resp_list = 'data' if is_legacy else 'choices'
# XXX model is ignored for now
@ -262,23 +262,23 @@ class Handler(BaseHTTPRequestHandler):
model = shared.model_name
created_time = int(time.time())
cmpl_id = "chatcmpl-%d" % (created_time) if is_chat else "conv-%d" % (created_time)
cmpl_id = "chatcmpl-%d" % (created_time) if is_chat_request else "conv-%d" % (created_time)
# Request Parameters
# Try to use openai defaults or map them to something with the same intent
req_params = default_req_params.copy()
req_params['custom_stopping_strings'] = default_req_params['custom_stopping_strings'].copy()
stopping_strings = []
if 'stop' in body:
if isinstance(body['stop'], str):
req_params['custom_stopping_strings'].extend([body['stop']])
stopping_strings.extend([body['stop']])
elif isinstance(body['stop'], list):
req_params['custom_stopping_strings'].extend(body['stop'])
stopping_strings.extend(body['stop'])
truncation_length = default(shared.settings, 'truncation_length', 2048)
truncation_length = clamp(default(body, 'truncation_length', truncation_length), 1, truncation_length)
default_max_tokens = truncation_length if is_chat else 16 # completions default, chat default is 'inf' so we need to cap it.
default_max_tokens = truncation_length if is_chat_request else 16 # completions default, chat default is 'inf' so we need to cap it.
max_tokens_str = 'length' if is_legacy else 'max_tokens'
max_tokens = default(body, max_tokens_str, default(shared.settings, 'max_new_tokens', default_max_tokens))
@ -295,9 +295,11 @@ class Handler(BaseHTTPRequestHandler):
req_params['seed'] = shared.settings.get('seed', default_req_params['seed'])
req_params['add_bos_token'] = shared.settings.get('add_bos_token', default_req_params['add_bos_token'])
is_streaming = req_params['stream']
self.send_response(200)
self.send_access_control_headers()
if req_params['stream']:
if is_streaming:
self.send_header('Content-Type', 'text/event-stream')
self.send_header('Cache-Control', 'no-cache')
# self.send_header('Connection', 'keep-alive')
@ -311,7 +313,7 @@ class Handler(BaseHTTPRequestHandler):
stream_object_type = ''
object_type = ''
if is_chat:
if is_chat_request:
# Chat Completions
stream_object_type = 'chat.completions.chunk'
object_type = 'chat.completions'
@ -347,20 +349,22 @@ class Handler(BaseHTTPRequestHandler):
'prompt': bot_prompt,
}
if instruct['user']: # WizardLM and some others have no user prompt.
req_params['custom_stopping_strings'].extend(['\n' + instruct['user'], instruct['user']])
if 'Alpaca' in shared.settings['instruction_template']:
stopping_strings.extend(['\n###'])
elif instruct['user']: # WizardLM and some others have no user prompt.
stopping_strings.extend(['\n' + instruct['user'], instruct['user']])
if debug:
print(f"Loaded instruction role format: {shared.settings['instruction_template']}")
except Exception as e:
req_params['custom_stopping_strings'].extend(['\nuser:'])
stopping_strings.extend(['\nuser:'])
print(f"Exception: When loading characters/instruction-following/{shared.settings['instruction_template']}.yaml: {repr(e)}")
print("Warning: Loaded default instruction-following template for model.")
else:
req_params['custom_stopping_strings'].extend(['\nuser:'])
stopping_strings.extend(['\nuser:'])
print("Warning: Loaded default instruction-following template for model.")
system_msgs = []
@ -391,7 +395,7 @@ class Handler(BaseHTTPRequestHandler):
system_msg = system_msg + '\n'
system_token_count = len(encode(system_msg)[0])
remaining_tokens = req_params['truncation_length'] - system_token_count
remaining_tokens = truncation_length - system_token_count
chat_msg = ''
while chat_msgs:
@ -424,20 +428,19 @@ class Handler(BaseHTTPRequestHandler):
return
token_count = len(encode(prompt)[0])
if token_count >= req_params['truncation_length']:
if token_count >= truncation_length:
new_len = int(len(prompt) * shared.settings['truncation_length'] / token_count)
prompt = prompt[-new_len:]
new_token_count = len(encode(prompt)[0])
print(f"Warning: truncating prompt to {new_len} characters, was {token_count} tokens. Now: {new_token_count} tokens.")
token_count = new_token_count
if req_params['truncation_length'] - token_count < req_params['max_new_tokens']:
print(f"Warning: Ignoring max_new_tokens ({req_params['max_new_tokens']}), too large for the remaining context. Remaining tokens: {req_params['truncation_length'] - token_count}")
req_params['max_new_tokens'] = req_params['truncation_length'] - token_count
if truncation_length - token_count < req_params['max_new_tokens']:
print(f"Warning: Ignoring max_new_tokens ({req_params['max_new_tokens']}), too large for the remaining context. Remaining tokens: {truncation_length - token_count}")
req_params['max_new_tokens'] = truncation_length - token_count
print(f"Warning: Set max_new_tokens = {req_params['max_new_tokens']}")
if req_params['stream']:
shared.args.chat = True
if is_streaming:
# begin streaming
chunk = {
"id": cmpl_id,
@ -463,11 +466,11 @@ class Handler(BaseHTTPRequestHandler):
# generate reply #######################################
if debug:
print({'prompt': prompt, 'req_params': req_params})
generator = generate_reply(prompt, req_params, stopping_strings=req_params['custom_stopping_strings'], is_chat=False)
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
answer = ''
seen_content = ''
longest_stop_len = max([len(x) for x in req_params['custom_stopping_strings']] + [0])
longest_stop_len = max([len(x) for x in stopping_strings] + [0])
for a in generator:
answer = a
@ -476,7 +479,7 @@ class Handler(BaseHTTPRequestHandler):
len_seen = len(seen_content)
search_start = max(len_seen - longest_stop_len, 0)
for string in req_params['custom_stopping_strings']:
for string in stopping_strings:
idx = answer.find(string, search_start)
if idx != -1:
answer = answer[:idx] # clip it.
@ -489,7 +492,7 @@ class Handler(BaseHTTPRequestHandler):
# is completed, buffer and generate more, don't send it
buffer_and_continue = False
for string in req_params['custom_stopping_strings']:
for string in stopping_strings:
for j in range(len(string) - 1, 0, -1):
if answer[-j:] == string[:j]:
buffer_and_continue = True
@ -501,7 +504,7 @@ class Handler(BaseHTTPRequestHandler):
if buffer_and_continue:
continue
if req_params['stream']:
if is_streaming:
# Streaming
new_content = answer[len_seen:]
@ -534,7 +537,7 @@ class Handler(BaseHTTPRequestHandler):
self.wfile.write(response.encode('utf-8'))
completion_token_count += len(encode(new_content)[0])
if req_params['stream']:
if is_streaming:
chunk = {
"id": cmpl_id,
"object": stream_object_type,
@ -575,7 +578,7 @@ class Handler(BaseHTTPRequestHandler):
completion_token_count = len(encode(answer)[0])
stop_reason = "stop"
if token_count + completion_token_count >= req_params['truncation_length']:
if token_count + completion_token_count >= truncation_length:
stop_reason = "length"
resp = {
@ -594,7 +597,7 @@ class Handler(BaseHTTPRequestHandler):
}
}
if is_chat:
if is_chat_request:
resp[resp_list][0]["message"] = {"role": "assistant", "content": answer}
else:
resp[resp_list][0]["text"] = answer
@ -620,7 +623,7 @@ class Handler(BaseHTTPRequestHandler):
# Request parameters
req_params = default_req_params.copy()
req_params['custom_stopping_strings'] = default_req_params['custom_stopping_strings'].copy()
stopping_strings = []
# Alpaca is verbose so a good default prompt
default_template = (
@ -632,26 +635,29 @@ class Handler(BaseHTTPRequestHandler):
instruction_template = default_template
# Use the special instruction/input/response template for anything trained like Alpaca
if shared.settings['instruction_template'] and not (shared.settings['instruction_template'] in ['Alpaca', 'Alpaca-Input']):
try:
instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r'))
if shared.settings['instruction_template']:
if 'Alpaca' in shared.settings['instruction_template']:
stopping_strings.extend(['\n###'])
else:
try:
instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r'))
template = instruct['turn_template']
template = template\
.replace('<|user|>', instruct.get('user', ''))\
.replace('<|bot|>', instruct.get('bot', ''))\
.replace('<|user-message|>', '{instruction}\n{input}')
template = instruct['turn_template']
template = template\
.replace('<|user|>', instruct.get('user', ''))\
.replace('<|bot|>', instruct.get('bot', ''))\
.replace('<|user-message|>', '{instruction}\n{input}')
instruction_template = instruct.get('context', '') + template[:template.find('<|bot-message|>')].rstrip(' ')
if instruct['user']:
req_params['custom_stopping_strings'].extend(['\n' + instruct['user'], instruct['user'] ])
except Exception as e:
instruction_template = default_template
print(f"Exception: When loading characters/instruction-following/{shared.settings['instruction_template']}.yaml: {repr(e)}")
print("Warning: Loaded default instruction-following template (Alpaca) for model.")
instruction_template = instruct.get('context', '') + template[:template.find('<|bot-message|>')].rstrip(' ')
if instruct['user']:
stopping_strings.extend(['\n' + instruct['user'], instruct['user'] ])
except Exception as e:
instruction_template = default_template
print(f"Exception: When loading characters/instruction-following/{shared.settings['instruction_template']}.yaml: {repr(e)}")
print("Warning: Loaded default instruction-following template (Alpaca) for model.")
else:
stopping_strings.extend(['\n###'])
print("Warning: Loaded default instruction-following template (Alpaca) for model.")
@ -671,9 +677,9 @@ class Handler(BaseHTTPRequestHandler):
if debug:
print({'edit_template': edit_task, 'req_params': req_params, 'token_count': token_count})
generator = generate_reply(edit_task, req_params, stopping_strings=req_params['custom_stopping_strings'], is_chat=False)
generator = generate_reply(edit_task, req_params, stopping_strings=stopping_strings, is_chat=False)
longest_stop_len = max([len(x) for x in req_params['custom_stopping_strings']] + [0])
longest_stop_len = max([len(x) for x in stopping_strings] + [0])
answer = ''
seen_content = ''
for a in generator:
@ -683,7 +689,7 @@ class Handler(BaseHTTPRequestHandler):
len_seen = len(seen_content)
search_start = max(len_seen - longest_stop_len, 0)
for string in req_params['custom_stopping_strings']:
for string in stopping_strings:
idx = answer.find(string, search_start)
if idx != -1:
answer = answer[:idx] # clip it.