Refactor the UI
A single dictionary called 'interface_state' is now passed as input to all functions. The values are updated only when necessary. The goal is to make it easier to add new elements to the UI.
This commit is contained in:
parent
64f5c90ee7
commit
0f212093a3
3 changed files with 136 additions and 100 deletions
|
@ -69,6 +69,7 @@ def generate_softprompt_input_tensors(input_ids):
|
|||
# filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
|
||||
return inputs_embeds, filler_input_ids
|
||||
|
||||
|
||||
# Removes empty replies from gpt4chan outputs
|
||||
def fix_gpt4chan(s):
|
||||
for i in range(10):
|
||||
|
@ -77,6 +78,7 @@ def fix_gpt4chan(s):
|
|||
s = re.sub("--- [0-9]*\n\n\n---", "---", s)
|
||||
return s
|
||||
|
||||
|
||||
# Fix the LaTeX equations in galactica
|
||||
def fix_galactica(s):
|
||||
s = s.replace(r'\[', r'$')
|
||||
|
@ -117,9 +119,9 @@ def stop_everything_event():
|
|||
shared.stop_everything = True
|
||||
|
||||
|
||||
def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]):
|
||||
def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
||||
clear_torch_cache()
|
||||
seed = set_manual_seed(generate_state['seed'])
|
||||
seed = set_manual_seed(state['seed'])
|
||||
shared.stop_everything = False
|
||||
generate_params = {}
|
||||
t0 = time.time()
|
||||
|
@ -134,8 +136,8 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
|||
# separately and terminate the function call earlier
|
||||
if any((shared.is_RWKV, shared.is_llamacpp)):
|
||||
for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
|
||||
generate_params[k] = generate_state[k]
|
||||
generate_params['token_count'] = generate_state['max_new_tokens']
|
||||
generate_params[k] = state[k]
|
||||
generate_params['token_count'] = state['max_new_tokens']
|
||||
try:
|
||||
if shared.args.no_stream:
|
||||
reply = shared.model.generate(context=question, **generate_params)
|
||||
|
@ -164,7 +166,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
|||
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
|
||||
return
|
||||
|
||||
input_ids = encode(question, generate_state['max_new_tokens'], add_bos_token=generate_state['add_bos_token'])
|
||||
input_ids = encode(question, state['max_new_tokens'], add_bos_token=state['add_bos_token'])
|
||||
original_input_ids = input_ids
|
||||
output = input_ids[0]
|
||||
|
||||
|
@ -179,13 +181,13 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
|||
|
||||
if not shared.args.flexgen:
|
||||
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']:
|
||||
generate_params[k] = generate_state[k]
|
||||
generate_params[k] = state[k]
|
||||
generate_params['eos_token_id'] = eos_token_ids
|
||||
generate_params['stopping_criteria'] = stopping_criteria_list
|
||||
else:
|
||||
for k in ['max_new_tokens', 'do_sample', 'temperature']:
|
||||
generate_params[k] = generate_state[k]
|
||||
generate_params['stop'] = generate_state['eos_token_ids'][-1]
|
||||
generate_params[k] = state[k]
|
||||
generate_params['stop'] = state['eos_token_ids'][-1]
|
||||
if not shared.args.no_stream:
|
||||
generate_params['max_new_tokens'] = 8
|
||||
|
||||
|
@ -248,7 +250,7 @@ def generate_reply(question, generate_state, eos_token=None, stopping_strings=[]
|
|||
|
||||
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
|
||||
else:
|
||||
for i in range(generate_state['max_new_tokens'] // 8 + 1):
|
||||
for i in range(state['max_new_tokens'] // 8 + 1):
|
||||
clear_torch_cache()
|
||||
with torch.no_grad():
|
||||
output = shared.model.generate(**generate_params)[0]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue