Add Classifier Free Guidance (CFG) for Transformers/ExLlama (#3325)

This commit is contained in:
oobabooga 2023-08-06 17:22:48 -03:00 committed by GitHub
parent 5134878344
commit 0af10ab49b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 131 additions and 42 deletions

View file

@ -1,9 +1,11 @@
from pathlib import Path
import torch.nn.functional as F
from torch import version as torch_version
from modules import shared
from modules.logging_colors import logger
from modules.models import clear_torch_cache
from modules.text_generation import get_max_prompt_length
try:
@ -78,6 +80,21 @@ class ExllamaModel:
return result, result
def generate_with_streaming(self, prompt, state):
# The cache batch size must be 2 for CFG and 1 otherwise
if state['guidance_scale'] == 1:
if self.cache.batch_size == 2:
del self.cache
clear_torch_cache()
self.cache = ExLlamaCache(self.model)
self.generator = ExLlamaGenerator(self.model, self.tokenizer, self.cache)
else:
if self.cache.batch_size == 1:
del self.cache
clear_torch_cache()
self.cache = ExLlamaCache(self.model, batch_size=2)
self.generator = ExLlamaGenerator(self.model, self.tokenizer, self.cache)
self.generator.settings.temperature = state['temperature']
self.generator.settings.top_p = state['top_p']
self.generator.settings.top_k = state['top_k']
@ -89,31 +106,71 @@ class ExllamaModel:
else:
self.generator.disallow_tokens(None)
self.generator.end_beam_search()
# Case 1: no CFG
if state['guidance_scale'] == 1:
self.generator.end_beam_search()
# Tokenizing the input
ids = self.generator.tokenizer.encode(prompt)
ids = ids[:, -get_max_prompt_length(state):]
if state['auto_max_new_tokens']:
max_new_tokens = state['truncation_length'] - ids.shape[-1]
# Tokenizing the input
ids = self.generator.tokenizer.encode(prompt)
ids = ids[:, -get_max_prompt_length(state):]
if state['auto_max_new_tokens']:
max_new_tokens = state['truncation_length'] - ids.shape[-1]
else:
max_new_tokens = state['max_new_tokens']
self.generator.gen_begin_reuse(ids)
initial_len = self.generator.sequence[0].shape[0]
has_leading_space = False
for i in range(max_new_tokens):
token = self.generator.gen_single_token()
if i == 0 and self.generator.tokenizer.tokenizer.IdToPiece(int(token)).startswith(''):
has_leading_space = True
decoded_text = self.generator.tokenizer.decode(self.generator.sequence[0][initial_len:])
if has_leading_space:
decoded_text = ' ' + decoded_text
yield decoded_text
if token.item() == self.generator.tokenizer.eos_token_id or shared.stop_everything:
break
# Case 2: CFG
else:
max_new_tokens = state['max_new_tokens']
alpha = state['guidance_scale']
prompts = [prompt, state['negative_prompt'] or '']
self.generator.gen_begin_reuse(ids)
initial_len = self.generator.sequence[0].shape[0]
has_leading_space = False
for i in range(max_new_tokens):
token = self.generator.gen_single_token()
if i == 0 and self.generator.tokenizer.tokenizer.IdToPiece(int(token)).startswith(''):
has_leading_space = True
ids, mask = self.tokenizer.encode(prompts, return_mask=True)
if state['auto_max_new_tokens']:
max_new_tokens = state['truncation_length'] - ids[0].shape[-1]
else:
max_new_tokens = state['max_new_tokens']
decoded_text = self.generator.tokenizer.decode(self.generator.sequence[0][initial_len:])
if has_leading_space:
decoded_text = ' ' + decoded_text
self.generator.gen_begin(ids, mask=mask)
initial_len = self.generator.sequence[0].shape[0]
has_leading_space = False
yield decoded_text
if token.item() == self.generator.tokenizer.eos_token_id or shared.stop_everything:
break
for i in range(max_new_tokens):
logits = self.model.forward(self.generator.sequence[:, -1:], self.cache, input_mask=mask)
self.generator.apply_rep_penalty(logits)
logits = F.log_softmax(logits, dim=-1)
logits_mixed = alpha * logits[0] + (1 - alpha) * logits[1]
token, _ = self.generator.sample_current(logits_mixed)
if i == 0 and self.generator.tokenizer.tokenizer.IdToPiece(int(token)).startswith(''):
has_leading_space = True
decoded_text = self.generator.tokenizer.decode(self.generator.sequence[0][initial_len:])
if has_leading_space:
decoded_text = ' ' + decoded_text
yield decoded_text
if token.item() == self.tokenizer.eos_token_id or shared.stop_everything:
break
batch_token = token.repeat(2, 1)
self.generator.gen_accept_token(batch_token)
def generate(self, prompt, state):
output = ''