From 8aeae3b3f40f293c608283ddb34d33c2b5a4113c Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 26 Aug 2023 22:15:06 -0700 Subject: [PATCH] Fix llamacpp_HF loading --- modules/llamacpp_hf.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modules/llamacpp_hf.py b/modules/llamacpp_hf.py index ce8c6d1..918ce7f 100644 --- a/modules/llamacpp_hf.py +++ b/modules/llamacpp_hf.py @@ -45,7 +45,7 @@ def llama_cpp_lib(model_file: Union[str, Path] = None): class LlamacppHF(PreTrainedModel): - def __init__(self, model): + def __init__(self, model, path): super().__init__(PretrainedConfig()) self.model = model self.generation_config = GenerationConfig() @@ -64,7 +64,7 @@ class LlamacppHF(PreTrainedModel): 'n_tokens': self.model.n_tokens, 'input_ids': self.model.input_ids.copy(), 'scores': self.model.scores.copy(), - 'ctx': llama_cpp_lib().llama_new_context_with_model(model.model, model.params) + 'ctx': llama_cpp_lib(path).llama_new_context_with_model(model.model, model.params) } def _validate_model_class(self): @@ -217,4 +217,4 @@ class LlamacppHF(PreTrainedModel): Llama = llama_cpp_lib(model_file).Llama model = Llama(**params) - return LlamacppHF(model) + return LlamacppHF(model, model_file)