Remove flexgen support
This commit is contained in:
parent
5134d5b1c6
commit
75c2dd38cf
8 changed files with 3 additions and 233 deletions
|
@ -53,8 +53,6 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
|
|||
|
||||
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel'] or shared.args.cpu:
|
||||
return input_ids
|
||||
elif shared.args.flexgen:
|
||||
return input_ids.numpy()
|
||||
elif shared.args.deepspeed:
|
||||
return input_ids.to(device=local_rank)
|
||||
elif torch.backends.mps.is_available():
|
||||
|
@ -182,8 +180,6 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False):
|
|||
|
||||
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel']:
|
||||
generate_func = generate_reply_custom
|
||||
elif shared.args.flexgen:
|
||||
generate_func = generate_reply_flexgen
|
||||
else:
|
||||
generate_func = generate_reply_HF
|
||||
|
||||
|
@ -339,66 +335,3 @@ def generate_reply_custom(question, original_question, seed, state, stopping_str
|
|||
new_tokens = len(encode(original_question + reply)[0]) - original_tokens
|
||||
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
|
||||
return
|
||||
|
||||
|
||||
def generate_reply_flexgen(question, original_question, seed, state, stopping_strings=None, is_chat=False):
|
||||
generate_params = {}
|
||||
for k in ['max_new_tokens', 'do_sample', 'temperature']:
|
||||
generate_params[k] = state[k]
|
||||
|
||||
if state['stream']:
|
||||
generate_params['max_new_tokens'] = 8
|
||||
|
||||
# Encode the input
|
||||
input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
|
||||
output = input_ids[0]
|
||||
|
||||
# Find the eos tokens
|
||||
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
|
||||
if not state['ban_eos_token']:
|
||||
generate_params['stop'] = eos_token_ids[-1]
|
||||
|
||||
# Add the encoded tokens to generate_params
|
||||
question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None)
|
||||
original_input_ids = input_ids
|
||||
generate_params.update({'inputs': input_ids})
|
||||
if inputs_embeds is not None:
|
||||
generate_params.update({'inputs_embeds': inputs_embeds})
|
||||
|
||||
t0 = time.time()
|
||||
try:
|
||||
if not is_chat:
|
||||
yield ''
|
||||
|
||||
# Generate the entire reply at once.
|
||||
if not state['stream']:
|
||||
with torch.no_grad():
|
||||
output = shared.model.generate(**generate_params)[0]
|
||||
|
||||
yield get_reply_from_output_ids(output, input_ids, original_question, state, is_chat=is_chat)
|
||||
|
||||
# Stream the output naively for FlexGen since it doesn't support 'stopping_criteria'
|
||||
else:
|
||||
for i in range(state['max_new_tokens'] // 8 + 1):
|
||||
if shared.stop_everything:
|
||||
break
|
||||
|
||||
clear_torch_cache()
|
||||
with torch.no_grad():
|
||||
output = shared.model.generate(**generate_params)[0]
|
||||
|
||||
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
|
||||
break
|
||||
|
||||
yield get_reply_from_output_ids(output, original_input_ids, original_question, state)
|
||||
input_ids = np.reshape(output, (1, output.shape[0]))
|
||||
generate_params.update({'inputs': input_ids})
|
||||
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
t1 = time.time()
|
||||
original_tokens = len(original_input_ids[0])
|
||||
new_tokens = len(output) - (original_tokens if not shared.is_seq2seq else 0)
|
||||
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
|
||||
return
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue