Add CFG to llamacpp_HF (second attempt) (#3678)
This commit is contained in:
parent
d6934bc7bc
commit
3320accfdc
3 changed files with 14 additions and 6 deletions
|
@ -38,16 +38,17 @@ class LlamacppHF(PreTrainedModel):
|
|||
self.llamacpp_cache = {
|
||||
'n_tokens': self.model.n_tokens,
|
||||
'input_ids': self.model.input_ids,
|
||||
'scores': self.model.scores
|
||||
'scores': self.model.scores,
|
||||
'ctx': self.model.ctx
|
||||
}
|
||||
|
||||
if shared.args.cfg_cache:
|
||||
logger.warning('CFG is currently bugged and not functional for llamacpp_HF. Contributions are welcome.')
|
||||
self.past_seq_negative = None
|
||||
self.llamacpp_cache_negative = {
|
||||
'n_tokens': self.model.n_tokens,
|
||||
'input_ids': self.model.input_ids.copy(),
|
||||
'scores': self.model.scores.copy()
|
||||
'scores': self.model.scores.copy(),
|
||||
'ctx': llama_cpp_lib().llama_new_context_with_model(model.model, model.params)
|
||||
}
|
||||
|
||||
def _validate_model_class(self):
|
||||
|
@ -63,25 +64,29 @@ class LlamacppHF(PreTrainedModel):
|
|||
self.llamacpp_cache.update({
|
||||
'n_tokens': self.model.n_tokens,
|
||||
'input_ids': self.model.input_ids,
|
||||
'scores': self.model.scores
|
||||
'scores': self.model.scores,
|
||||
'ctx': self.model.ctx
|
||||
})
|
||||
|
||||
def save_negative_cache(self):
|
||||
self.llamacpp_cache_negative.update({
|
||||
'n_tokens': self.model.n_tokens,
|
||||
'input_ids': self.model.input_ids,
|
||||
'scores': self.model.scores
|
||||
'scores': self.model.scores,
|
||||
'ctx': self.model.ctx
|
||||
})
|
||||
|
||||
def load_cache(self):
|
||||
self.model.n_tokens = self.llamacpp_cache['n_tokens']
|
||||
self.model.input_ids = self.llamacpp_cache['input_ids']
|
||||
self.model.scores = self.llamacpp_cache['scores']
|
||||
self.model.ctx = self.llamacpp_cache['ctx']
|
||||
|
||||
def load_negative_cache(self):
|
||||
self.model.n_tokens = self.llamacpp_cache_negative['n_tokens']
|
||||
self.model.input_ids = self.llamacpp_cache_negative['input_ids']
|
||||
self.model.scores = self.llamacpp_cache_negative['scores']
|
||||
self.model.ctx = self.llamacpp_cache_negative['ctx']
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
|
@ -95,7 +100,6 @@ class LlamacppHF(PreTrainedModel):
|
|||
if len(args) > 0:
|
||||
if not shared.args.cfg_cache:
|
||||
logger.error("Please enable the cfg-cache option to use CFG with llamacpp_HF.")
|
||||
logger.warning('CFG is currently bugged and not functional for llamacpp_HF. Contributions are welcome.')
|
||||
return
|
||||
|
||||
input_ids = args[0]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue