Add Classifier Free Guidance (CFG) for Transformers/ExLlama (#3325)

This commit is contained in:
oobabooga 2023-08-06 17:22:48 -03:00 committed by GitHub
parent 5134878344
commit 0af10ab49b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 131 additions and 42 deletions

View file

@ -47,12 +47,11 @@ class ExllamaHF(PreTrainedModel):
return torch.device(0)
def __call__(self, *args, **kwargs):
# TODO: Some decoding methods (such as Contrastive Search) may not work at this time
assert len(args) == 0, 'no *args should be passed to forward'
input_ids = args[0] if len(args) > 0 else kwargs['input_ids']
use_cache = kwargs.get('use_cache', True)
labels = kwargs.get('labels', None)
seq = kwargs['input_ids'][0].tolist()
cache = kwargs['past_key_values'] if 'past_key_values' in kwargs else None
cache = kwargs.get('past_key_values', None)
seq = input_ids[0].tolist()
if labels is None:
if cache is None:
@ -60,7 +59,7 @@ class ExllamaHF(PreTrainedModel):
cache = self.ex_cache
self.ex_model.forward(torch.tensor([seq[:-1]], dtype=torch.long), cache, preprocess_only=True, lora=self.lora)
logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), cache, lora=self.lora).to(kwargs['input_ids'].device)
logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), cache, lora=self.lora).to(input_ids.device)
else:
if cache is None:
self.ex_cache.current_seq_len = 0