Automatically set wbits/groupsize/instruct based on model name (#1167)
This commit is contained in:
parent
9d66957207
commit
8e31f2bad4
7 changed files with 377 additions and 286 deletions
|
@ -79,7 +79,7 @@ def get_stopping_strings(state):
|
|||
stopping_strings = [f"\n{state['name1']}", f"\n{state['name2']}"]
|
||||
else:
|
||||
stopping_strings = [f"\n{state['name1']}:", f"\n{state['name2']}:"]
|
||||
stopping_strings += state['custom_stopping_strings']
|
||||
stopping_strings += eval(f"[{state['custom_stopping_strings']}]")
|
||||
return stopping_strings
|
||||
|
||||
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
model = None
|
||||
tokenizer = None
|
||||
|
@ -42,6 +45,7 @@ settings = {
|
|||
'truncation_length_min': 0,
|
||||
'truncation_length_max': 4096,
|
||||
'mode': 'cai-chat',
|
||||
'instruction_template': 'None',
|
||||
'chat_prompt_size': 2048,
|
||||
'chat_prompt_size_min': 0,
|
||||
'chat_prompt_size_max': 2048,
|
||||
|
@ -159,3 +163,21 @@ if args.cai_chat:
|
|||
|
||||
def is_chat():
|
||||
return args.chat
|
||||
|
||||
|
||||
# Loading model-specific settings (default)
|
||||
with Path(f'{args.model_dir}/config.yaml') as p:
|
||||
if p.exists():
|
||||
model_config = yaml.safe_load(open(p, 'r').read())
|
||||
else:
|
||||
model_config = {}
|
||||
|
||||
# Applying user-defined model settings
|
||||
with Path(f'{args.model_dir}/config-user.yaml') as p:
|
||||
if p.exists():
|
||||
user_config = yaml.safe_load(open(p, 'r').read())
|
||||
for k in user_config:
|
||||
if k in model_config:
|
||||
model_config[k].update(user_config[k])
|
||||
else:
|
||||
model_config[k] = user_config[k]
|
||||
|
|
|
@ -192,7 +192,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
|||
|
||||
# Handling the stopping strings
|
||||
stopping_criteria_list = transformers.StoppingCriteriaList()
|
||||
for st in [stopping_strings, state['custom_stopping_strings']]:
|
||||
for st in [stopping_strings, eval(f"[{state['custom_stopping_strings']}]")]:
|
||||
if type(st) is list and len(st) > 0:
|
||||
sentinel_token_ids = [encode(string, add_special_tokens=False) for string in st]
|
||||
stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=sentinel_token_ids, starting_idx=len(input_ids[0])))
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
import torch
|
||||
|
||||
from modules import shared
|
||||
|
||||
|
@ -16,10 +17,18 @@ with open(Path(__file__).resolve().parent / '../css/chat.js', 'r') as f:
|
|||
chat_js = f.read()
|
||||
|
||||
|
||||
def list_model_elements():
|
||||
elements = ['cpu_memory', 'auto_devices', 'disk', 'cpu', 'bf16', 'load_in_8bit', 'wbits', 'groupsize', 'model_type', 'pre_layer']
|
||||
for i in range(torch.cuda.device_count()):
|
||||
elements.append(f'gpu_memory_{i}')
|
||||
return elements
|
||||
|
||||
|
||||
def list_interface_input_elements(chat=False):
|
||||
elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings']
|
||||
if chat:
|
||||
elements += ['name1', 'name2', 'greeting', 'context', 'end_of_turn', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode']
|
||||
elements += ['name1', 'name2', 'greeting', 'context', 'end_of_turn', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode', 'instruction_template']
|
||||
elements += list_model_elements()
|
||||
return elements
|
||||
|
||||
|
||||
|
@ -27,10 +36,13 @@ def gather_interface_values(*args):
|
|||
output = {}
|
||||
for i, element in enumerate(shared.input_elements):
|
||||
output[element] = args[i]
|
||||
output['custom_stopping_strings'] = eval(f"[{output['custom_stopping_strings']}]")
|
||||
return output
|
||||
|
||||
|
||||
def apply_interface_values(state):
|
||||
return [state[i] for i in list_interface_input_elements(chat=shared.is_chat())]
|
||||
|
||||
|
||||
class ToolButton(gr.Button, gr.components.FormComponent):
|
||||
"""Small button with single emoji as text, fits inside gradio forms"""
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue