Add ctransformers support (#3313)

---------

Co-authored-by: cal066 <cal066@users.noreply.github.com>
Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
Co-authored-by: randoentity <137087500+randoentity@users.noreply.github.com>
This commit is contained in:
cal066 2023-08-11 17:41:33 +00:00 committed by GitHub
parent 8dbaa20ca8
commit 7a4fcee069
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 188 additions and 43 deletions

View file

@ -58,7 +58,8 @@ def load_model(model_name, loader=None):
'llamacpp_HF': llamacpp_HF_loader,
'RWKV': RWKV_loader,
'ExLlama': ExLlama_loader,
'ExLlama_HF': ExLlama_HF_loader
'ExLlama_HF': ExLlama_HF_loader,
'ctransformers': ctransformers_loader,
}
p = Path(model_name)
@ -242,7 +243,7 @@ def llamacpp_loader(model_name):
else:
model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('*ggml*.bin'))[0]
logger.info(f"llama.cpp weights detected: {model_file}\n")
logger.info(f"llama.cpp weights detected: {model_file}")
model, tokenizer = LlamaCppModel.from_pretrained(model_file)
return model, tokenizer
@ -268,6 +269,24 @@ def llamacpp_HF_loader(model_name):
return model, tokenizer
def ctransformers_loader(model_name):
from modules.ctransformers_model import CtransformersModel
path = Path(f'{shared.args.model_dir}/{model_name}')
ctrans = CtransformersModel()
if ctrans.model_type_is_auto():
model_file = path
else:
if path.is_file():
model_file = path
else:
model_file = list(Path(f'{shared.args.model_dir}/{model_name}').glob('*.bin'))[0]
logger.info(f'ctransformers weights detected: {model_file}')
model, tokenizer = ctrans.from_pretrained(model_file)
return model, tokenizer
def GPTQ_loader(model_name):
# Monkey patch