diff --git a/README.md b/README.md index 9d2e1b0..24c0471 100644 --- a/README.md +++ b/README.md @@ -212,7 +212,7 @@ Optionally, you can use the following command-line flags: | Flag | Description | |--------------------------------------------|-------------| -| `--loader LOADER` | Choose the model loader manually, otherwise, it will get autodetected. Valid options: transformers, autogptq, gptq-for-llama, exllama, llamacpp, rwkv, flexgen | +| `--loader LOADER` | Choose the model loader manually, otherwise, it will get autodetected. Valid options: transformers, autogptq, gptq-for-llama, exllama, exllama_hf, llamacpp, rwkv, flexgen | #### Accelerate/transformers diff --git a/modules/exllama_hf.py b/modules/exllama_hf.py new file mode 100644 index 0000000..27cac37 --- /dev/null +++ b/modules/exllama_hf.py @@ -0,0 +1,82 @@ +import os +import sys +from pathlib import Path +from typing import * + +import torch +from transformers import ( + GenerationConfig, + LlamaTokenizer, + PretrainedConfig, + PreTrainedModel +) +from transformers.modeling_outputs import CausalLMOutputWithPast + +from modules import shared +from modules.logging_colors import logger +from modules.relative_imports import RelativeImport + +with RelativeImport("repositories/exllama"): + from model import ExLlama, ExLlamaCache, ExLlamaConfig + + +class ExllamaHF(PreTrainedModel): + def __init__(self, config: ExLlamaConfig): + super().__init__(PretrainedConfig()) + self.ex_config = config + self.ex_model = ExLlama(self.ex_config) + self.generation_config = GenerationConfig() + + def _validate_model_class(self): + pass + + def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): + pass + + def prepare_inputs_for_generation(self, input_ids, **kwargs): + return {'input_ids': input_ids, **kwargs} + + @property + def device(self) -> torch.device: + # TODO: May cause problem on multi-gpu inference? + return torch.device(0) + + def __call__(self, *args, **kwargs): + # TODO: Some decoding methods (such as Contrastive Search) may not work at this time + assert len(args) == 0, 'no *args should be passed to forward' + use_cache = kwargs['use_cache'] + seq = kwargs['input_ids'][0].tolist() + cache = kwargs['past_key_values'] if 'past_key_values' in kwargs else None + if cache is None: + cache = ExLlamaCache(self.ex_model) + self.ex_model.forward(torch.tensor([seq[:-1]], dtype=torch.long), cache, preprocess_only=True) + logits = self.ex_model.forward(torch.tensor([seq[-1:]], dtype=torch.long), cache).to(self.device) + return CausalLMOutputWithPast(logits=logits, past_key_values=cache if use_cache else None) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs): + assert len(model_args) == 0 and len(kwargs) == 0, "extra args is currently not supported" + if isinstance(pretrained_model_name_or_path, str): + pretrained_model_name_or_path = Path(pretrained_model_name_or_path) + + pretrained_model_name_or_path = Path(f'{shared.args.model_dir}') / Path(pretrained_model_name_or_path) + config = ExLlamaConfig(pretrained_model_name_or_path / 'config.json') + + # from 'oobabooga/text-generation-webui/modules/exllama.py' + weight_path = None + for ext in ['.safetensors', '.pt', '.bin']: + found = list(pretrained_model_name_or_path.glob(f"*{ext}")) + if len(found) > 0: + weight_path = found[-1] + break + assert weight_path is not None, f'could not find weight in "{pretrained_model_name_or_path}"' + + config.model_path = str(weight_path) + + # This slowes down a bit but align better with autogptq generation. + # TODO: Should give user choice to tune the exllama config + config.act_order = True + config.fused_attn = False + config.fused_mlp_thd = 0 + + return ExllamaHF(config) \ No newline at end of file diff --git a/modules/loaders.py b/modules/loaders.py index ac6f80b..2164202 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -55,6 +55,10 @@ loaders_and_params = { 'ExLlama' : [ 'gpu_split', 'exllama_info', + ], + 'ExLlama_HF' : [ + 'gpu_split', + 'exllama_HF_info', ] } diff --git a/modules/models.py b/modules/models.py index 1aba66c..574e164 100644 --- a/modules/models.py +++ b/modules/models.py @@ -49,7 +49,8 @@ def load_model(model_name, loader=None): 'llama.cpp': llamacpp_loader, 'FlexGen': flexgen_loader, 'RWKV': RWKV_loader, - 'ExLlama': ExLlama_loader + 'ExLlama': ExLlama_loader, + 'ExLlama_HF': ExLlama_HF_loader } if loader is None: @@ -278,6 +279,12 @@ def ExLlama_loader(model_name): return model, tokenizer +def ExLlama_HF_loader(model_name): + from modules.exllama_hf import ExllamaHF + + return ExllamaHF.from_pretrained(model_name) + + def get_max_memory_dict(): max_memory = {} if shared.args.gpu_memory: diff --git a/modules/shared.py b/modules/shared.py index ecc03fc..e065b76 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -98,7 +98,7 @@ parser.add_argument('--extensions', type=str, nargs="+", help='The list of exten parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.') # Model loader -parser.add_argument('--loader', type=str, help='Choose the model loader manually, otherwise, it will get autodetected. Valid options: transformers, autogptq, gptq-for-llama, exllama, llamacpp, rwkv, flexgen') +parser.add_argument('--loader', type=str, help='Choose the model loader manually, otherwise, it will get autodetected. Valid options: transformers, autogptq, gptq-for-llama, exllama, exllama_hf, llamacpp, rwkv, flexgen') # Accelerate/transformers parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text. Warning: Training on CPU is extremely slow.') @@ -218,6 +218,8 @@ def fix_loader_name(name): return 'GPTQ-for-LLaMa' elif name in ['exllama', 'ex-llama', 'ex_llama', 'exlama']: return 'ExLlama' + elif name in ['exllama-hf', 'exllama_hf', 'exllama hf', 'ex-llama-hf', 'ex_llama_hf']: + return 'ExLlama_HF' if args.loader is not None: diff --git a/modules/text_generation.py b/modules/text_generation.py index 0d2f55c..d0965b8 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -104,9 +104,8 @@ def get_reply_from_output_ids(output_ids, input_ids, original_question, state, i else: new_tokens = len(output_ids) - len(input_ids[0]) reply = decode(output_ids[-new_tokens:], state['skip_special_tokens']) - # Prevent LlamaTokenizer from skipping a space - if type(shared.tokenizer) is transformers.LlamaTokenizer and len(output_ids) > 0: + if type(shared.tokenizer) in [transformers.LlamaTokenizer, transformers.LlamaTokenizerFast] and len(output_ids) > 0: if shared.tokenizer.convert_ids_to_tokens(int(output_ids[-new_tokens])).startswith('▁'): reply = ' ' + reply diff --git a/server.py b/server.py index b6699f1..ff4f2c3 100644 --- a/server.py +++ b/server.py @@ -197,7 +197,7 @@ def create_model_menus(): with gr.Row(): with gr.Column(): - shared.gradio['loader'] = gr.Dropdown(label="Model loader", choices=["Transformers", "AutoGPTQ", "GPTQ-for-LLaMa", "ExLlama", "llama.cpp"], value=None) + shared.gradio['loader'] = gr.Dropdown(label="Model loader", choices=["Transformers", "AutoGPTQ", "GPTQ-for-LLaMa", "ExLlama", "ExLlama_HF", "llama.cpp"], value=None) with gr.Box(): with gr.Row(): with gr.Column(): @@ -237,6 +237,7 @@ def create_model_menus(): shared.gradio['trust_remote_code'] = gr.Checkbox(label="trust-remote-code", value=shared.args.trust_remote_code, info='Make sure to inspect the .py files inside the model folder before loading it with this option enabled.') shared.gradio['gptq_for_llama_info'] = gr.Markdown('GPTQ-for-LLaMa is currently 2x faster than AutoGPTQ on some systems. It is installed by default with the one-click installers. Otherwise, it has to be installed manually following the instructions here: [instructions](https://github.com/oobabooga/text-generation-webui/blob/main/docs/GPTQ-models-(4-bit-mode).md#installation-1).') shared.gradio['exllama_info'] = gr.Markdown('ExLlama has to be installed manually. See the instructions here: [instructions](https://github.com/oobabooga/text-generation-webui/blob/main/docs/ExLlama.md).') + shared.gradio['exllama_HF_info'] = gr.Markdown('ExLlama_HF is a wrapper that lets you use ExLlama like a Transformers model, which means it can use the Transformers samplers. It\'s still a bit buggy, so feel free to help out by fixing issues.\n\nCheck out PR [#2777](https://github.com/oobabooga/text-generation-webui/pull/2777) for more details.') with gr.Column(): with gr.Row():