Add Classifier Free Guidance (CFG) for Transformers/ExLlama (#3325)
This commit is contained in:
parent
5134878344
commit
0af10ab49b
17 changed files with 131 additions and 42 deletions
|
@ -47,12 +47,11 @@ class ExllamaHF(PreTrainedModel):
|
|||
return torch.device(0)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# TODO: Some decoding methods (such as Contrastive Search) may not work at this time
|
||||
assert len(args) == 0, 'no *args should be passed to forward'
|
||||
input_ids = args[0] if len(args) > 0 else kwargs['input_ids']
|
||||
use_cache = kwargs.get('use_cache', True)
|
||||
labels = kwargs.get('labels', None)
|
||||
seq = kwargs['input_ids'][0].tolist()
|
||||
cache = kwargs['past_key_values'] if 'past_key_values' in kwargs else None
|
||||
cache = kwargs.get('past_key_values', None)
|
||||
seq = input_ids[0].tolist()
|
||||
|
||||
if labels is None:
|
||||
if cache is None:
|
||||
|
@ -60,7 +59,7 @@ class ExllamaHF(PreTrainedModel):
|
|||
cache = self.ex_cache
|
||||
self.ex_model.forward(torch.tensor([seq[:-1]], dtype=torch.long), cache, preprocess_only=True, lora=self.lora)
|
||||
|
||||
logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), cache, lora=self.lora).to(kwargs['input_ids'].device)
|
||||
logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), cache, lora=self.lora).to(input_ids.device)
|
||||
else:
|
||||
if cache is None:
|
||||
self.ex_cache.current_seq_len = 0
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue