Make OpenAI API the default API (#4430)

This commit is contained in:
oobabooga 2023-11-06 02:38:29 -03:00 committed by GitHub
parent 84d957ba62
commit ec17a5d2b7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 769 additions and 1432 deletions

View file

@ -81,7 +81,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
# Find the maximum prompt size
max_length = get_max_prompt_length(state)
all_substrings = {
'chat': get_turn_substrings(state, instruct=False),
'chat': get_turn_substrings(state, instruct=False) if state['mode'] in ['chat', 'chat-instruct'] else None,
'instruct': get_turn_substrings(state, instruct=True)
}
@ -237,7 +237,10 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False, loading_mess
for j, reply in enumerate(generate_reply(prompt, state, stopping_strings=stopping_strings, is_chat=True)):
# Extract the reply
visible_reply = re.sub("(<USER>|<user>|{{user}})", state['name1'], reply)
visible_reply = reply
if state['mode'] in ['chat', 'chat-instruct']:
visible_reply = re.sub("(<USER>|<user>|{{user}})", state['name1'], reply)
visible_reply = html.escape(visible_reply)
if shared.stop_everything:

View file

@ -71,11 +71,12 @@ def load_model(model_name, loader=None):
'AutoAWQ': AutoAWQ_loader,
}
metadata = get_model_metadata(model_name)
if loader is None:
if shared.args.loader is not None:
loader = shared.args.loader
else:
loader = get_model_metadata(model_name)['loader']
loader = metadata['loader']
if loader is None:
logger.error('The path to the model does not exist. Exiting.')
return None, None
@ -95,6 +96,7 @@ def load_model(model_name, loader=None):
if any((shared.args.xformers, shared.args.sdp_attention)):
llama_attn_hijack.hijack_llama_attention()
shared.settings.update({k: v for k, v in metadata.items() if k in shared.settings})
logger.info(f"Loaded the model in {(time.time()-t0):.2f} seconds.")
return model, tokenizer

View file

@ -6,33 +6,32 @@ import yaml
def default_preset():
return {
'do_sample': True,
'temperature': 1,
'temperature_last': False,
'top_p': 1,
'min_p': 0,
'top_k': 0,
'typical_p': 1,
'epsilon_cutoff': 0,
'eta_cutoff': 0,
'tfs': 1,
'top_a': 0,
'repetition_penalty': 1,
'presence_penalty': 0,
'frequency_penalty': 0,
'repetition_penalty_range': 0,
'typical_p': 1,
'tfs': 1,
'top_a': 0,
'epsilon_cutoff': 0,
'eta_cutoff': 0,
'guidance_scale': 1,
'penalty_alpha': 0,
'mirostat_mode': 0,
'mirostat_tau': 5,
'mirostat_eta': 0.1,
'do_sample': True,
'encoder_repetition_penalty': 1,
'no_repeat_ngram_size': 0,
'min_length': 0,
'guidance_scale': 1,
'mirostat_mode': 0,
'mirostat_tau': 5.0,
'mirostat_eta': 0.1,
'penalty_alpha': 0,
'num_beams': 1,
'length_penalty': 1,
'early_stopping': False,
'custom_token_bans': '',
}

View file

@ -39,21 +39,21 @@ settings = {
'max_new_tokens': 200,
'max_new_tokens_min': 1,
'max_new_tokens_max': 4096,
'seed': -1,
'negative_prompt': '',
'seed': -1,
'truncation_length': 2048,
'truncation_length_min': 0,
'truncation_length_max': 32768,
'custom_stopping_strings': '',
'auto_max_new_tokens': False,
'max_tokens_second': 0,
'ban_eos_token': False,
'custom_stopping_strings': '',
'custom_token_bans': '',
'auto_max_new_tokens': False,
'ban_eos_token': False,
'add_bos_token': True,
'skip_special_tokens': True,
'stream': True,
'name1': 'You',
'character': 'Assistant',
'name1': 'You',
'instruction_template': 'Alpaca',
'chat-instruct_command': 'Continue the chat dialogue below. Write a single reply for the character "<|character|>".\n\n<|prompt|>',
'autoload_model': False,
@ -167,8 +167,8 @@ parser.add_argument('--ssl-certfile', type=str, help='The path to the SSL certif
parser.add_argument('--api', action='store_true', help='Enable the API extension.')
parser.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudfare.')
parser.add_argument('--public-api-id', type=str, help='Tunnel ID for named Cloudflare Tunnel. Use together with public-api option.', default=None)
parser.add_argument('--api-blocking-port', type=int, default=5000, help='The listening port for the blocking API.')
parser.add_argument('--api-streaming-port', type=int, default=5005, help='The listening port for the streaming API.')
parser.add_argument('--api-port', type=int, default=5000, help='The listening port for the API.')
parser.add_argument('--api-key', type=str, default='', help='API authentication key.')
# Multimodal
parser.add_argument('--multimodal-pipeline', type=str, default=None, help='The multimodal pipeline to use. Examples: llava-7b, llava-13b.')
@ -178,6 +178,8 @@ parser.add_argument('--notebook', action='store_true', help='DEPRECATED')
parser.add_argument('--chat', action='store_true', help='DEPRECATED')
parser.add_argument('--no-stream', action='store_true', help='DEPRECATED')
parser.add_argument('--mul_mat_q', action='store_true', help='DEPRECATED')
parser.add_argument('--api-blocking-port', type=int, default=5000, help='DEPRECATED')
parser.add_argument('--api-streaming-port', type=int, default=5005, help='DEPRECATED')
args = parser.parse_args()
args_defaults = parser.parse_args([])
@ -233,10 +235,13 @@ def fix_loader_name(name):
return 'AutoAWQ'
def add_extension(name):
def add_extension(name, last=False):
if args.extensions is None:
args.extensions = [name]
elif 'api' not in args.extensions:
elif last:
args.extensions = [x for x in args.extensions if x != name]
args.extensions.append(name)
elif name not in args.extensions:
args.extensions.append(name)
@ -246,14 +251,15 @@ def is_chat():
args.loader = fix_loader_name(args.loader)
# Activate the API extension
if args.api or args.public_api:
add_extension('api')
# Activate the multimodal extension
if args.multimodal_pipeline is not None:
add_extension('multimodal')
# Activate the API extension
if args.api:
# add_extension('openai', last=True)
add_extension('api', last=True)
# Load model-specific settings
with Path(f'{args.model_dir}/config.yaml') as p:
if p.exists():

View file

@ -56,7 +56,10 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap
# Find the stopping strings
all_stop_strings = []
for st in (stopping_strings, ast.literal_eval(f"[{state['custom_stopping_strings']}]")):
for st in (stopping_strings, state['custom_stopping_strings']):
if type(st) is str:
st = ast.literal_eval(f"[{st}]")
if type(st) is list and len(st) > 0:
all_stop_strings += st

View file

@ -215,9 +215,6 @@ def load_model_wrapper(selected_model, loader, autoload=False):
if 'instruction_template' in settings:
output += '\n\nIt seems to be an instruction-following model with template "{}". In the chat tab, instruct or chat-instruct modes should be used.'.format(settings['instruction_template'])
# Applying the changes to the global shared settings (in-memory)
shared.settings.update({k: v for k, v in settings.items() if k in shared.settings})
yield output
else:
yield f"Failed to load `{selected_model}`."