Add Classifier Free Guidance (CFG) for Transformers/ExLlama (#3325)
This commit is contained in:
parent
5134878344
commit
0af10ab49b
17 changed files with 131 additions and 42 deletions
|
@ -1,9 +1,11 @@
|
|||
from pathlib import Path
|
||||
|
||||
import torch.nn.functional as F
|
||||
from torch import version as torch_version
|
||||
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
from modules.models import clear_torch_cache
|
||||
from modules.text_generation import get_max_prompt_length
|
||||
|
||||
try:
|
||||
|
@ -78,6 +80,21 @@ class ExllamaModel:
|
|||
return result, result
|
||||
|
||||
def generate_with_streaming(self, prompt, state):
|
||||
|
||||
# The cache batch size must be 2 for CFG and 1 otherwise
|
||||
if state['guidance_scale'] == 1:
|
||||
if self.cache.batch_size == 2:
|
||||
del self.cache
|
||||
clear_torch_cache()
|
||||
self.cache = ExLlamaCache(self.model)
|
||||
self.generator = ExLlamaGenerator(self.model, self.tokenizer, self.cache)
|
||||
else:
|
||||
if self.cache.batch_size == 1:
|
||||
del self.cache
|
||||
clear_torch_cache()
|
||||
self.cache = ExLlamaCache(self.model, batch_size=2)
|
||||
self.generator = ExLlamaGenerator(self.model, self.tokenizer, self.cache)
|
||||
|
||||
self.generator.settings.temperature = state['temperature']
|
||||
self.generator.settings.top_p = state['top_p']
|
||||
self.generator.settings.top_k = state['top_k']
|
||||
|
@ -89,31 +106,71 @@ class ExllamaModel:
|
|||
else:
|
||||
self.generator.disallow_tokens(None)
|
||||
|
||||
self.generator.end_beam_search()
|
||||
# Case 1: no CFG
|
||||
if state['guidance_scale'] == 1:
|
||||
self.generator.end_beam_search()
|
||||
|
||||
# Tokenizing the input
|
||||
ids = self.generator.tokenizer.encode(prompt)
|
||||
ids = ids[:, -get_max_prompt_length(state):]
|
||||
if state['auto_max_new_tokens']:
|
||||
max_new_tokens = state['truncation_length'] - ids.shape[-1]
|
||||
# Tokenizing the input
|
||||
ids = self.generator.tokenizer.encode(prompt)
|
||||
ids = ids[:, -get_max_prompt_length(state):]
|
||||
if state['auto_max_new_tokens']:
|
||||
max_new_tokens = state['truncation_length'] - ids.shape[-1]
|
||||
else:
|
||||
max_new_tokens = state['max_new_tokens']
|
||||
|
||||
self.generator.gen_begin_reuse(ids)
|
||||
initial_len = self.generator.sequence[0].shape[0]
|
||||
has_leading_space = False
|
||||
|
||||
for i in range(max_new_tokens):
|
||||
token = self.generator.gen_single_token()
|
||||
if i == 0 and self.generator.tokenizer.tokenizer.IdToPiece(int(token)).startswith('▁'):
|
||||
has_leading_space = True
|
||||
|
||||
decoded_text = self.generator.tokenizer.decode(self.generator.sequence[0][initial_len:])
|
||||
if has_leading_space:
|
||||
decoded_text = ' ' + decoded_text
|
||||
|
||||
yield decoded_text
|
||||
if token.item() == self.generator.tokenizer.eos_token_id or shared.stop_everything:
|
||||
break
|
||||
|
||||
# Case 2: CFG
|
||||
else:
|
||||
max_new_tokens = state['max_new_tokens']
|
||||
alpha = state['guidance_scale']
|
||||
prompts = [prompt, state['negative_prompt'] or '']
|
||||
|
||||
self.generator.gen_begin_reuse(ids)
|
||||
initial_len = self.generator.sequence[0].shape[0]
|
||||
has_leading_space = False
|
||||
for i in range(max_new_tokens):
|
||||
token = self.generator.gen_single_token()
|
||||
if i == 0 and self.generator.tokenizer.tokenizer.IdToPiece(int(token)).startswith('▁'):
|
||||
has_leading_space = True
|
||||
ids, mask = self.tokenizer.encode(prompts, return_mask=True)
|
||||
if state['auto_max_new_tokens']:
|
||||
max_new_tokens = state['truncation_length'] - ids[0].shape[-1]
|
||||
else:
|
||||
max_new_tokens = state['max_new_tokens']
|
||||
|
||||
decoded_text = self.generator.tokenizer.decode(self.generator.sequence[0][initial_len:])
|
||||
if has_leading_space:
|
||||
decoded_text = ' ' + decoded_text
|
||||
self.generator.gen_begin(ids, mask=mask)
|
||||
initial_len = self.generator.sequence[0].shape[0]
|
||||
has_leading_space = False
|
||||
|
||||
yield decoded_text
|
||||
if token.item() == self.generator.tokenizer.eos_token_id or shared.stop_everything:
|
||||
break
|
||||
for i in range(max_new_tokens):
|
||||
logits = self.model.forward(self.generator.sequence[:, -1:], self.cache, input_mask=mask)
|
||||
self.generator.apply_rep_penalty(logits)
|
||||
|
||||
logits = F.log_softmax(logits, dim=-1)
|
||||
logits_mixed = alpha * logits[0] + (1 - alpha) * logits[1]
|
||||
|
||||
token, _ = self.generator.sample_current(logits_mixed)
|
||||
if i == 0 and self.generator.tokenizer.tokenizer.IdToPiece(int(token)).startswith('▁'):
|
||||
has_leading_space = True
|
||||
|
||||
decoded_text = self.generator.tokenizer.decode(self.generator.sequence[0][initial_len:])
|
||||
if has_leading_space:
|
||||
decoded_text = ' ' + decoded_text
|
||||
|
||||
yield decoded_text
|
||||
if token.item() == self.tokenizer.eos_token_id or shared.stop_everything:
|
||||
break
|
||||
|
||||
batch_token = token.repeat(2, 1)
|
||||
self.generator.gen_accept_token(batch_token)
|
||||
|
||||
def generate(self, prompt, state):
|
||||
output = ''
|
||||
|
|
|
@ -47,12 +47,11 @@ class ExllamaHF(PreTrainedModel):
|
|||
return torch.device(0)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# TODO: Some decoding methods (such as Contrastive Search) may not work at this time
|
||||
assert len(args) == 0, 'no *args should be passed to forward'
|
||||
input_ids = args[0] if len(args) > 0 else kwargs['input_ids']
|
||||
use_cache = kwargs.get('use_cache', True)
|
||||
labels = kwargs.get('labels', None)
|
||||
seq = kwargs['input_ids'][0].tolist()
|
||||
cache = kwargs['past_key_values'] if 'past_key_values' in kwargs else None
|
||||
cache = kwargs.get('past_key_values', None)
|
||||
seq = input_ids[0].tolist()
|
||||
|
||||
if labels is None:
|
||||
if cache is None:
|
||||
|
@ -60,7 +59,7 @@ class ExllamaHF(PreTrainedModel):
|
|||
cache = self.ex_cache
|
||||
self.ex_model.forward(torch.tensor([seq[:-1]], dtype=torch.long), cache, preprocess_only=True, lora=self.lora)
|
||||
|
||||
logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), cache, lora=self.lora).to(kwargs['input_ids'].device)
|
||||
logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), cache, lora=self.lora).to(input_ids.device)
|
||||
else:
|
||||
if cache is None:
|
||||
self.ex_cache.current_seq_len = 0
|
||||
|
|
|
@ -49,12 +49,11 @@ class LlamacppHF(PreTrainedModel):
|
|||
return torch.device(0)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# TODO: Some decoding methods (such as Contrastive Search) may not work at this time
|
||||
assert len(args) == 0, 'no *args should be passed to forward'
|
||||
input_ids = args[0] if len(args) > 0 else kwargs['input_ids']
|
||||
use_cache = kwargs.get('use_cache', True)
|
||||
labels = kwargs.get('labels', None)
|
||||
seq = kwargs['input_ids'][0].tolist()
|
||||
cache = kwargs['past_key_values'] if 'past_key_values' in kwargs else None
|
||||
cache = kwargs.get('past_key_values', None)
|
||||
seq = input_ids[0].tolist()
|
||||
|
||||
# Make the forward call
|
||||
seq_tensor = torch.tensor(seq)
|
||||
|
@ -70,7 +69,7 @@ class LlamacppHF(PreTrainedModel):
|
|||
self.model.reset()
|
||||
self.model.eval(seq)
|
||||
logits = torch.tensor(self.model.eval_logits)
|
||||
logits = logits.view(1, logits.shape[0], logits.shape[1]).to(kwargs['input_ids'].device)
|
||||
logits = logits.view(1, logits.shape[0], logits.shape[1]).to(input_ids.device)
|
||||
|
||||
self.cache = seq_tensor
|
||||
|
||||
|
|
|
@ -115,6 +115,8 @@ loaders_samplers = {
|
|||
'mirostat_mode',
|
||||
'mirostat_tau',
|
||||
'mirostat_eta',
|
||||
'guidance_scale',
|
||||
'negative_prompt',
|
||||
'ban_eos_token',
|
||||
'add_bos_token',
|
||||
'skip_special_tokens',
|
||||
|
@ -152,6 +154,8 @@ loaders_samplers = {
|
|||
'repetition_penalty',
|
||||
'repetition_penalty_range',
|
||||
'seed',
|
||||
'guidance_scale',
|
||||
'negative_prompt',
|
||||
'ban_eos_token',
|
||||
'auto_max_new_tokens',
|
||||
},
|
||||
|
@ -178,6 +182,8 @@ loaders_samplers = {
|
|||
'mirostat_mode',
|
||||
'mirostat_tau',
|
||||
'mirostat_eta',
|
||||
'guidance_scale',
|
||||
'negative_prompt',
|
||||
'ban_eos_token',
|
||||
'add_bos_token',
|
||||
'skip_special_tokens',
|
||||
|
@ -206,6 +212,8 @@ loaders_samplers = {
|
|||
'mirostat_mode',
|
||||
'mirostat_tau',
|
||||
'mirostat_eta',
|
||||
'guidance_scale',
|
||||
'negative_prompt',
|
||||
'ban_eos_token',
|
||||
'add_bos_token',
|
||||
'skip_special_tokens',
|
||||
|
|
|
@ -9,6 +9,7 @@ def default_preset():
|
|||
'do_sample': True,
|
||||
'temperature': 1,
|
||||
'top_p': 1,
|
||||
'top_k': 0,
|
||||
'typical_p': 1,
|
||||
'epsilon_cutoff': 0,
|
||||
'eta_cutoff': 0,
|
||||
|
@ -17,19 +18,23 @@ def default_preset():
|
|||
'repetition_penalty': 1,
|
||||
'repetition_penalty_range': 0,
|
||||
'encoder_repetition_penalty': 1,
|
||||
'top_k': 0,
|
||||
'num_beams': 1,
|
||||
'penalty_alpha': 0,
|
||||
'min_length': 0,
|
||||
'length_penalty': 1,
|
||||
'no_repeat_ngram_size': 0,
|
||||
'early_stopping': False,
|
||||
'min_length': 0,
|
||||
'guidance_scale': 1,
|
||||
'mirostat_mode': 0,
|
||||
'mirostat_tau': 5.0,
|
||||
'mirostat_eta': 0.1,
|
||||
'penalty_alpha': 0,
|
||||
'num_beams': 1,
|
||||
'length_penalty': 1,
|
||||
'early_stopping': False,
|
||||
}
|
||||
|
||||
|
||||
def presets_params():
|
||||
return [k for k in default_preset()]
|
||||
|
||||
|
||||
def load_preset(name):
|
||||
generate_params = default_preset()
|
||||
if name not in ['None', None, '']:
|
||||
|
@ -51,12 +56,12 @@ def load_preset_memoized(name):
|
|||
def load_preset_for_ui(name, state):
|
||||
generate_params = load_preset(name)
|
||||
state.update(generate_params)
|
||||
return state, *[generate_params[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']]
|
||||
return state, *[generate_params[k] for k in presets_params()]
|
||||
|
||||
|
||||
def generate_preset_yaml(state):
|
||||
defaults = default_preset()
|
||||
data = {k: state[k] for k in ['do_sample', 'temperature', 'top_p', 'typical_p', 'epsilon_cutoff', 'eta_cutoff', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'tfs', 'top_a']}
|
||||
data = {k: state[k] for k in presets_params()}
|
||||
|
||||
# Remove entries that are identical to the defaults
|
||||
for k in list(data.keys()):
|
||||
|
|
|
@ -42,6 +42,7 @@ settings = {
|
|||
'max_new_tokens_max': 4096,
|
||||
'auto_max_new_tokens': False,
|
||||
'seed': -1,
|
||||
'negative_prompt': '',
|
||||
'character': 'None',
|
||||
'name1': 'You',
|
||||
'name2': 'Assistant',
|
||||
|
|
|
@ -226,9 +226,12 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False):
|
|||
|
||||
def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False):
|
||||
generate_params = {}
|
||||
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta']:
|
||||
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'repetition_penalty_range', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'tfs', 'top_a', 'mirostat_mode', 'mirostat_tau', 'mirostat_eta', 'guidance_scale']:
|
||||
generate_params[k] = state[k]
|
||||
|
||||
if state['negative_prompt'] != '':
|
||||
generate_params['negative_prompt_ids'] = encode(state['negative_prompt'])
|
||||
|
||||
for k in ['epsilon_cutoff', 'eta_cutoff']:
|
||||
if state[k] > 0:
|
||||
generate_params[k] = state[k] * 1e-4
|
||||
|
|
|
@ -100,6 +100,8 @@ def list_interface_input_elements():
|
|||
'mirostat_mode',
|
||||
'mirostat_tau',
|
||||
'mirostat_eta',
|
||||
'negative_prompt',
|
||||
'guidance_scale',
|
||||
'add_bos_token',
|
||||
'ban_eos_token',
|
||||
'truncation_length',
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue