Add Classifier Free Guidance (CFG) for Transformers/ExLlama (#3325)
This commit is contained in:
parent
5134878344
commit
0af10ab49b
17 changed files with 131 additions and 42 deletions
|
@ -49,12 +49,11 @@ class LlamacppHF(PreTrainedModel):
|
|||
return torch.device(0)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# TODO: Some decoding methods (such as Contrastive Search) may not work at this time
|
||||
assert len(args) == 0, 'no *args should be passed to forward'
|
||||
input_ids = args[0] if len(args) > 0 else kwargs['input_ids']
|
||||
use_cache = kwargs.get('use_cache', True)
|
||||
labels = kwargs.get('labels', None)
|
||||
seq = kwargs['input_ids'][0].tolist()
|
||||
cache = kwargs['past_key_values'] if 'past_key_values' in kwargs else None
|
||||
cache = kwargs.get('past_key_values', None)
|
||||
seq = input_ids[0].tolist()
|
||||
|
||||
# Make the forward call
|
||||
seq_tensor = torch.tensor(seq)
|
||||
|
@ -70,7 +69,7 @@ class LlamacppHF(PreTrainedModel):
|
|||
self.model.reset()
|
||||
self.model.eval(seq)
|
||||
logits = torch.tensor(self.model.eval_logits)
|
||||
logits = logits.view(1, logits.shape[0], logits.shape[1]).to(kwargs['input_ids'].device)
|
||||
logits = logits.view(1, logits.shape[0], logits.shape[1]).to(input_ids.device)
|
||||
|
||||
self.cache = seq_tensor
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue