This commit is contained in:
parent
ebfcfa41f2
commit
b45baeea41
2 changed files with 143 additions and 67 deletions
|
@ -54,7 +54,7 @@ default_req_params = {
|
|||
'mirostat_eta': 0.1,
|
||||
'ban_eos_token': False,
|
||||
'skip_special_tokens': True,
|
||||
'custom_stopping_strings': ['\n###'],
|
||||
'custom_stopping_strings': '',
|
||||
}
|
||||
|
||||
# Optional, install the module and download the model to enable
|
||||
|
@ -254,7 +254,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||
return
|
||||
|
||||
is_legacy = '/generate' in self.path
|
||||
is_chat = 'chat' in self.path
|
||||
is_chat_request = 'chat' in self.path
|
||||
resp_list = 'data' if is_legacy else 'choices'
|
||||
|
||||
# XXX model is ignored for now
|
||||
|
@ -262,23 +262,23 @@ class Handler(BaseHTTPRequestHandler):
|
|||
model = shared.model_name
|
||||
created_time = int(time.time())
|
||||
|
||||
cmpl_id = "chatcmpl-%d" % (created_time) if is_chat else "conv-%d" % (created_time)
|
||||
cmpl_id = "chatcmpl-%d" % (created_time) if is_chat_request else "conv-%d" % (created_time)
|
||||
|
||||
# Request Parameters
|
||||
# Try to use openai defaults or map them to something with the same intent
|
||||
req_params = default_req_params.copy()
|
||||
req_params['custom_stopping_strings'] = default_req_params['custom_stopping_strings'].copy()
|
||||
stopping_strings = []
|
||||
|
||||
if 'stop' in body:
|
||||
if isinstance(body['stop'], str):
|
||||
req_params['custom_stopping_strings'].extend([body['stop']])
|
||||
stopping_strings.extend([body['stop']])
|
||||
elif isinstance(body['stop'], list):
|
||||
req_params['custom_stopping_strings'].extend(body['stop'])
|
||||
stopping_strings.extend(body['stop'])
|
||||
|
||||
truncation_length = default(shared.settings, 'truncation_length', 2048)
|
||||
truncation_length = clamp(default(body, 'truncation_length', truncation_length), 1, truncation_length)
|
||||
|
||||
default_max_tokens = truncation_length if is_chat else 16 # completions default, chat default is 'inf' so we need to cap it.
|
||||
default_max_tokens = truncation_length if is_chat_request else 16 # completions default, chat default is 'inf' so we need to cap it.
|
||||
|
||||
max_tokens_str = 'length' if is_legacy else 'max_tokens'
|
||||
max_tokens = default(body, max_tokens_str, default(shared.settings, 'max_new_tokens', default_max_tokens))
|
||||
|
@ -295,9 +295,11 @@ class Handler(BaseHTTPRequestHandler):
|
|||
req_params['seed'] = shared.settings.get('seed', default_req_params['seed'])
|
||||
req_params['add_bos_token'] = shared.settings.get('add_bos_token', default_req_params['add_bos_token'])
|
||||
|
||||
is_streaming = req_params['stream']
|
||||
|
||||
self.send_response(200)
|
||||
self.send_access_control_headers()
|
||||
if req_params['stream']:
|
||||
if is_streaming:
|
||||
self.send_header('Content-Type', 'text/event-stream')
|
||||
self.send_header('Cache-Control', 'no-cache')
|
||||
# self.send_header('Connection', 'keep-alive')
|
||||
|
@ -311,7 +313,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||
stream_object_type = ''
|
||||
object_type = ''
|
||||
|
||||
if is_chat:
|
||||
if is_chat_request:
|
||||
# Chat Completions
|
||||
stream_object_type = 'chat.completions.chunk'
|
||||
object_type = 'chat.completions'
|
||||
|
@ -347,20 +349,22 @@ class Handler(BaseHTTPRequestHandler):
|
|||
'prompt': bot_prompt,
|
||||
}
|
||||
|
||||
if instruct['user']: # WizardLM and some others have no user prompt.
|
||||
req_params['custom_stopping_strings'].extend(['\n' + instruct['user'], instruct['user']])
|
||||
if 'Alpaca' in shared.settings['instruction_template']:
|
||||
stopping_strings.extend(['\n###'])
|
||||
elif instruct['user']: # WizardLM and some others have no user prompt.
|
||||
stopping_strings.extend(['\n' + instruct['user'], instruct['user']])
|
||||
|
||||
if debug:
|
||||
print(f"Loaded instruction role format: {shared.settings['instruction_template']}")
|
||||
|
||||
except Exception as e:
|
||||
req_params['custom_stopping_strings'].extend(['\nuser:'])
|
||||
stopping_strings.extend(['\nuser:'])
|
||||
|
||||
print(f"Exception: When loading characters/instruction-following/{shared.settings['instruction_template']}.yaml: {repr(e)}")
|
||||
print("Warning: Loaded default instruction-following template for model.")
|
||||
|
||||
else:
|
||||
req_params['custom_stopping_strings'].extend(['\nuser:'])
|
||||
stopping_strings.extend(['\nuser:'])
|
||||
print("Warning: Loaded default instruction-following template for model.")
|
||||
|
||||
system_msgs = []
|
||||
|
@ -391,7 +395,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||
system_msg = system_msg + '\n'
|
||||
|
||||
system_token_count = len(encode(system_msg)[0])
|
||||
remaining_tokens = req_params['truncation_length'] - system_token_count
|
||||
remaining_tokens = truncation_length - system_token_count
|
||||
chat_msg = ''
|
||||
|
||||
while chat_msgs:
|
||||
|
@ -424,20 +428,19 @@ class Handler(BaseHTTPRequestHandler):
|
|||
return
|
||||
|
||||
token_count = len(encode(prompt)[0])
|
||||
if token_count >= req_params['truncation_length']:
|
||||
if token_count >= truncation_length:
|
||||
new_len = int(len(prompt) * shared.settings['truncation_length'] / token_count)
|
||||
prompt = prompt[-new_len:]
|
||||
new_token_count = len(encode(prompt)[0])
|
||||
print(f"Warning: truncating prompt to {new_len} characters, was {token_count} tokens. Now: {new_token_count} tokens.")
|
||||
token_count = new_token_count
|
||||
|
||||
if req_params['truncation_length'] - token_count < req_params['max_new_tokens']:
|
||||
print(f"Warning: Ignoring max_new_tokens ({req_params['max_new_tokens']}), too large for the remaining context. Remaining tokens: {req_params['truncation_length'] - token_count}")
|
||||
req_params['max_new_tokens'] = req_params['truncation_length'] - token_count
|
||||
if truncation_length - token_count < req_params['max_new_tokens']:
|
||||
print(f"Warning: Ignoring max_new_tokens ({req_params['max_new_tokens']}), too large for the remaining context. Remaining tokens: {truncation_length - token_count}")
|
||||
req_params['max_new_tokens'] = truncation_length - token_count
|
||||
print(f"Warning: Set max_new_tokens = {req_params['max_new_tokens']}")
|
||||
|
||||
if req_params['stream']:
|
||||
shared.args.chat = True
|
||||
if is_streaming:
|
||||
# begin streaming
|
||||
chunk = {
|
||||
"id": cmpl_id,
|
||||
|
@ -463,11 +466,11 @@ class Handler(BaseHTTPRequestHandler):
|
|||
# generate reply #######################################
|
||||
if debug:
|
||||
print({'prompt': prompt, 'req_params': req_params})
|
||||
generator = generate_reply(prompt, req_params, stopping_strings=req_params['custom_stopping_strings'], is_chat=False)
|
||||
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
||||
|
||||
answer = ''
|
||||
seen_content = ''
|
||||
longest_stop_len = max([len(x) for x in req_params['custom_stopping_strings']] + [0])
|
||||
longest_stop_len = max([len(x) for x in stopping_strings] + [0])
|
||||
|
||||
for a in generator:
|
||||
answer = a
|
||||
|
@ -476,7 +479,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||
len_seen = len(seen_content)
|
||||
search_start = max(len_seen - longest_stop_len, 0)
|
||||
|
||||
for string in req_params['custom_stopping_strings']:
|
||||
for string in stopping_strings:
|
||||
idx = answer.find(string, search_start)
|
||||
if idx != -1:
|
||||
answer = answer[:idx] # clip it.
|
||||
|
@ -489,7 +492,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||
# is completed, buffer and generate more, don't send it
|
||||
buffer_and_continue = False
|
||||
|
||||
for string in req_params['custom_stopping_strings']:
|
||||
for string in stopping_strings:
|
||||
for j in range(len(string) - 1, 0, -1):
|
||||
if answer[-j:] == string[:j]:
|
||||
buffer_and_continue = True
|
||||
|
@ -501,7 +504,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||
if buffer_and_continue:
|
||||
continue
|
||||
|
||||
if req_params['stream']:
|
||||
if is_streaming:
|
||||
# Streaming
|
||||
new_content = answer[len_seen:]
|
||||
|
||||
|
@ -534,7 +537,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||
self.wfile.write(response.encode('utf-8'))
|
||||
completion_token_count += len(encode(new_content)[0])
|
||||
|
||||
if req_params['stream']:
|
||||
if is_streaming:
|
||||
chunk = {
|
||||
"id": cmpl_id,
|
||||
"object": stream_object_type,
|
||||
|
@ -575,7 +578,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||
|
||||
completion_token_count = len(encode(answer)[0])
|
||||
stop_reason = "stop"
|
||||
if token_count + completion_token_count >= req_params['truncation_length']:
|
||||
if token_count + completion_token_count >= truncation_length:
|
||||
stop_reason = "length"
|
||||
|
||||
resp = {
|
||||
|
@ -594,7 +597,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||
}
|
||||
}
|
||||
|
||||
if is_chat:
|
||||
if is_chat_request:
|
||||
resp[resp_list][0]["message"] = {"role": "assistant", "content": answer}
|
||||
else:
|
||||
resp[resp_list][0]["text"] = answer
|
||||
|
@ -620,7 +623,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||
|
||||
# Request parameters
|
||||
req_params = default_req_params.copy()
|
||||
req_params['custom_stopping_strings'] = default_req_params['custom_stopping_strings'].copy()
|
||||
stopping_strings = []
|
||||
|
||||
# Alpaca is verbose so a good default prompt
|
||||
default_template = (
|
||||
|
@ -632,26 +635,29 @@ class Handler(BaseHTTPRequestHandler):
|
|||
instruction_template = default_template
|
||||
|
||||
# Use the special instruction/input/response template for anything trained like Alpaca
|
||||
if shared.settings['instruction_template'] and not (shared.settings['instruction_template'] in ['Alpaca', 'Alpaca-Input']):
|
||||
try:
|
||||
instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r'))
|
||||
if shared.settings['instruction_template']:
|
||||
if 'Alpaca' in shared.settings['instruction_template']:
|
||||
stopping_strings.extend(['\n###'])
|
||||
else:
|
||||
try:
|
||||
instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r'))
|
||||
|
||||
template = instruct['turn_template']
|
||||
template = template\
|
||||
.replace('<|user|>', instruct.get('user', ''))\
|
||||
.replace('<|bot|>', instruct.get('bot', ''))\
|
||||
.replace('<|user-message|>', '{instruction}\n{input}')
|
||||
template = instruct['turn_template']
|
||||
template = template\
|
||||
.replace('<|user|>', instruct.get('user', ''))\
|
||||
.replace('<|bot|>', instruct.get('bot', ''))\
|
||||
.replace('<|user-message|>', '{instruction}\n{input}')
|
||||
|
||||
instruction_template = instruct.get('context', '') + template[:template.find('<|bot-message|>')].rstrip(' ')
|
||||
if instruct['user']:
|
||||
req_params['custom_stopping_strings'].extend(['\n' + instruct['user'], instruct['user'] ])
|
||||
|
||||
except Exception as e:
|
||||
instruction_template = default_template
|
||||
print(f"Exception: When loading characters/instruction-following/{shared.settings['instruction_template']}.yaml: {repr(e)}")
|
||||
print("Warning: Loaded default instruction-following template (Alpaca) for model.")
|
||||
instruction_template = instruct.get('context', '') + template[:template.find('<|bot-message|>')].rstrip(' ')
|
||||
if instruct['user']:
|
||||
stopping_strings.extend(['\n' + instruct['user'], instruct['user'] ])
|
||||
|
||||
except Exception as e:
|
||||
instruction_template = default_template
|
||||
print(f"Exception: When loading characters/instruction-following/{shared.settings['instruction_template']}.yaml: {repr(e)}")
|
||||
print("Warning: Loaded default instruction-following template (Alpaca) for model.")
|
||||
else:
|
||||
stopping_strings.extend(['\n###'])
|
||||
print("Warning: Loaded default instruction-following template (Alpaca) for model.")
|
||||
|
||||
|
||||
|
@ -671,9 +677,9 @@ class Handler(BaseHTTPRequestHandler):
|
|||
if debug:
|
||||
print({'edit_template': edit_task, 'req_params': req_params, 'token_count': token_count})
|
||||
|
||||
generator = generate_reply(edit_task, req_params, stopping_strings=req_params['custom_stopping_strings'], is_chat=False)
|
||||
generator = generate_reply(edit_task, req_params, stopping_strings=stopping_strings, is_chat=False)
|
||||
|
||||
longest_stop_len = max([len(x) for x in req_params['custom_stopping_strings']] + [0])
|
||||
longest_stop_len = max([len(x) for x in stopping_strings] + [0])
|
||||
answer = ''
|
||||
seen_content = ''
|
||||
for a in generator:
|
||||
|
@ -683,7 +689,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||
len_seen = len(seen_content)
|
||||
search_start = max(len_seen - longest_stop_len, 0)
|
||||
|
||||
for string in req_params['custom_stopping_strings']:
|
||||
for string in stopping_strings:
|
||||
idx = answer.find(string, search_start)
|
||||
if idx != -1:
|
||||
answer = answer[:idx] # clip it.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue