extensions/openai: docs update, model loader, minor fixes (#2557)

This commit is contained in:
matatonic 2023-06-17 18:15:24 -04:00 committed by GitHub
parent 2220b78e7a
commit 1e97aaac95
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 154 additions and 69 deletions

View file

@ -4,11 +4,13 @@ import os
import time
import requests
import yaml
import numpy as np
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from threading import Thread
from modules.utils import get_available_models
import numpy as np
from modules.models import load_model, unload_model
from modules.models_settings import (get_model_settings_from_yamls,
update_model_parameters)
from modules import shared
from modules.text_generation import encode, generate_reply
@ -37,8 +39,8 @@ default_req_params = {
'add_bos_token': True,
'do_sample': True,
'typical_p': 1.0,
'epsilon_cutoff': 0, # In units of 1e-4
'eta_cutoff': 0, # In units of 1e-4
'epsilon_cutoff': 0.0, # In units of 1e-4
'eta_cutoff': 0.0, # In units of 1e-4
'tfs': 1.0,
'top_a': 0.0,
'min_length': 0,
@ -142,41 +144,83 @@ class Handler(BaseHTTPRequestHandler):
self.wfile.write("OK".encode('utf-8'))
def do_GET(self):
if self.path.startswith('/v1/models'):
self.send_response(200)
self.send_access_control_headers()
self.send_header('Content-Type', 'application/json')
self.end_headers()
# TODO: Lora's?
# This API should list capabilities, limits and pricing...
current_model_list = [ shared.model_name ] # The real chat/completions model
if self.path.startswith('/v1/engines') or self.path.startswith('/v1/models'):
current_model_list = [ shared.model_name ] # The real chat/completions model, maybe "None"
embeddings_model_list = [ st_model ] if embedding_model else [] # The real sentence transformer embeddings model
pseudo_model_list = [ # these are expected by so much, so include some here as a dummy
'gpt-3.5-turbo', # /v1/chat/completions
'text-curie-001', # /v1/completions, 2k context
'text-davinci-002' # /v1/embeddings text-embedding-ada-002:1536, text-davinci-002:768
]
available_model_list = get_available_models()
all_model_list = current_model_list + embeddings_model_list + pseudo_model_list + available_model_list
models = [{ "id": id, "object": "model", "owned_by": "user", "permission": [] } for id in all_model_list ]
is_legacy = 'engines' in self.path
is_list = self.path in ['/v1/engines', '/v1/models']
response = ''
if self.path == '/v1/models':
response = json.dumps({
resp = ''
if is_legacy and not is_list: # load model
model_name = self.path[self.path.find('/v1/engines/') + len('/v1/engines/'):]
resp = {
"id": model_name,
"object": "engine",
"owner": "self",
"ready": True,
}
if model_name not in pseudo_model_list + embeddings_model_list + current_model_list: # Real model only
# No args. Maybe it works anyways!
# TODO: hack some heuristics into args for better results
shared.model_name = model_name
unload_model()
model_settings = get_model_settings_from_yamls(shared.model_name)
shared.settings.update(model_settings)
update_model_parameters(model_settings, initial=True)
if shared.settings['mode'] != 'instruct':
shared.settings['instruction_template'] = None
shared.model, shared.tokenizer = load_model(shared.model_name)
if not shared.model: # load failed.
shared.model_name = "None"
resp['id'] = "None"
resp['ready'] = False
elif is_list:
# TODO: Lora's?
available_model_list = get_available_models()
all_model_list = current_model_list + embeddings_model_list + pseudo_model_list + available_model_list
models = {}
if is_legacy:
models = [{ "id": id, "object": "engine", "owner": "user", "ready": True } for id in all_model_list ]
if not shared.model:
models[0]['ready'] = False
else:
models = [{ "id": id, "object": "model", "owned_by": "user", "permission": [] } for id in all_model_list ]
resp = {
"object": "list",
"data": models,
})
}
else:
the_model_name = self.path[len('/v1/models/'):]
response = json.dumps({
resp = {
"id": the_model_name,
"object": "model",
"owned_by": "user",
"permission": []
})
}
self.send_response(200)
self.send_access_control_headers()
self.send_header('Content-Type', 'application/json')
self.end_headers()
response = json.dumps(resp)
self.wfile.write(response.encode('utf-8'))
elif '/billing/usage' in self.path:
@ -283,35 +327,41 @@ class Handler(BaseHTTPRequestHandler):
}
# Instruct models can be much better
try:
instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r'))
if shared.settings['instruction_template']:
try:
instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r'))
template = instruct['turn_template']
system_message_template = "{message}"
system_message_default = instruct['context']
bot_start = template.find('<|bot|>') # So far, 100% of instruction templates have this token
user_message_template = template[:bot_start].replace('<|user-message|>', '{message}').replace('<|user|>', instruct['user'])
bot_message_template = template[bot_start:].replace('<|bot-message|>', '{message}').replace('<|bot|>', instruct['bot'])
bot_prompt = bot_message_template[:bot_message_template.find('{message}')].rstrip(' ')
role_formats = {
'user': user_message_template,
'assistant': bot_message_template,
'system': system_message_template,
'context': system_message_default,
'prompt': bot_prompt,
}
template = instruct['turn_template']
system_message_template = "{message}"
system_message_default = instruct['context']
bot_start = template.find('<|bot|>') # So far, 100% of instruction templates have this token
user_message_template = template[:bot_start].replace('<|user-message|>', '{message}').replace('<|user|>', instruct['user'])
bot_message_template = template[bot_start:].replace('<|bot-message|>', '{message}').replace('<|bot|>', instruct['bot'])
bot_prompt = bot_message_template[:bot_message_template.find('{message}')].rstrip(' ')
role_formats = {
'user': user_message_template,
'assistant': bot_message_template,
'system': system_message_template,
'context': system_message_default,
'prompt': bot_prompt,
}
if instruct['user']: # WizardLM and some others have no user prompt.
req_params['custom_stopping_strings'].extend(['\n' + instruct['user'], instruct['user']])
if instruct['user']: # WizardLM and some others have no user prompt.
req_params['custom_stopping_strings'].extend(['\n' + instruct['user'], instruct['user']])
if debug:
print(f"Loaded instruction role format: {shared.settings['instruction_template']}")
except:
if debug:
print(f"Loaded instruction role format: {shared.settings['instruction_template']}")
except Exception as e:
req_params['custom_stopping_strings'].extend(['\nuser:'])
print(f"Exception: When loading characters/instruction-following/{shared.settings['instruction_template']}.yaml: {repr(e)}")
print("Warning: Loaded default instruction-following template for model.")
else:
req_params['custom_stopping_strings'].extend(['\nuser:'])
if debug:
print("Loaded default role format.")
print("Warning: Loaded default instruction-following template for model.")
system_msgs = []
chat_msgs = []
@ -370,7 +420,8 @@ class Handler(BaseHTTPRequestHandler):
prompt = body['prompt'] # XXX this can be different types
if isinstance(prompt, list):
prompt = ''.join(prompt) # XXX this is wrong... need to split out to multiple calls?
self.openai_error("API Batched generation not yet supported.")
return
token_count = len(encode(prompt)[0])
if token_count >= req_params['truncation_length']:
@ -412,7 +463,7 @@ class Handler(BaseHTTPRequestHandler):
# generate reply #######################################
if debug:
print({'prompt': prompt, 'req_params': req_params})
generator = generate_reply(prompt, req_params, is_chat=False)
generator = generate_reply(prompt, req_params, stopping_strings=req_params['custom_stopping_strings'], is_chat=False)
answer = ''
seen_content = ''
@ -569,6 +620,7 @@ class Handler(BaseHTTPRequestHandler):
# Request parameters
req_params = default_req_params.copy()
req_params['custom_stopping_strings'] = default_req_params['custom_stopping_strings'].copy()
# Alpaca is verbose so a good default prompt
default_template = (
@ -578,10 +630,9 @@ class Handler(BaseHTTPRequestHandler):
)
instruction_template = default_template
req_params['custom_stopping_strings'] = [ '\n###' ]
# Use the special instruction/input/response template for anything trained like Alpaca
if not (shared.settings['instruction_template'] in ['Alpaca', 'Alpaca-Input']):
if shared.settings['instruction_template'] and not (shared.settings['instruction_template'] in ['Alpaca', 'Alpaca-Input']):
try:
instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r'))
@ -593,9 +644,16 @@ class Handler(BaseHTTPRequestHandler):
instruction_template = instruct.get('context', '') + template[:template.find('<|bot-message|>')].rstrip(' ')
if instruct['user']:
req_params['custom_stopping_strings'] = [ '\n' + instruct['user'], instruct['user'] ]
except:
pass
req_params['custom_stopping_strings'].extend(['\n' + instruct['user'], instruct['user'] ])
except Exception as e:
instruction_template = default_template
print(f"Exception: When loading characters/instruction-following/{shared.settings['instruction_template']}.yaml: {repr(e)}")
print("Warning: Loaded default instruction-following template (Alpaca) for model.")
else:
print("Warning: Loaded default instruction-following template (Alpaca) for model.")
edit_task = instruction_template.format(instruction=instruction, input=input)
@ -613,12 +671,28 @@ class Handler(BaseHTTPRequestHandler):
if debug:
print({'edit_template': edit_task, 'req_params': req_params, 'token_count': token_count})
generator = generate_reply(edit_task, req_params, is_chat=False)
generator = generate_reply(edit_task, req_params, stopping_strings=req_params['custom_stopping_strings'], is_chat=False)
longest_stop_len = max([len(x) for x in req_params['custom_stopping_strings']] + [0])
answer = ''
seen_content = ''
for a in generator:
answer = a
stop_string_found = False
len_seen = len(seen_content)
search_start = max(len_seen - longest_stop_len, 0)
for string in req_params['custom_stopping_strings']:
idx = answer.find(string, search_start)
if idx != -1:
answer = answer[:idx] # clip it.
stop_string_found = True
if stop_string_found:
break
# some reply's have an extra leading space to fit the instruction template, just clip it off from the reply.
if edit_task[-1] != '\n' and answer and answer[0] == ' ':
answer = answer[1:]