Refactor everything (#3481)
This commit is contained in:
parent
d4b851bdc8
commit
65aa11890f
19 changed files with 1306 additions and 1178 deletions
|
@ -31,8 +31,62 @@ def generate_reply(*args, **kwargs):
|
|||
shared.generation_lock.release()
|
||||
|
||||
|
||||
def get_max_prompt_length(state):
|
||||
return state['truncation_length'] - state['max_new_tokens']
|
||||
def _generate_reply(question, state, stopping_strings=None, is_chat=False):
|
||||
|
||||
# Find the appropriate generation function
|
||||
generate_func = apply_extensions('custom_generate_reply')
|
||||
if generate_func is None:
|
||||
if shared.model_name == 'None' or shared.model is None:
|
||||
logger.error("No model is loaded! Select one in the Model tab.")
|
||||
yield ''
|
||||
return
|
||||
|
||||
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel']:
|
||||
generate_func = generate_reply_custom
|
||||
else:
|
||||
generate_func = generate_reply_HF
|
||||
|
||||
# Prepare the input
|
||||
original_question = question
|
||||
if not is_chat:
|
||||
state = apply_extensions('state', state)
|
||||
question = apply_extensions('input', question, state)
|
||||
|
||||
# Find the stopping strings
|
||||
all_stop_strings = []
|
||||
for st in (stopping_strings, ast.literal_eval(f"[{state['custom_stopping_strings']}]")):
|
||||
if type(st) is list and len(st) > 0:
|
||||
all_stop_strings += st
|
||||
|
||||
if shared.args.verbose:
|
||||
print(f'\n\n{question}\n--------------------\n')
|
||||
|
||||
shared.stop_everything = False
|
||||
clear_torch_cache()
|
||||
seed = set_manual_seed(state['seed'])
|
||||
last_update = -1
|
||||
reply = ''
|
||||
is_stream = state['stream']
|
||||
if len(all_stop_strings) > 0 and not state['stream']:
|
||||
state = copy.deepcopy(state)
|
||||
state['stream'] = True
|
||||
|
||||
# Generate
|
||||
for reply in generate_func(question, original_question, seed, state, stopping_strings, is_chat=is_chat):
|
||||
reply, stop_found = apply_stopping_strings(reply, all_stop_strings)
|
||||
if is_stream:
|
||||
cur_time = time.time()
|
||||
if cur_time - last_update > 0.041666666666666664: # Limit streaming to 24 fps
|
||||
last_update = cur_time
|
||||
yield reply
|
||||
|
||||
if stop_found:
|
||||
break
|
||||
|
||||
if not is_chat:
|
||||
reply = apply_extensions('output', reply, state)
|
||||
|
||||
yield reply
|
||||
|
||||
|
||||
def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None):
|
||||
|
@ -61,6 +115,10 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
|
|||
return input_ids.cuda()
|
||||
|
||||
|
||||
def decode(output_ids, skip_special_tokens=True):
|
||||
return shared.tokenizer.decode(output_ids, skip_special_tokens)
|
||||
|
||||
|
||||
def get_encoded_length(prompt):
|
||||
length_after_extensions = apply_extensions('tokenized_length', prompt)
|
||||
if length_after_extensions is not None:
|
||||
|
@ -69,12 +127,36 @@ def get_encoded_length(prompt):
|
|||
return len(encode(prompt)[0])
|
||||
|
||||
|
||||
def decode(output_ids, skip_special_tokens=True):
|
||||
return shared.tokenizer.decode(output_ids, skip_special_tokens)
|
||||
def get_max_prompt_length(state):
|
||||
return state['truncation_length'] - state['max_new_tokens']
|
||||
|
||||
|
||||
def generate_reply_wrapper(question, state, stopping_strings=None):
|
||||
"""
|
||||
Returns formatted outputs for the UI
|
||||
"""
|
||||
reply = question if not shared.is_seq2seq else ''
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
for reply in generate_reply(question, state, stopping_strings, is_chat=False):
|
||||
if not shared.is_seq2seq:
|
||||
reply = question + reply
|
||||
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
|
||||
def formatted_outputs(reply, model_name):
|
||||
if any(s in model_name for s in ['gpt-4chan', 'gpt4chan']):
|
||||
reply = fix_gpt4chan(reply)
|
||||
return reply, generate_4chan_html(reply)
|
||||
else:
|
||||
return reply, generate_basic_html(reply)
|
||||
|
||||
|
||||
# Removes empty replies from gpt4chan outputs
|
||||
def fix_gpt4chan(s):
|
||||
"""
|
||||
Removes empty replies from gpt4chan outputs
|
||||
"""
|
||||
for i in range(10):
|
||||
s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
|
||||
s = re.sub("--- [0-9]*\n *\n---", "---", s)
|
||||
|
@ -83,8 +165,10 @@ def fix_gpt4chan(s):
|
|||
return s
|
||||
|
||||
|
||||
# Fix the LaTeX equations in galactica
|
||||
def fix_galactica(s):
|
||||
"""
|
||||
Fix the LaTeX equations in GALACTICA
|
||||
"""
|
||||
s = s.replace(r'\[', r'$')
|
||||
s = s.replace(r'\]', r'$')
|
||||
s = s.replace(r'\(', r'$')
|
||||
|
@ -109,14 +193,6 @@ def get_reply_from_output_ids(output_ids, input_ids, original_question, state, i
|
|||
return reply
|
||||
|
||||
|
||||
def formatted_outputs(reply, model_name):
|
||||
if any(s in model_name for s in ['gpt-4chan', 'gpt4chan']):
|
||||
reply = fix_gpt4chan(reply)
|
||||
return reply, generate_4chan_html(reply)
|
||||
else:
|
||||
return reply, generate_basic_html(reply)
|
||||
|
||||
|
||||
def set_manual_seed(seed):
|
||||
seed = int(seed)
|
||||
if seed == -1:
|
||||
|
@ -133,17 +209,6 @@ def stop_everything_event():
|
|||
shared.stop_everything = True
|
||||
|
||||
|
||||
def generate_reply_wrapper(question, state, stopping_strings=None):
|
||||
reply = question if not shared.is_seq2seq else ''
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
for reply in generate_reply(question, state, stopping_strings, is_chat=False):
|
||||
if not shared.is_seq2seq:
|
||||
reply = question + reply
|
||||
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
|
||||
def apply_stopping_strings(reply, all_stop_strings):
|
||||
stop_found = False
|
||||
for string in all_stop_strings:
|
||||
|
@ -169,61 +234,6 @@ def apply_stopping_strings(reply, all_stop_strings):
|
|||
return reply, stop_found
|
||||
|
||||
|
||||
def _generate_reply(question, state, stopping_strings=None, is_chat=False):
|
||||
generate_func = apply_extensions('custom_generate_reply')
|
||||
if generate_func is None:
|
||||
if shared.model_name == 'None' or shared.model is None:
|
||||
logger.error("No model is loaded! Select one in the Model tab.")
|
||||
yield ''
|
||||
return
|
||||
|
||||
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel']:
|
||||
generate_func = generate_reply_custom
|
||||
else:
|
||||
generate_func = generate_reply_HF
|
||||
|
||||
# Preparing the input
|
||||
original_question = question
|
||||
if not is_chat:
|
||||
state = apply_extensions('state', state)
|
||||
question = apply_extensions('input', question, state)
|
||||
|
||||
# Finding the stopping strings
|
||||
all_stop_strings = []
|
||||
for st in (stopping_strings, ast.literal_eval(f"[{state['custom_stopping_strings']}]")):
|
||||
if type(st) is list and len(st) > 0:
|
||||
all_stop_strings += st
|
||||
|
||||
if shared.args.verbose:
|
||||
print(f'\n\n{question}\n--------------------\n')
|
||||
|
||||
shared.stop_everything = False
|
||||
clear_torch_cache()
|
||||
seed = set_manual_seed(state['seed'])
|
||||
last_update = -1
|
||||
reply = ''
|
||||
is_stream = state['stream']
|
||||
if len(all_stop_strings) > 0 and not state['stream']:
|
||||
state = copy.deepcopy(state)
|
||||
state['stream'] = True
|
||||
|
||||
for reply in generate_func(question, original_question, seed, state, stopping_strings, is_chat=is_chat):
|
||||
reply, stop_found = apply_stopping_strings(reply, all_stop_strings)
|
||||
if is_stream:
|
||||
cur_time = time.time()
|
||||
if cur_time - last_update > 0.041666666666666664: # Limit streaming to 24 fps
|
||||
last_update = cur_time
|
||||
yield reply
|
||||
|
||||
if stop_found:
|
||||
break
|
||||
|
||||
if not is_chat:
|
||||
reply = apply_extensions('output', reply, state)
|
||||
|
||||
yield reply
|
||||
|
||||
|
||||
def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
|
||||
generate_params = {}
|
||||
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'guidance_scale']:
|
||||
|
@ -316,6 +326,9 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings
|
|||
|
||||
|
||||
def generate_reply_custom(question, original_question, seed, state, stopping_strings=None, is_chat=False):
|
||||
"""
|
||||
For models that do not use the transformers library for sampling
|
||||
"""
|
||||
seed = set_manual_seed(state['seed'])
|
||||
|
||||
t0 = time.time()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue