Implement auto_max_new_tokens for ExLlama
This commit is contained in:
parent
e931844fe2
commit
32a2bbee4a
2 changed files with 6 additions and 1 deletions
|
@ -94,11 +94,15 @@ class ExllamaModel:
|
||||||
# Tokenizing the input
|
# Tokenizing the input
|
||||||
ids = self.generator.tokenizer.encode(prompt)
|
ids = self.generator.tokenizer.encode(prompt)
|
||||||
ids = ids[:, -get_max_prompt_length(state):]
|
ids = ids[:, -get_max_prompt_length(state):]
|
||||||
|
if state['auto_max_new_tokens']:
|
||||||
|
max_new_tokens = state['truncation_length'] - ids.shape[-1]
|
||||||
|
else:
|
||||||
|
max_new_tokens = state['max_new_tokens']
|
||||||
|
|
||||||
self.generator.gen_begin_reuse(ids)
|
self.generator.gen_begin_reuse(ids)
|
||||||
initial_len = self.generator.sequence[0].shape[0]
|
initial_len = self.generator.sequence[0].shape[0]
|
||||||
has_leading_space = False
|
has_leading_space = False
|
||||||
for i in range(state['max_new_tokens']):
|
for i in range(max_new_tokens):
|
||||||
token = self.generator.gen_single_token()
|
token = self.generator.gen_single_token()
|
||||||
if i == 0 and self.generator.tokenizer.tokenizer.IdToPiece(int(token)).startswith('▁'):
|
if i == 0 and self.generator.tokenizer.tokenizer.IdToPiece(int(token)).startswith('▁'):
|
||||||
has_leading_space = True
|
has_leading_space = True
|
||||||
|
|
|
@ -151,6 +151,7 @@ loaders_samplers = {
|
||||||
'repetition_penalty_range',
|
'repetition_penalty_range',
|
||||||
'seed',
|
'seed',
|
||||||
'ban_eos_token',
|
'ban_eos_token',
|
||||||
|
'auto_max_new_tokens',
|
||||||
},
|
},
|
||||||
'AutoGPTQ': {
|
'AutoGPTQ': {
|
||||||
'temperature',
|
'temperature',
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue