Add Classifier Free Guidance (CFG) for Transformers/ExLlama (#3325)

This commit is contained in:
oobabooga 2023-08-06 17:22:48 -03:00 committed by GitHub
parent 5134878344
commit 0af10ab49b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 131 additions and 42 deletions

View file

@ -49,12 +49,11 @@ class LlamacppHF(PreTrainedModel):
return torch.device(0)
def __call__(self, *args, **kwargs):
# TODO: Some decoding methods (such as Contrastive Search) may not work at this time
assert len(args) == 0, 'no *args should be passed to forward'
input_ids = args[0] if len(args) > 0 else kwargs['input_ids']
use_cache = kwargs.get('use_cache', True)
labels = kwargs.get('labels', None)
seq = kwargs['input_ids'][0].tolist()
cache = kwargs['past_key_values'] if 'past_key_values' in kwargs else None
cache = kwargs.get('past_key_values', None)
seq = input_ids[0].tolist()
# Make the forward call
seq_tensor = torch.tensor(seq)
@ -70,7 +69,7 @@ class LlamacppHF(PreTrainedModel):
self.model.reset()
self.model.eval(seq)
logits = torch.tensor(self.model.eval_logits)
logits = logits.view(1, logits.shape[0], logits.shape[1]).to(kwargs['input_ids'].device)
logits = logits.view(1, logits.shape[0], logits.shape[1]).to(input_ids.device)
self.cache = seq_tensor