Refactor text_generation.py, add support for custom generation functions (#1817)

This commit is contained in:
oobabooga 2023-05-05 18:53:03 -03:00 committed by GitHub
parent 876fbb97c0
commit 8aafb1f796
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 289 additions and 195 deletions

View file

@ -21,6 +21,7 @@ def get_max_prompt_length(state):
max_length = state['truncation_length'] - state['max_new_tokens']
if shared.soft_prompt:
max_length -= shared.soft_prompt_tensor.shape[1]
return max_length
@ -62,6 +63,36 @@ def decode(output_ids, skip_special_tokens=True):
return shared.tokenizer.decode(output_ids, skip_special_tokens)
def generate_softprompt_input_tensors(input_ids):
inputs_embeds = shared.model.transformer.wte(input_ids)
inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1)
filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(shared.model.device)
# filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
return inputs_embeds, filler_input_ids
# Removes empty replies from gpt4chan outputs
def fix_gpt4chan(s):
for i in range(10):
s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
s = re.sub("--- [0-9]*\n *\n---", "---", s)
s = re.sub("--- [0-9]*\n\n\n---", "---", s)
return s
# Fix the LaTeX equations in galactica
def fix_galactica(s):
s = s.replace(r'\[', r'$')
s = s.replace(r'\]', r'$')
s = s.replace(r'\(', r'$')
s = s.replace(r'\)', r'$')
s = s.replace(r'$$', r'$')
s = re.sub(r'\n', r'\n\n', s)
s = re.sub(r"\n{3,}", "\n\n", s)
return s
def get_reply_from_output_ids(output_ids, input_ids, original_question, state):
if shared.model_type == 'HF_seq2seq':
reply = decode(output_ids, state['skip_special_tokens'])
@ -81,35 +112,6 @@ def get_reply_from_output_ids(output_ids, input_ids, original_question, state):
return reply
def generate_softprompt_input_tensors(input_ids):
inputs_embeds = shared.model.transformer.wte(input_ids)
inputs_embeds = torch.cat((shared.soft_prompt_tensor, inputs_embeds), dim=1)
filler_input_ids = torch.zeros((1, inputs_embeds.shape[1]), dtype=input_ids.dtype).to(shared.model.device)
# filler_input_ids += shared.model.config.bos_token_id # setting dummy input_ids to bos tokens
return inputs_embeds, filler_input_ids
# Removes empty replies from gpt4chan outputs
def fix_gpt4chan(s):
for i in range(10):
s = re.sub("--- [0-9]*\n>>[0-9]*\n---", "---", s)
s = re.sub("--- [0-9]*\n *\n---", "---", s)
s = re.sub("--- [0-9]*\n\n\n---", "---", s)
return s
# Fix the LaTeX equations in galactica
def fix_galactica(s):
s = s.replace(r'\[', r'$')
s = s.replace(r'\]', r'$')
s = s.replace(r'\(', r'$')
s = s.replace(r'\)', r'$')
s = s.replace(r'$$', r'$')
s = re.sub(r'\n', r'\n\n', s)
s = re.sub(r"\n{3,}", "\n\n", s)
return s
def formatted_outputs(reply, model_name):
if not shared.is_chat():
if shared.model_type == 'galactica':
@ -140,51 +142,21 @@ def stop_everything_event():
shared.stop_everything = True
def get_generate_params(state):
generate_params = {}
# Models that are not on transformers
if shared.model_type in ['rwkv', 'llamacpp']:
generate_params['token_count'] = state['max_new_tokens']
for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
generate_params[k] = state[k]
else:
# FlexGen
if shared.args.flexgen:
for k in ['max_new_tokens', 'do_sample', 'temperature']:
generate_params[k] = state[k]
if not shared.args.no_stream:
generate_params['max_new_tokens'] = 8
# transformers
else:
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']:
generate_params[k] = state[k]
if state['ban_eos_token']:
generate_params['suppress_tokens'] = [shared.tokenizer.eos_token_id]
if shared.args.no_cache:
generate_params.update({'use_cache': False})
if shared.args.deepspeed:
generate_params.update({'synced_gpus': True})
return generate_params
def generate_reply(question, state, eos_token=None, stopping_strings=[]):
if shared.model_name == 'None' or shared.model is None:
logging.error("No model is loaded! Select one in the Model tab.")
yield formatted_outputs(question, shared.model_name)
return
state = apply_extensions('state', state)
generate_func = apply_extensions('custom_generate_reply')
if generate_func is None:
if shared.model_name == 'None' or shared.model is None:
logging.error("No model is loaded! Select one in the Model tab.")
yield formatted_outputs(question, shared.model_name)
return
clear_torch_cache()
seed = set_manual_seed(state['seed'])
shared.stop_everything = False
generate_params = get_generate_params(state)
t0 = time.time()
if shared.model_type in ['rwkv', 'llamacpp']:
generate_func = generate_reply_custom
elif shared.args.flexgen:
generate_func = generate_reply_flexgen
else:
generate_func = generate_reply_HF
# Preparing the input
original_question = question
@ -194,42 +166,31 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
if shared.args.verbose:
print(f'\n\n{question}\n--------------------\n')
# If the model is not on transformers, handle it separately and end this
# function call earlier.
if shared.model_type in ['rwkv', 'llamacpp']:
shared.stop_everything = False
clear_torch_cache()
seed = set_manual_seed(state['seed'])
for reply in generate_func(question, original_question, seed, state, eos_token, stopping_strings):
yield formatted_outputs(reply, shared.model_name)
try:
if shared.args.no_stream:
reply = shared.model.generate(context=question, **generate_params)
output = original_question + reply
if not shared.is_chat():
reply = original_question + apply_extensions('output', reply)
yield formatted_outputs(reply, shared.model_name)
else:
if not shared.is_chat():
yield formatted_outputs(question, shared.model_name)
def generate_reply_HF(question, original_question, seed, state, eos_token=None, stopping_strings=[]):
generate_params = {}
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']:
generate_params[k] = state[k]
for reply in shared.model.generate_with_streaming(context=question, **generate_params):
output = original_question + reply
if not shared.is_chat():
reply = original_question + apply_extensions('output', reply)
if state['ban_eos_token']:
generate_params['suppress_tokens'] = [shared.tokenizer.eos_token_id]
yield formatted_outputs(reply, shared.model_name)
if shared.args.no_cache:
generate_params.update({'use_cache': False})
except Exception:
traceback.print_exc()
finally:
t1 = time.time()
original_tokens = len(encode(original_question)[0])
new_tokens = len(encode(output)[0]) - original_tokens
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
return
if shared.args.deepspeed:
generate_params.update({'synced_gpus': True})
# Encode the input
input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
output = input_ids[0]
cuda = not any((shared.args.cpu, shared.args.deepspeed, shared.args.flexgen))
cuda = not any((shared.args.cpu, shared.args.deepspeed))
# Find the eos tokens
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
@ -259,15 +220,16 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
break
# Update generate_params with the eos token and the stopping strings
if shared.args.flexgen:
generate_params['stop'] = eos_token_ids[-1]
else:
generate_params['eos_token_id'] = eos_token_ids
generate_params['stopping_criteria'] = stopping_criteria_list
generate_params['eos_token_id'] = eos_token_ids
generate_params['stopping_criteria'] = stopping_criteria_list
t0 = time.time()
try:
if not shared.is_chat() and shared.model_type != 'HF_seq2seq':
yield original_question
# Generate the entire reply at once.
if shared.args.no_stream:
if not state['stream']:
with torch.no_grad():
output = shared.model.generate(**generate_params)[0]
if cuda:
@ -276,12 +238,11 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
reply = get_reply_from_output_ids(output, input_ids, original_question, state)
yield formatted_outputs(reply, shared.model_name)
yield get_reply_from_output_ids(output, input_ids, original_question, state)
# Stream the reply 1 token at a time.
# This is based on the trick of using 'stopping_criteria' to create an iterator.
elif not shared.args.flexgen:
else:
def generate_with_callback(callback=None, **kwargs):
kwargs['stopping_criteria'].append(Stream(callback_func=callback))
@ -292,45 +253,118 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
def generate_with_streaming(**kwargs):
return Iteratorize(generate_with_callback, kwargs, callback=None)
if not shared.is_chat() and shared.model_type != 'HF_seq2seq':
yield formatted_outputs(original_question, shared.model_name)
with generate_with_streaming(**generate_params) as generator:
for output in generator:
if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
reply = get_reply_from_output_ids(output, input_ids, original_question, state)
yield get_reply_from_output_ids(output, input_ids, original_question, state)
if output[-1] in eos_token_ids:
break
yield formatted_outputs(reply, shared.model_name)
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
else:
for i in range(state['max_new_tokens'] // 8 + 1):
clear_torch_cache()
with torch.no_grad():
output = shared.model.generate(**generate_params)[0]
if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
reply = get_reply_from_output_ids(output, input_ids, original_question, state)
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
break
yield formatted_outputs(reply, shared.model_name)
input_ids = np.reshape(output, (1, output.shape[0]))
if shared.soft_prompt:
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
generate_params.update({'inputs_embeds': inputs_embeds})
generate_params.update({'inputs': filler_input_ids})
else:
generate_params.update({'inputs': input_ids})
yield formatted_outputs(reply, shared.model_name)
except Exception:
traceback.print_exc()
finally:
t1 = time.time()
original_tokens = len(original_input_ids[0])
new_tokens = len(output) - (original_tokens if shared.model_type != 'HF_seq2seq' else 0)
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
return
def generate_reply_custom(question, original_question, seed, state, eos_token=None, stopping_strings=[]):
seed = set_manual_seed(state['seed'])
generate_params = {'token_count': state['max_new_tokens']}
for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
generate_params[k] = state[k]
t0 = time.time()
try:
if not shared.is_chat():
yield question
if not state['stream']:
reply = shared.model.generate(context=question, **generate_params)
output = original_question + reply
if not shared.is_chat():
reply = original_question + apply_extensions('output', reply)
yield reply
else:
for reply in shared.model.generate_with_streaming(context=question, **generate_params):
output = original_question + reply
if not shared.is_chat():
reply = original_question + apply_extensions('output', reply)
yield reply
except Exception:
traceback.print_exc()
finally:
t1 = time.time()
original_tokens = len(encode(original_question)[0])
new_tokens = len(encode(output)[0]) - original_tokens
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
return
def generate_reply_flexgen(question, original_question, seed, state, eos_token=None, stopping_strings=[]):
generate_params = {}
for k in ['max_new_tokens', 'do_sample', 'temperature']:
generate_params[k] = state[k]
if state['stream']:
generate_params['max_new_tokens'] = 8
# Encode the input
input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
output = input_ids[0]
# Find the eos tokens
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
if eos_token is not None:
eos_token_ids.append(int(encode(eos_token)[0][-1]))
# Add the encoded tokens to generate_params
question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None)
original_input_ids = input_ids
generate_params.update({'inputs': input_ids})
if inputs_embeds is not None:
generate_params.update({'inputs_embeds': inputs_embeds})
# Update generate_params with the eos token and the stopping strings
generate_params['stop'] = eos_token_ids[-1]
t0 = time.time()
try:
if not shared.is_chat():
yield question
# Generate the entire reply at once.
if not state['stream']:
with torch.no_grad():
output = shared.model.generate(**generate_params)[0]
yield get_reply_from_output_ids(output, input_ids, original_question, state)
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
else:
for i in range(state['max_new_tokens'] // 8 + 1):
if shared.stop_everything:
break
clear_torch_cache()
with torch.no_grad():
output = shared.model.generate(**generate_params)[0]
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
break
yield get_reply_from_output_ids(output, original_input_ids, original_question, state)
input_ids = np.reshape(output, (1, output.shape[0]))
generate_params.update({'inputs': input_ids})
except Exception:
traceback.print_exc()
finally: