Add ExLlama support (#2444)
This commit is contained in:
parent
dea43685b0
commit
9f40032d32
12 changed files with 156 additions and 47 deletions
|
|
@ -38,31 +38,31 @@ class RWKVModel:
|
|||
result.cached_output_logits = None
|
||||
return result
|
||||
|
||||
def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=None, alpha_frequency=0.1, alpha_presence=0.1, token_ban=None, token_stop=None, callback=None):
|
||||
def generate(self, prompt, state, callback=None):
|
||||
args = PIPELINE_ARGS(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
alpha_frequency=alpha_frequency, # Frequency Penalty (as in GPT-3)
|
||||
alpha_presence=alpha_presence, # Presence Penalty (as in GPT-3)
|
||||
token_ban=token_ban or [0], # ban the generation of some tokens
|
||||
token_stop=token_stop or []
|
||||
temperature=state['temperature'],
|
||||
top_p=state['top_p'],
|
||||
top_k=state['top_k'],
|
||||
alpha_frequency=0.1, # Frequency Penalty (as in GPT-3)
|
||||
alpha_presence=0.1, # Presence Penalty (as in GPT-3)
|
||||
token_ban=[0], # ban the generation of some tokens
|
||||
token_stop=[]
|
||||
)
|
||||
|
||||
if self.cached_context != "":
|
||||
if context.startswith(self.cached_context):
|
||||
context = context[len(self.cached_context):]
|
||||
if prompt.startswith(self.cached_context):
|
||||
prompt = prompt[len(self.cached_context):]
|
||||
else:
|
||||
self.cached_context = ""
|
||||
self.cached_model_state = None
|
||||
self.cached_output_logits = None
|
||||
|
||||
# out = self.pipeline.generate(context, token_count=token_count, args=args, callback=callback)
|
||||
out = self.generate_from_cached_state(context, token_count=token_count, args=args, callback=callback)
|
||||
# out = self.pipeline.generate(prompt, token_count=state['max_new_tokens'], args=args, callback=callback)
|
||||
out = self.generate_from_cached_state(prompt, token_count=state['max_new_tokens'], args=args, callback=callback)
|
||||
return out
|
||||
|
||||
def generate_with_streaming(self, **kwargs):
|
||||
with Iteratorize(self.generate, kwargs, callback=None) as generator:
|
||||
def generate_with_streaming(self, *args, **kwargs):
|
||||
with Iteratorize(self.generate, args, kwargs, callback=None) as generator:
|
||||
reply = ''
|
||||
for token in generator:
|
||||
reply += token
|
||||
|
|
@ -81,6 +81,7 @@ class RWKVModel:
|
|||
if ctx == "":
|
||||
out = self.cached_output_logits
|
||||
|
||||
token = None
|
||||
for i in range(token_count):
|
||||
# forward
|
||||
tokens = self.pipeline.encode(ctx) if i == 0 else [token]
|
||||
|
|
|
|||
|
|
@ -55,11 +55,12 @@ class Iteratorize:
|
|||
Adapted from: https://stackoverflow.com/a/9969000
|
||||
"""
|
||||
|
||||
def __init__(self, func, kwargs=None, callback=None):
|
||||
def __init__(self, func, args=None, kwargs=None, callback=None):
|
||||
self.mfunc = func
|
||||
self.c_callback = callback
|
||||
self.q = Queue()
|
||||
self.sentinel = object()
|
||||
self.args = args or []
|
||||
self.kwargs = kwargs or {}
|
||||
self.stop_now = False
|
||||
|
||||
|
|
@ -70,7 +71,7 @@ class Iteratorize:
|
|||
|
||||
def gentask():
|
||||
try:
|
||||
ret = self.mfunc(callback=_callback, **self.kwargs)
|
||||
ret = self.mfunc(callback=_callback, *args, **self.kwargs)
|
||||
except ValueError:
|
||||
pass
|
||||
except:
|
||||
|
|
|
|||
81
modules/exllama.py
Normal file
81
modules/exllama.py
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path("repositories/exllama")))
|
||||
|
||||
from modules.logging_colors import logger
|
||||
from repositories.exllama.generator import ExLlamaGenerator
|
||||
from repositories.exllama.model import ExLlama, ExLlamaCache, ExLlamaConfig
|
||||
from repositories.exllama.tokenizer import ExLlamaTokenizer
|
||||
|
||||
|
||||
class ExllamaModel:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, path_to_model):
|
||||
|
||||
path_to_model = Path("models") / Path(path_to_model)
|
||||
tokenizer_model_path = path_to_model / "tokenizer.model"
|
||||
model_config_path = path_to_model / "config.json"
|
||||
|
||||
# Find the model checkpoint
|
||||
model_path = None
|
||||
for ext in ['.safetensors', '.pt', '.bin']:
|
||||
found = list(path_to_model.glob(f"*{ext}"))
|
||||
if len(found) > 0:
|
||||
if len(found) > 1:
|
||||
logger.warning(f'More than one {ext} model has been found. The last one will be selected. It could be wrong.')
|
||||
|
||||
model_path = found[-1]
|
||||
break
|
||||
|
||||
config = ExLlamaConfig(str(model_config_path))
|
||||
config.model_path = str(model_path)
|
||||
model = ExLlama(config)
|
||||
tokenizer = ExLlamaTokenizer(str(tokenizer_model_path))
|
||||
cache = ExLlamaCache(model)
|
||||
|
||||
result = self()
|
||||
result.config = config
|
||||
result.model = model
|
||||
result.cache = cache
|
||||
result.tokenizer = tokenizer
|
||||
return result, result
|
||||
|
||||
def generate(self, prompt, state, callback=None):
|
||||
generator = ExLlamaGenerator(self.model, self.tokenizer, self.cache)
|
||||
generator.settings.temperature = state['temperature']
|
||||
generator.settings.top_p = state['top_p']
|
||||
generator.settings.top_k = state['top_k']
|
||||
generator.settings.typical = state['typical_p']
|
||||
generator.settings.token_repetition_penalty_max = state['repetition_penalty']
|
||||
if state['ban_eos_token']:
|
||||
generator.disallow_tokens([self.tokenizer.eos_token_id])
|
||||
|
||||
text = generator.generate_simple(prompt, max_new_tokens=state['max_new_tokens'])
|
||||
return text
|
||||
|
||||
def generate_with_streaming(self, prompt, state, callback=None):
|
||||
generator = ExLlamaGenerator(self.model, self.tokenizer, self.cache)
|
||||
generator.settings.temperature = state['temperature']
|
||||
generator.settings.top_p = state['top_p']
|
||||
generator.settings.top_k = state['top_k']
|
||||
generator.settings.typical = state['typical_p']
|
||||
generator.settings.token_repetition_penalty_max = state['repetition_penalty']
|
||||
if state['ban_eos_token']:
|
||||
generator.disallow_tokens([self.tokenizer.eos_token_id])
|
||||
|
||||
generator.end_beam_search()
|
||||
ids = generator.tokenizer.encode(prompt)
|
||||
generator.gen_begin(ids)
|
||||
initial_len = generator.sequence[0].shape[0]
|
||||
for i in range(state['max_new_tokens']):
|
||||
token = generator.gen_single_token()
|
||||
yield (generator.tokenizer.decode(generator.sequence[0][initial_len:]))
|
||||
if token.item() == generator.tokenizer.eos_token_id:
|
||||
break
|
||||
|
||||
def encode(self, string, **kwargs):
|
||||
return self.tokenizer.encode(string)
|
||||
|
|
@ -59,18 +59,18 @@ class LlamaCppModel:
|
|||
|
||||
return self.model.tokenize(string)
|
||||
|
||||
def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, mirostat_mode=0, mirostat_tau=5, mirostat_eta=0.1, callback=None):
|
||||
context = context if type(context) is str else context.decode()
|
||||
def generate(self, prompt, state, callback=None):
|
||||
prompt = prompt if type(prompt) is str else prompt.decode()
|
||||
completion_chunks = self.model.create_completion(
|
||||
prompt=context,
|
||||
max_tokens=token_count,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
repeat_penalty=repetition_penalty,
|
||||
mirostat_mode=int(mirostat_mode),
|
||||
mirostat_tau=mirostat_tau,
|
||||
mirostat_eta=mirostat_eta,
|
||||
prompt=prompt,
|
||||
max_tokens=state['max_new_tokens'],
|
||||
temperature=state['temperature'],
|
||||
top_p=state['top_p'],
|
||||
top_k=state['top_k'],
|
||||
repeat_penalty=state['repetition_penalty'],
|
||||
mirostat_mode=int(state['mirostat_mode']),
|
||||
mirostat_tau=state['mirostat_tau'],
|
||||
mirostat_eta=state['mirostat_eta'],
|
||||
stream=True
|
||||
)
|
||||
|
||||
|
|
@ -83,8 +83,8 @@ class LlamaCppModel:
|
|||
|
||||
return output
|
||||
|
||||
def generate_with_streaming(self, **kwargs):
|
||||
with Iteratorize(self.generate, kwargs, callback=None) as generator:
|
||||
def generate_with_streaming(self, *args, **kwargs):
|
||||
with Iteratorize(self.generate, args, kwargs, callback=None) as generator:
|
||||
reply = ''
|
||||
for token in generator:
|
||||
reply += token
|
||||
|
|
|
|||
|
|
@ -52,6 +52,9 @@ loaders_and_params = {
|
|||
'trust_remote_code',
|
||||
'transformers_info'
|
||||
],
|
||||
'ExLlama' : [
|
||||
'exllama_info',
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -48,7 +48,8 @@ def load_model(model_name, loader=None):
|
|||
'GPTQ-for-LLaMa': GPTQ_loader,
|
||||
'llama.cpp': llamacpp_loader,
|
||||
'FlexGen': flexgen_loader,
|
||||
'RWKV': RWKV_loader
|
||||
'RWKV': RWKV_loader,
|
||||
'ExLlama': ExLlama_loader
|
||||
}
|
||||
|
||||
if loader is None:
|
||||
|
|
@ -270,6 +271,13 @@ def AutoGPTQ_loader(model_name):
|
|||
return modules.AutoGPTQ_loader.load_quantized(model_name)
|
||||
|
||||
|
||||
def ExLlama_loader(model_name):
|
||||
from modules.exllama import ExllamaModel
|
||||
|
||||
model, tokenizer = ExllamaModel.from_pretrained(model_name)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def get_max_memory_dict():
|
||||
max_memory = {}
|
||||
if shared.args.gpu_memory:
|
||||
|
|
|
|||
|
|
@ -94,7 +94,7 @@ def apply_model_settings_to_state(model, state):
|
|||
loader = 'AutoGPTQ'
|
||||
|
||||
# If the user is using an alternative GPTQ loader, let them keep using it
|
||||
if not (loader == 'AutoGPTQ' and state['loader'] in ['GPTQ-for-LLaMa', 'exllama']):
|
||||
if not (loader == 'AutoGPTQ' and state['loader'] in ['GPTQ-for-LLaMa', 'ExLlama']):
|
||||
state['loader'] = loader
|
||||
|
||||
for k in model_settings:
|
||||
|
|
|
|||
|
|
@ -97,7 +97,7 @@ parser.add_argument('--extensions', type=str, nargs="+", help='The list of exten
|
|||
parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
|
||||
|
||||
# Model loader
|
||||
parser.add_argument('--loader', type=str, help='Choose the model loader manually, otherwise, it will get autodetected. Valid options: autogptq, gptq-for-llama, transformers, llamacpp, rwkv, flexgen')
|
||||
parser.add_argument('--loader', type=str, help='Choose the model loader manually, otherwise, it will get autodetected. Valid options: transformers, autogptq, gptq-for-llama, exllama, llamacpp, rwkv, flexgen')
|
||||
|
||||
# Accelerate/transformers
|
||||
parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text. Warning: Training on CPU is extremely slow.')
|
||||
|
|
@ -212,6 +212,8 @@ def fix_loader_name(name):
|
|||
return 'AutoGPTQ'
|
||||
elif name in ['gptq-for-llama', 'gptqforllama', 'gptqllama', 'gptq for llama', 'gptq_for_llama']:
|
||||
return 'GPTQ-for-LLaMa'
|
||||
elif name in ['exllama', 'ex-llama', 'ex_llama', 'exlama']:
|
||||
return 'ExLlama'
|
||||
|
||||
|
||||
if args.loader is not None:
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
|
|||
if truncation_length is not None:
|
||||
input_ids = input_ids[:, -truncation_length:]
|
||||
|
||||
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel'] or shared.args.cpu:
|
||||
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel'] or shared.args.cpu:
|
||||
return input_ids
|
||||
elif shared.args.flexgen:
|
||||
return input_ids.numpy()
|
||||
|
|
@ -157,7 +157,7 @@ def _generate_reply(question, state, eos_token=None, stopping_strings=None, is_c
|
|||
yield ''
|
||||
return
|
||||
|
||||
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel']:
|
||||
if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel', 'ExllamaModel']:
|
||||
generate_func = generate_reply_custom
|
||||
elif shared.args.flexgen:
|
||||
generate_func = generate_reply_flexgen
|
||||
|
|
@ -283,13 +283,6 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None,
|
|||
|
||||
def generate_reply_custom(question, original_question, seed, state, eos_token=None, stopping_strings=None, is_chat=False):
|
||||
seed = set_manual_seed(state['seed'])
|
||||
generate_params = {'token_count': state['max_new_tokens']}
|
||||
for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
|
||||
generate_params[k] = state[k]
|
||||
|
||||
if shared.model.__class__.__name__ in ['LlamaCppModel']:
|
||||
for k in ['mirostat_mode', 'mirostat_tau', 'mirostat_eta']:
|
||||
generate_params[k] = state[k]
|
||||
|
||||
t0 = time.time()
|
||||
reply = ''
|
||||
|
|
@ -298,13 +291,13 @@ def generate_reply_custom(question, original_question, seed, state, eos_token=No
|
|||
yield ''
|
||||
|
||||
if not state['stream']:
|
||||
reply = shared.model.generate(context=question, **generate_params)
|
||||
reply = shared.model.generate(question, state)
|
||||
if not is_chat:
|
||||
reply = apply_extensions('output', reply)
|
||||
|
||||
yield reply
|
||||
else:
|
||||
for reply in shared.model.generate_with_streaming(context=question, **generate_params):
|
||||
for reply in shared.model.generate_with_streaming(question, state):
|
||||
if not is_chat:
|
||||
reply = apply_extensions('output', reply)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue