Add ctransformers support (#3313)

---------

Co-authored-by: cal066 <cal066@users.noreply.github.com>
Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
Co-authored-by: randoentity <137087500+randoentity@users.noreply.github.com>
This commit is contained in:
cal066 2023-08-11 17:41:33 +00:00 committed by GitHub
parent 8dbaa20ca8
commit 7a4fcee069
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 188 additions and 43 deletions

View file

@ -41,7 +41,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False):
yield ''
return
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel']:
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel', 'CtransformersModel']:
generate_func = generate_reply_custom
else:
generate_func = generate_reply_HF
@ -90,7 +90,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False):
def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_length=None):
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel']:
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'CtransformersModel']:
input_ids = shared.tokenizer.encode(str(prompt))
input_ids = np.array(input_ids).reshape(1, len(input_ids))
else:
@ -104,7 +104,7 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
if truncation_length is not None:
input_ids = input_ids[:, -truncation_length:]
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel'] or shared.args.cpu:
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel', 'CtransformersModel'] or shared.args.cpu:
return input_ids
elif shared.args.deepspeed:
return input_ids.to(device=local_rank)