Refactor text_generation.py, add support for custom generation functions (#1817)
This commit is contained in:
parent
876fbb97c0
commit
8aafb1f796
12 changed files with 289 additions and 195 deletions
|
@ -21,6 +21,7 @@ def get_max_prompt_length(state):
|
|||
max_length = state['truncation_length'] - state['max_new_tokens']
|
||||
if shared.soft_prompt:
|
||||
max_length -= shared.soft_prompt_tensor.shape[1]
|
||||
|
||||
return max_length
|
||||
|
||||
|
||||
|
@ -62,6 +63,36 @@ def decode(output_ids, skip_special_tokens=True):
|
|||
return shared.tokenizer.decode(output_ids, skip_special_tokens)
|
||||
|
||||
|
||||
def generate_softprompt_input_tensors(input_ids):
|
||||
inputs_embeds = shared.model.transformer.wte(input_ids)
|
||||
inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1)
|
||||
filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(shared.model.device)
|
||||
# filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
|
||||
return inputs_embeds, filler_input_ids
|
||||
|
||||
|
||||
# Removes empty replies from gpt4chan outputs
|
||||
def fix_gpt4chan(s):
|
||||
for i in range(10):
|
||||
s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
|
||||
s = re.sub("--- [0-9]*\n *\n---", "---", s)
|
||||
s = re.sub("--- [0-9]*\n\n\n---", "---", s)
|
||||
|
||||
return s
|
||||
|
||||
|
||||
# Fix the LaTeX equations in galactica
|
||||
def fix_galactica(s):
|
||||
s = s.replace(r'\[', r'$')
|
||||
s = s.replace(r'\]', r'$')
|
||||
s = s.replace(r'\(', r'$')
|
||||
s = s.replace(r'\)', r'$')
|
||||
s = s.replace(r'$$', r'$')
|
||||
s = re.sub(r'\n', r'\n\n', s)
|
||||
s = re.sub(r"\n{3,}", "\n\n", s)
|
||||
return s
|
||||
|
||||
|
||||
def get_reply_from_output_ids(output_ids, input_ids, original_question, state):
|
||||
if shared.model_type == 'HF_seq2seq':
|
||||
reply = decode(output_ids, state['skip_special_tokens'])
|
||||
|
@ -81,35 +112,6 @@ def get_reply_from_output_ids(output_ids, input_ids, original_question, state):
|
|||
return reply
|
||||
|
||||
|
||||
def generate_softprompt_input_tensors(input_ids):
|
||||
inputs_embeds = shared.model.transformer.wte(input_ids)
|
||||
inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1)
|
||||
filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(shared.model.device)
|
||||
# filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
|
||||
return inputs_embeds, filler_input_ids
|
||||
|
||||
|
||||
# Removes empty replies from gpt4chan outputs
|
||||
def fix_gpt4chan(s):
|
||||
for i in range(10):
|
||||
s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
|
||||
s = re.sub("--- [0-9]*\n *\n---", "---", s)
|
||||
s = re.sub("--- [0-9]*\n\n\n---", "---", s)
|
||||
return s
|
||||
|
||||
|
||||
# Fix the LaTeX equations in galactica
|
||||
def fix_galactica(s):
|
||||
s = s.replace(r'\[', r'$')
|
||||
s = s.replace(r'\]', r'$')
|
||||
s = s.replace(r'\(', r'$')
|
||||
s = s.replace(r'\)', r'$')
|
||||
s = s.replace(r'$$', r'$')
|
||||
s = re.sub(r'\n', r'\n\n', s)
|
||||
s = re.sub(r"\n{3,}", "\n\n", s)
|
||||
return s
|
||||
|
||||
|
||||
def formatted_outputs(reply, model_name):
|
||||
if not shared.is_chat():
|
||||
if shared.model_type == 'galactica':
|
||||
|
@ -140,51 +142,21 @@ def stop_everything_event():
|
|||
shared.stop_everything = True
|
||||
|
||||
|
||||
def get_generate_params(state):
|
||||
generate_params = {}
|
||||
|
||||
# Models that are not on transformers
|
||||
if shared.model_type in ['rwkv', 'llamacpp']:
|
||||
generate_params['token_count'] = state['max_new_tokens']
|
||||
for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
|
||||
generate_params[k] = state[k]
|
||||
else:
|
||||
# FlexGen
|
||||
if shared.args.flexgen:
|
||||
for k in ['max_new_tokens', 'do_sample', 'temperature']:
|
||||
generate_params[k] = state[k]
|
||||
|
||||
if not shared.args.no_stream:
|
||||
generate_params['max_new_tokens'] = 8
|
||||
|
||||
# transformers
|
||||
else:
|
||||
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']:
|
||||
generate_params[k] = state[k]
|
||||
|
||||
if state['ban_eos_token']:
|
||||
generate_params['suppress_tokens'] = [shared.tokenizer.eos_token_id]
|
||||
|
||||
if shared.args.no_cache:
|
||||
generate_params.update({'use_cache': False})
|
||||
|
||||
if shared.args.deepspeed:
|
||||
generate_params.update({'synced_gpus': True})
|
||||
|
||||
return generate_params
|
||||
|
||||
|
||||
def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
||||
if shared.model_name == 'None' or shared.model is None:
|
||||
logging.error("No model is loaded! Select one in the Model tab.")
|
||||
yield formatted_outputs(question, shared.model_name)
|
||||
return
|
||||
state = apply_extensions('state', state)
|
||||
generate_func = apply_extensions('custom_generate_reply')
|
||||
if generate_func is None:
|
||||
if shared.model_name == 'None' or shared.model is None:
|
||||
logging.error("No model is loaded! Select one in the Model tab.")
|
||||
yield formatted_outputs(question, shared.model_name)
|
||||
return
|
||||
|
||||
clear_torch_cache()
|
||||
seed = set_manual_seed(state['seed'])
|
||||
shared.stop_everything = False
|
||||
generate_params = get_generate_params(state)
|
||||
t0 = time.time()
|
||||
if shared.model_type in ['rwkv', 'llamacpp']:
|
||||
generate_func = generate_reply_custom
|
||||
elif shared.args.flexgen:
|
||||
generate_func = generate_reply_flexgen
|
||||
else:
|
||||
generate_func = generate_reply_HF
|
||||
|
||||
# Preparing the input
|
||||
original_question = question
|
||||
|
@ -194,42 +166,31 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
|||
if shared.args.verbose:
|
||||
print(f'\n\n{question}\n--------------------\n')
|
||||
|
||||
# If the model is not on transformers, handle it separately and end this
|
||||
# function call earlier.
|
||||
if shared.model_type in ['rwkv', 'llamacpp']:
|
||||
shared.stop_everything = False
|
||||
clear_torch_cache()
|
||||
seed = set_manual_seed(state['seed'])
|
||||
for reply in generate_func(question, original_question, seed, state, eos_token, stopping_strings):
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
try:
|
||||
if shared.args.no_stream:
|
||||
reply = shared.model.generate(context=question, **generate_params)
|
||||
output = original_question + reply
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions('output', reply)
|
||||
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
else:
|
||||
if not shared.is_chat():
|
||||
yield formatted_outputs(question, shared.model_name)
|
||||
def generate_reply_HF(question, original_question, seed, state, eos_token=None, stopping_strings=[]):
|
||||
generate_params = {}
|
||||
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']:
|
||||
generate_params[k] = state[k]
|
||||
|
||||
for reply in shared.model.generate_with_streaming(context=question, **generate_params):
|
||||
output = original_question + reply
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions('output', reply)
|
||||
if state['ban_eos_token']:
|
||||
generate_params['suppress_tokens'] = [shared.tokenizer.eos_token_id]
|
||||
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
if shared.args.no_cache:
|
||||
generate_params.update({'use_cache': False})
|
||||
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
t1 = time.time()
|
||||
original_tokens = len(encode(original_question)[0])
|
||||
new_tokens = len(encode(output)[0]) - original_tokens
|
||||
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
|
||||
return
|
||||
if shared.args.deepspeed:
|
||||
generate_params.update({'synced_gpus': True})
|
||||
|
||||
# Encode the input
|
||||
input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
|
||||
output = input_ids[0]
|
||||
cuda = not any((shared.args.cpu, shared.args.deepspeed, shared.args.flexgen))
|
||||
cuda = not any((shared.args.cpu, shared.args.deepspeed))
|
||||
|
||||
# Find the eos tokens
|
||||
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
|
||||
|
@ -259,15 +220,16 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
|||
break
|
||||
|
||||
# Update generate_params with the eos token and the stopping strings
|
||||
if shared.args.flexgen:
|
||||
generate_params['stop'] = eos_token_ids[-1]
|
||||
else:
|
||||
generate_params['eos_token_id'] = eos_token_ids
|
||||
generate_params['stopping_criteria'] = stopping_criteria_list
|
||||
generate_params['eos_token_id'] = eos_token_ids
|
||||
generate_params['stopping_criteria'] = stopping_criteria_list
|
||||
|
||||
t0 = time.time()
|
||||
try:
|
||||
if not shared.is_chat() and shared.model_type != 'HF_seq2seq':
|
||||
yield original_question
|
||||
|
||||
# Generate the entire reply at once.
|
||||
if shared.args.no_stream:
|
||||
if not state['stream']:
|
||||
with torch.no_grad():
|
||||
output = shared.model.generate(**generate_params)[0]
|
||||
if cuda:
|
||||
|
@ -276,12 +238,11 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
|||
if shared.soft_prompt:
|
||||
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
||||
|
||||
reply = get_reply_from_output_ids(output, input_ids, original_question, state)
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
yield get_reply_from_output_ids(output, input_ids, original_question, state)
|
||||
|
||||
# Stream the reply 1 token at a time.
|
||||
# This is based on the trick of using 'stopping_criteria' to create an iterator.
|
||||
elif not shared.args.flexgen:
|
||||
else:
|
||||
|
||||
def generate_with_callback(callback=None, **kwargs):
|
||||
kwargs['stopping_criteria'].append(Stream(callback_func=callback))
|
||||
|
@ -292,45 +253,118 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
|||
def generate_with_streaming(**kwargs):
|
||||
return Iteratorize(generate_with_callback, kwargs, callback=None)
|
||||
|
||||
if not shared.is_chat() and shared.model_type != 'HF_seq2seq':
|
||||
yield formatted_outputs(original_question, shared.model_name)
|
||||
|
||||
with generate_with_streaming(**generate_params) as generator:
|
||||
for output in generator:
|
||||
if shared.soft_prompt:
|
||||
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
||||
|
||||
reply = get_reply_from_output_ids(output, input_ids, original_question, state)
|
||||
yield get_reply_from_output_ids(output, input_ids, original_question, state)
|
||||
if output[-1] in eos_token_ids:
|
||||
break
|
||||
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
|
||||
else:
|
||||
for i in range(state['max_new_tokens'] // 8 + 1):
|
||||
clear_torch_cache()
|
||||
with torch.no_grad():
|
||||
output = shared.model.generate(**generate_params)[0]
|
||||
|
||||
if shared.soft_prompt:
|
||||
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
||||
|
||||
reply = get_reply_from_output_ids(output, input_ids, original_question, state)
|
||||
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
|
||||
break
|
||||
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
input_ids = np.reshape(output, (1, output.shape[0]))
|
||||
if shared.soft_prompt:
|
||||
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
|
||||
generate_params.update({'inputs_embeds': inputs_embeds})
|
||||
generate_params.update({'inputs': filler_input_ids})
|
||||
else:
|
||||
generate_params.update({'inputs': input_ids})
|
||||
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
t1 = time.time()
|
||||
original_tokens = len(original_input_ids[0])
|
||||
new_tokens = len(output) - (original_tokens if shared.model_type != 'HF_seq2seq' else 0)
|
||||
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
|
||||
return
|
||||
|
||||
|
||||
def generate_reply_custom(question, original_question, seed, state, eos_token=None, stopping_strings=[]):
|
||||
seed = set_manual_seed(state['seed'])
|
||||
generate_params = {'token_count': state['max_new_tokens']}
|
||||
for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
|
||||
generate_params[k] = state[k]
|
||||
|
||||
t0 = time.time()
|
||||
try:
|
||||
if not shared.is_chat():
|
||||
yield question
|
||||
|
||||
if not state['stream']:
|
||||
reply = shared.model.generate(context=question, **generate_params)
|
||||
output = original_question + reply
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions('output', reply)
|
||||
|
||||
yield reply
|
||||
else:
|
||||
|
||||
for reply in shared.model.generate_with_streaming(context=question, **generate_params):
|
||||
output = original_question + reply
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions('output', reply)
|
||||
|
||||
yield reply
|
||||
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
t1 = time.time()
|
||||
original_tokens = len(encode(original_question)[0])
|
||||
new_tokens = len(encode(output)[0]) - original_tokens
|
||||
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
|
||||
return
|
||||
|
||||
|
||||
def generate_reply_flexgen(question, original_question, seed, state, eos_token=None, stopping_strings=[]):
|
||||
generate_params = {}
|
||||
for k in ['max_new_tokens', 'do_sample', 'temperature']:
|
||||
generate_params[k] = state[k]
|
||||
|
||||
if state['stream']:
|
||||
generate_params['max_new_tokens'] = 8
|
||||
|
||||
# Encode the input
|
||||
input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
|
||||
output = input_ids[0]
|
||||
|
||||
# Find the eos tokens
|
||||
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
|
||||
if eos_token is not None:
|
||||
eos_token_ids.append(int(encode(eos_token)[0][-1]))
|
||||
|
||||
# Add the encoded tokens to generate_params
|
||||
question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None)
|
||||
original_input_ids = input_ids
|
||||
generate_params.update({'inputs': input_ids})
|
||||
if inputs_embeds is not None:
|
||||
generate_params.update({'inputs_embeds': inputs_embeds})
|
||||
|
||||
# Update generate_params with the eos token and the stopping strings
|
||||
generate_params['stop'] = eos_token_ids[-1]
|
||||
|
||||
t0 = time.time()
|
||||
try:
|
||||
if not shared.is_chat():
|
||||
yield question
|
||||
|
||||
# Generate the entire reply at once.
|
||||
if not state['stream']:
|
||||
with torch.no_grad():
|
||||
output = shared.model.generate(**generate_params)[0]
|
||||
|
||||
yield get_reply_from_output_ids(output, input_ids, original_question, state)
|
||||
|
||||
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
|
||||
else:
|
||||
for i in range(state['max_new_tokens'] // 8 + 1):
|
||||
if shared.stop_everything:
|
||||
break
|
||||
|
||||
clear_torch_cache()
|
||||
with torch.no_grad():
|
||||
output = shared.model.generate(**generate_params)[0]
|
||||
|
||||
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
|
||||
break
|
||||
|
||||
yield get_reply_from_output_ids(output, original_input_ids, original_question, state)
|
||||
input_ids = np.reshape(output, (1, output.shape[0]))
|
||||
generate_params.update({'inputs': input_ids})
|
||||
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue