Add ExLlama support (#2444)

This commit is contained in:
oobabooga 2023-06-16 20:35:38 -03:00 committed by GitHub
parent dea43685b0
commit 9f40032d32
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 156 additions and 47 deletions

View file

@ -51,7 +51,7 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
if truncation_length is not None:
input_ids = input_ids[:, -truncation_length:]
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel'] or shared.args.cpu:
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel'] or shared.args.cpu:
return input_ids
elif shared.args.flexgen:
return input_ids.numpy()
@ -157,7 +157,7 @@ def _generate_reply(question, state, eos_token=None, stopping_strings=None, is_c
yield ''
return
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel']:
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel']:
generate_func = generate_reply_custom
elif shared.args.flexgen:
generate_func = generate_reply_flexgen
@ -283,13 +283,6 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None,
def generate_reply_custom(question, original_question, seed, state, eos_token=None, stopping_strings=None, is_chat=False):
seed = set_manual_seed(state['seed'])
generate_params = {'token_count': state['max_new_tokens']}
for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
generate_params[k] = state[k]
if shared.model.__class__.__name__ in ['LlamaCppModel']:
for k in ['mirostat_mode', 'mirostat_tau', 'mirostat_eta']:
generate_params[k] = state[k]
t0 = time.time()
reply = ''
@ -298,13 +291,13 @@ def generate_reply_custom(question, original_question, seed, state, eos_token=No
yield ''
if not state['stream']:
reply = shared.model.generate(context=question, **generate_params)
reply = shared.model.generate(question, state)
if not is_chat:
reply = apply_extensions('output', reply)
yield reply
else:
for reply in shared.model.generate_with_streaming(context=question, **generate_params):
for reply in shared.model.generate_with_streaming(question, state):
if not is_chat:
reply = apply_extensions('output', reply)