extensions/openai: docs update, model loader, minor fixes (#2557)
This commit is contained in:
parent
2220b78e7a
commit
1e97aaac95
2 changed files with 154 additions and 69 deletions
|
@ -4,11 +4,13 @@ import os
|
|||
import time
|
||||
import requests
|
||||
import yaml
|
||||
import numpy as np
|
||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||
from threading import Thread
|
||||
from modules.utils import get_available_models
|
||||
|
||||
import numpy as np
|
||||
from modules.models import load_model, unload_model
|
||||
from modules.models_settings import (get_model_settings_from_yamls,
|
||||
update_model_parameters)
|
||||
|
||||
from modules import shared
|
||||
from modules.text_generation import encode, generate_reply
|
||||
|
@ -37,8 +39,8 @@ default_req_params = {
|
|||
'add_bos_token': True,
|
||||
'do_sample': True,
|
||||
'typical_p': 1.0,
|
||||
'epsilon_cutoff': 0, # In units of 1e-4
|
||||
'eta_cutoff': 0, # In units of 1e-4
|
||||
'epsilon_cutoff': 0.0, # In units of 1e-4
|
||||
'eta_cutoff': 0.0, # In units of 1e-4
|
||||
'tfs': 1.0,
|
||||
'top_a': 0.0,
|
||||
'min_length': 0,
|
||||
|
@ -142,41 +144,83 @@ class Handler(BaseHTTPRequestHandler):
|
|||
self.wfile.write("OK".encode('utf-8'))
|
||||
|
||||
def do_GET(self):
|
||||
if self.path.startswith('/v1/models'):
|
||||
self.send_response(200)
|
||||
self.send_access_control_headers()
|
||||
self.send_header('Content-Type', 'application/json')
|
||||
self.end_headers()
|
||||
|
||||
# TODO: Lora's?
|
||||
# This API should list capabilities, limits and pricing...
|
||||
current_model_list = [ shared.model_name ] # The real chat/completions model
|
||||
if self.path.startswith('/v1/engines') or self.path.startswith('/v1/models'):
|
||||
current_model_list = [ shared.model_name ] # The real chat/completions model, maybe "None"
|
||||
embeddings_model_list = [ st_model ] if embedding_model else [] # The real sentence transformer embeddings model
|
||||
pseudo_model_list = [ # these are expected by so much, so include some here as a dummy
|
||||
'gpt-3.5-turbo', # /v1/chat/completions
|
||||
'text-curie-001', # /v1/completions, 2k context
|
||||
'text-davinci-002' # /v1/embeddings text-embedding-ada-002:1536, text-davinci-002:768
|
||||
]
|
||||
available_model_list = get_available_models()
|
||||
all_model_list = current_model_list + embeddings_model_list + pseudo_model_list + available_model_list
|
||||
|
||||
models = [{ "id": id, "object": "model", "owned_by": "user", "permission": [] } for id in all_model_list ]
|
||||
is_legacy = 'engines' in self.path
|
||||
is_list = self.path in ['/v1/engines', '/v1/models']
|
||||
|
||||
response = ''
|
||||
if self.path == '/v1/models':
|
||||
response = json.dumps({
|
||||
resp = ''
|
||||
|
||||
if is_legacy and not is_list: # load model
|
||||
model_name = self.path[self.path.find('/v1/engines/') + len('/v1/engines/'):]
|
||||
|
||||
resp = {
|
||||
"id": model_name,
|
||||
"object": "engine",
|
||||
"owner": "self",
|
||||
"ready": True,
|
||||
}
|
||||
if model_name not in pseudo_model_list + embeddings_model_list + current_model_list: # Real model only
|
||||
# No args. Maybe it works anyways!
|
||||
# TODO: hack some heuristics into args for better results
|
||||
|
||||
shared.model_name = model_name
|
||||
unload_model()
|
||||
|
||||
model_settings = get_model_settings_from_yamls(shared.model_name)
|
||||
shared.settings.update(model_settings)
|
||||
update_model_parameters(model_settings, initial=True)
|
||||
|
||||
if shared.settings['mode'] != 'instruct':
|
||||
shared.settings['instruction_template'] = None
|
||||
|
||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||||
|
||||
if not shared.model: # load failed.
|
||||
shared.model_name = "None"
|
||||
resp['id'] = "None"
|
||||
resp['ready'] = False
|
||||
|
||||
elif is_list:
|
||||
# TODO: Lora's?
|
||||
available_model_list = get_available_models()
|
||||
all_model_list = current_model_list + embeddings_model_list + pseudo_model_list + available_model_list
|
||||
|
||||
models = {}
|
||||
|
||||
if is_legacy:
|
||||
models = [{ "id": id, "object": "engine", "owner": "user", "ready": True } for id in all_model_list ]
|
||||
if not shared.model:
|
||||
models[0]['ready'] = False
|
||||
else:
|
||||
models = [{ "id": id, "object": "model", "owned_by": "user", "permission": [] } for id in all_model_list ]
|
||||
|
||||
resp = {
|
||||
"object": "list",
|
||||
"data": models,
|
||||
})
|
||||
}
|
||||
|
||||
else:
|
||||
the_model_name = self.path[len('/v1/models/'):]
|
||||
response = json.dumps({
|
||||
resp = {
|
||||
"id": the_model_name,
|
||||
"object": "model",
|
||||
"owned_by": "user",
|
||||
"permission": []
|
||||
})
|
||||
}
|
||||
|
||||
self.send_response(200)
|
||||
self.send_access_control_headers()
|
||||
self.send_header('Content-Type', 'application/json')
|
||||
self.end_headers()
|
||||
response = json.dumps(resp)
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
|
||||
elif '/billing/usage' in self.path:
|
||||
|
@ -283,35 +327,41 @@ class Handler(BaseHTTPRequestHandler):
|
|||
}
|
||||
|
||||
# Instruct models can be much better
|
||||
try:
|
||||
instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r'))
|
||||
if shared.settings['instruction_template']:
|
||||
try:
|
||||
instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r'))
|
||||
|
||||
template = instruct['turn_template']
|
||||
system_message_template = "{message}"
|
||||
system_message_default = instruct['context']
|
||||
bot_start = template.find('<|bot|>') # So far, 100% of instruction templates have this token
|
||||
user_message_template = template[:bot_start].replace('<|user-message|>', '{message}').replace('<|user|>', instruct['user'])
|
||||
bot_message_template = template[bot_start:].replace('<|bot-message|>', '{message}').replace('<|bot|>', instruct['bot'])
|
||||
bot_prompt = bot_message_template[:bot_message_template.find('{message}')].rstrip(' ')
|
||||
|
||||
role_formats = {
|
||||
'user': user_message_template,
|
||||
'assistant': bot_message_template,
|
||||
'system': system_message_template,
|
||||
'context': system_message_default,
|
||||
'prompt': bot_prompt,
|
||||
}
|
||||
template = instruct['turn_template']
|
||||
system_message_template = "{message}"
|
||||
system_message_default = instruct['context']
|
||||
bot_start = template.find('<|bot|>') # So far, 100% of instruction templates have this token
|
||||
user_message_template = template[:bot_start].replace('<|user-message|>', '{message}').replace('<|user|>', instruct['user'])
|
||||
bot_message_template = template[bot_start:].replace('<|bot-message|>', '{message}').replace('<|bot|>', instruct['bot'])
|
||||
bot_prompt = bot_message_template[:bot_message_template.find('{message}')].rstrip(' ')
|
||||
|
||||
role_formats = {
|
||||
'user': user_message_template,
|
||||
'assistant': bot_message_template,
|
||||
'system': system_message_template,
|
||||
'context': system_message_default,
|
||||
'prompt': bot_prompt,
|
||||
}
|
||||
|
||||
if instruct['user']: # WizardLM and some others have no user prompt.
|
||||
req_params['custom_stopping_strings'].extend(['\n' + instruct['user'], instruct['user']])
|
||||
if instruct['user']: # WizardLM and some others have no user prompt.
|
||||
req_params['custom_stopping_strings'].extend(['\n' + instruct['user'], instruct['user']])
|
||||
|
||||
if debug:
|
||||
print(f"Loaded instruction role format: {shared.settings['instruction_template']}")
|
||||
except:
|
||||
if debug:
|
||||
print(f"Loaded instruction role format: {shared.settings['instruction_template']}")
|
||||
|
||||
except Exception as e:
|
||||
req_params['custom_stopping_strings'].extend(['\nuser:'])
|
||||
|
||||
print(f"Exception: When loading characters/instruction-following/{shared.settings['instruction_template']}.yaml: {repr(e)}")
|
||||
print("Warning: Loaded default instruction-following template for model.")
|
||||
|
||||
else:
|
||||
req_params['custom_stopping_strings'].extend(['\nuser:'])
|
||||
|
||||
if debug:
|
||||
print("Loaded default role format.")
|
||||
print("Warning: Loaded default instruction-following template for model.")
|
||||
|
||||
system_msgs = []
|
||||
chat_msgs = []
|
||||
|
@ -370,7 +420,8 @@ class Handler(BaseHTTPRequestHandler):
|
|||
prompt = body['prompt'] # XXX this can be different types
|
||||
|
||||
if isinstance(prompt, list):
|
||||
prompt = ''.join(prompt) # XXX this is wrong... need to split out to multiple calls?
|
||||
self.openai_error("API Batched generation not yet supported.")
|
||||
return
|
||||
|
||||
token_count = len(encode(prompt)[0])
|
||||
if token_count >= req_params['truncation_length']:
|
||||
|
@ -412,7 +463,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||
# generate reply #######################################
|
||||
if debug:
|
||||
print({'prompt': prompt, 'req_params': req_params})
|
||||
generator = generate_reply(prompt, req_params, is_chat=False)
|
||||
generator = generate_reply(prompt, req_params, stopping_strings=req_params['custom_stopping_strings'], is_chat=False)
|
||||
|
||||
answer = ''
|
||||
seen_content = ''
|
||||
|
@ -569,6 +620,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||
|
||||
# Request parameters
|
||||
req_params = default_req_params.copy()
|
||||
req_params['custom_stopping_strings'] = default_req_params['custom_stopping_strings'].copy()
|
||||
|
||||
# Alpaca is verbose so a good default prompt
|
||||
default_template = (
|
||||
|
@ -578,10 +630,9 @@ class Handler(BaseHTTPRequestHandler):
|
|||
)
|
||||
|
||||
instruction_template = default_template
|
||||
req_params['custom_stopping_strings'] = [ '\n###' ]
|
||||
|
||||
|
||||
# Use the special instruction/input/response template for anything trained like Alpaca
|
||||
if not (shared.settings['instruction_template'] in ['Alpaca', 'Alpaca-Input']):
|
||||
if shared.settings['instruction_template'] and not (shared.settings['instruction_template'] in ['Alpaca', 'Alpaca-Input']):
|
||||
try:
|
||||
instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r'))
|
||||
|
||||
|
@ -593,9 +644,16 @@ class Handler(BaseHTTPRequestHandler):
|
|||
|
||||
instruction_template = instruct.get('context', '') + template[:template.find('<|bot-message|>')].rstrip(' ')
|
||||
if instruct['user']:
|
||||
req_params['custom_stopping_strings'] = [ '\n' + instruct['user'], instruct['user'] ]
|
||||
except:
|
||||
pass
|
||||
req_params['custom_stopping_strings'].extend(['\n' + instruct['user'], instruct['user'] ])
|
||||
|
||||
except Exception as e:
|
||||
instruction_template = default_template
|
||||
print(f"Exception: When loading characters/instruction-following/{shared.settings['instruction_template']}.yaml: {repr(e)}")
|
||||
print("Warning: Loaded default instruction-following template (Alpaca) for model.")
|
||||
|
||||
else:
|
||||
print("Warning: Loaded default instruction-following template (Alpaca) for model.")
|
||||
|
||||
|
||||
edit_task = instruction_template.format(instruction=instruction, input=input)
|
||||
|
||||
|
@ -613,12 +671,28 @@ class Handler(BaseHTTPRequestHandler):
|
|||
if debug:
|
||||
print({'edit_template': edit_task, 'req_params': req_params, 'token_count': token_count})
|
||||
|
||||
generator = generate_reply(edit_task, req_params, is_chat=False)
|
||||
generator = generate_reply(edit_task, req_params, stopping_strings=req_params['custom_stopping_strings'], is_chat=False)
|
||||
|
||||
longest_stop_len = max([len(x) for x in req_params['custom_stopping_strings']] + [0])
|
||||
answer = ''
|
||||
seen_content = ''
|
||||
for a in generator:
|
||||
answer = a
|
||||
|
||||
stop_string_found = False
|
||||
len_seen = len(seen_content)
|
||||
search_start = max(len_seen - longest_stop_len, 0)
|
||||
|
||||
for string in req_params['custom_stopping_strings']:
|
||||
idx = answer.find(string, search_start)
|
||||
if idx != -1:
|
||||
answer = answer[:idx] # clip it.
|
||||
stop_string_found = True
|
||||
|
||||
if stop_string_found:
|
||||
break
|
||||
|
||||
|
||||
# some reply's have an extra leading space to fit the instruction template, just clip it off from the reply.
|
||||
if edit_task[-1] != '\n' and answer and answer[0] == ' ':
|
||||
answer = answer[1:]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue