diff --git a/download-model.py b/download-model.py index f8be986..b22405e 100644 --- a/download-model.py +++ b/download-model.py @@ -127,10 +127,23 @@ class ModelDownloader: if classifications[i] in ['pytorch', 'pt']: links.pop(i) + # For GGUF, try to download only the Q4_K_M if no specific file is specified. + # If not present, exclude all GGUFs, as that's likely a repository with both + # GGUF and fp16 files. if has_gguf and specific_file is None: + has_q4km = False for i in range(len(classifications) - 1, -1, -1): - if 'q4_k_m' not in links[i].lower(): - links.pop(i) + if 'q4_k_m' in links[i].lower(): + has_q4km = True + + if has_q4km: + for i in range(len(classifications) - 1, -1, -1): + if 'q4_k_m' not in links[i].lower(): + links.pop(i) + else: + for i in range(len(classifications) - 1, -1, -1): + if links[i].lower().endswith('.gguf'): + links.pop(i) is_llamacpp = has_gguf and specific_file is not None return links, sha256, is_lora, is_llamacpp diff --git a/extensions/openai/completions.py b/extensions/openai/completions.py index 389466f..273d533 100644 --- a/extensions/openai/completions.py +++ b/extensions/openai/completions.py @@ -236,7 +236,7 @@ def chat_completions_common(body: dict, is_legacy: bool = False, stream=False) - max_tokens = generate_params['max_new_tokens'] if max_tokens in [None, 0]: - generate_params['max_new_tokens'] = 200 + generate_params['max_new_tokens'] = 512 generate_params['auto_max_new_tokens'] = True requested_model = generate_params.pop('model') diff --git a/extensions/openai/typing.py b/extensions/openai/typing.py index 5a2d40d..695b929 100644 --- a/extensions/openai/typing.py +++ b/extensions/openai/typing.py @@ -10,7 +10,7 @@ class GenerationOptions(BaseModel): min_p: float = 0 top_k: int = 0 repetition_penalty: float = 1 - repetition_penalty_range: int = 0 + repetition_penalty_range: int = 1024 typical_p: float = 1 tfs: float = 1 top_a: float = 0 diff --git a/modules/exllama.py b/modules/exllama.py index 4257ee0..25c4c99 100644 --- a/modules/exllama.py +++ b/modules/exllama.py @@ -165,10 +165,19 @@ class ExllamaModel: if has_leading_space: decoded_text = ' ' + decoded_text - yield decoded_text + # Check the partial unicode character + if chr(0xfffd) in decoded_text: + is_last = i == max_new_tokens - 1 + is_stopping = token.item() == self.generator.tokenizer.eos_token_id or shared.stop_everything + # If we are not at the end of the generation, we skip this token + if not (is_last or is_stopping): + continue + if token.item() == self.generator.tokenizer.eos_token_id or shared.stop_everything: break + yield decoded_text + # Case 2: CFG # Copied from https://github.com/turboderp/exllama/blob/master/example_cfg.py else: @@ -205,6 +214,14 @@ class ExllamaModel: if has_leading_space: decoded_text = ' ' + decoded_text + # Check the partial unicode character + if chr(0xfffd) in decoded_text: + is_last = i == max_new_tokens - 1 + is_stopping = token.item() == self.tokenizer.eos_token_id or shared.stop_everything + # If we are not at the end of the generation, we skip this token + if not (is_last or is_stopping): + continue + yield decoded_text if token.item() == self.tokenizer.eos_token_id or shared.stop_everything: break diff --git a/modules/exllamav2.py b/modules/exllamav2.py index b92e884..d755a36 100644 --- a/modules/exllamav2.py +++ b/modules/exllamav2.py @@ -138,11 +138,19 @@ class Exllamav2Model: if has_leading_space: decoded_text = ' ' + decoded_text - yield decoded_text + # Check the partial unicode character + if chr(0xfffd) in decoded_text: + is_last = i == max_new_tokens - 1 + is_stopping = token.item() == self.tokenizer.eos_token_id or shared.stop_everything + # If we are not at the end of the generation, we skip this token + if not (is_last or is_stopping): + continue if token.item() == self.tokenizer.eos_token_id or shared.stop_everything: break + yield decoded_text + def generate(self, prompt, state): output = '' for output in self.generate_with_streaming(prompt, state): diff --git a/modules/loaders.py b/modules/loaders.py index 062f353..babbe44 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -143,6 +143,11 @@ loaders_and_params = OrderedDict({ 'no_mmap', 'mlock' ], + 'QuIP#': [ + 'trust_remote_code', + 'no_use_fast', + 'no_flash_attn', + ] }) loaders_samplers = { @@ -453,6 +458,43 @@ loaders_samplers = { 'skip_special_tokens', 'auto_max_new_tokens', }, + 'QuIP#': { + 'temperature', + 'temperature_last', + 'top_p', + 'min_p', + 'top_k', + 'typical_p', + 'epsilon_cutoff', + 'eta_cutoff', + 'tfs', + 'top_a', + 'repetition_penalty', + 'presence_penalty', + 'frequency_penalty', + 'repetition_penalty_range', + 'encoder_repetition_penalty', + 'no_repeat_ngram_size', + 'min_length', + 'seed', + 'do_sample', + 'penalty_alpha', + 'num_beams', + 'length_penalty', + 'early_stopping', + 'mirostat_mode', + 'mirostat_tau', + 'mirostat_eta', + 'grammar_file_row', + 'grammar_string', + 'guidance_scale', + 'negative_prompt', + 'ban_eos_token', + 'custom_token_bans', + 'add_bos_token', + 'skip_special_tokens', + 'auto_max_new_tokens', + }, } loaders_model_types = { diff --git a/modules/models.py b/modules/models.py index c7dd6cc..1df36a6 100644 --- a/modules/models.py +++ b/modules/models.py @@ -1,4 +1,5 @@ import gc +import logging import os import re import time @@ -23,6 +24,7 @@ import modules.shared as shared from modules import RoPE, llama_attn_hijack, sampler_hijack from modules.logging_colors import logger from modules.models_settings import get_model_metadata +from modules.relative_imports import RelativeImport transformers.logging.set_verbosity_error() @@ -69,6 +71,7 @@ def load_model(model_name, loader=None): 'ExLlamav2_HF': ExLlamav2_HF_loader, 'ctransformers': ctransformers_loader, 'AutoAWQ': AutoAWQ_loader, + 'QuIP#': QuipSharp_loader, } metadata = get_model_metadata(model_name) @@ -321,6 +324,37 @@ def AutoAWQ_loader(model_name): return model +def QuipSharp_loader(model_name): + try: + with RelativeImport("repositories/quip-sharp"): + from lib.utils.unsafe_import import model_from_hf_path + except: + logger.error( + "\nQuIP# has not been found. It must be installed manually for now.\n" + "For instructions on how to do that, please consult:\n" + "https://github.com/oobabooga/text-generation-webui/pull/4803\n" + ) + return None, None + + # This fixes duplicate logging messages after the import above. + handlers = logging.getLogger().handlers + if len(handlers) > 1: + logging.getLogger().removeHandler(handlers[1]) + + model_dir = Path(f'{shared.args.model_dir}/{model_name}') + if not all((model_dir / file).exists() for file in ['tokenizer_config.json', 'special_tokens_map.json', 'tokenizer.model']): + logger.error(f"Could not load the model because the tokenizer files could not be found in the model folder. Please download the following files from the original (unquantized) model into {model_dir}: special_tokens_map.json, tokenizer.json, tokenizer.model, tokenizer_config.json.") + return None, None + + model, model_str = model_from_hf_path( + model_dir, + use_cuda_graph=False, + use_flash_attn=not shared.args.no_flash_attn + ) + + return model + + def GPTQ_loader(model_name): # Monkey patch diff --git a/modules/models_settings.py b/modules/models_settings.py index ebe4fdd..d259a4e 100644 --- a/modules/models_settings.py +++ b/modules/models_settings.py @@ -33,14 +33,24 @@ def get_model_metadata(model): for k in settings[pat]: model_settings[k] = settings[pat][k] + + path = Path(f'{shared.args.model_dir}/{model}/config.json') + if path.exists(): + hf_metadata = json.loads(open(path, 'r').read()) + else: + hf_metadata = None + if 'loader' not in model_settings: - loader = infer_loader(model, model_settings) - if 'wbits' in model_settings and type(model_settings['wbits']) is int and model_settings['wbits'] > 0: - loader = 'AutoGPTQ' + if hf_metadata is not None and 'quip_params' in hf_metadata: + model_settings['loader'] = 'QuIP#' + else: + loader = infer_loader(model, model_settings) + if 'wbits' in model_settings and type(model_settings['wbits']) is int and model_settings['wbits'] > 0: + loader = 'AutoGPTQ' - model_settings['loader'] = loader + model_settings['loader'] = loader - # Read GGUF metadata + # GGUF metadata if model_settings['loader'] in ['llama.cpp', 'llamacpp_HF', 'ctransformers']: path = Path(f'{shared.args.model_dir}/{model}') if path.is_file(): @@ -57,9 +67,8 @@ def get_model_metadata(model): model_settings['rope_freq_base'] = metadata['llama.rope.freq_base'] else: - # Read transformers metadata - path = Path(f'{shared.args.model_dir}/{model}/config.json') - if path.exists(): + # Transformers metadata + if hf_metadata is not None: metadata = json.loads(open(path, 'r').read()) if 'max_position_embeddings' in metadata: model_settings['truncation_length'] = metadata['max_position_embeddings'] diff --git a/modules/presets.py b/modules/presets.py index 842992f..1544362 100644 --- a/modules/presets.py +++ b/modules/presets.py @@ -18,7 +18,7 @@ def default_preset(): 'repetition_penalty': 1, 'presence_penalty': 0, 'frequency_penalty': 0, - 'repetition_penalty_range': 0, + 'repetition_penalty_range': 1024, 'typical_p': 1, 'tfs': 1, 'top_a': 0, diff --git a/modules/shared.py b/modules/shared.py index c0899a9..680cd8f 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -36,7 +36,7 @@ settings = { 'prompt-default': 'QA', 'prompt-notebook': 'QA', 'preset': 'simple-1', - 'max_new_tokens': 200, + 'max_new_tokens': 512, 'max_new_tokens_min': 1, 'max_new_tokens_max': 4096, 'negative_prompt': '', @@ -241,6 +241,8 @@ def fix_loader_name(name): return 'ctransformers' elif name in ['autoawq', 'awq', 'auto-awq']: return 'AutoAWQ' + elif name in ['quip#', 'quip-sharp', 'quipsharp', 'quip_sharp']: + return 'QuIP#' def add_extension(name, last=False): diff --git a/modules/text_generation.py b/modules/text_generation.py index ca379fd..417ac19 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -264,14 +264,10 @@ def apply_stopping_strings(reply, all_stop_strings): def get_reply_from_output_ids(output_ids, state, starting_from=0): - if shared.is_seq2seq: - reply = decode(output_ids, state['skip_special_tokens']) - else: - reply = decode(output_ids[starting_from:], state['skip_special_tokens']) - # Prevent LlamaTokenizer from skipping a space - if type(shared.tokenizer) in [transformers.LlamaTokenizer, transformers.LlamaTokenizerFast] and len(output_ids) > 0: - if shared.tokenizer.convert_ids_to_tokens(int(output_ids[starting_from])).startswith('▁'): - reply = ' ' + reply + reply = decode(output_ids[starting_from:], state['skip_special_tokens']) + if type(shared.tokenizer) in [transformers.LlamaTokenizer, transformers.LlamaTokenizerFast] and len(output_ids) > starting_from: + if shared.tokenizer.convert_ids_to_tokens(int(output_ids[starting_from])).startswith('▁'): + reply = ' ' + reply return reply @@ -343,7 +339,8 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings if cuda: output = output.cuda() - yield get_reply_from_output_ids(output, state, starting_from=len(input_ids[0])) + starting_from = 0 if shared.is_seq2seq else len(input_ids[0]) + yield get_reply_from_output_ids(output, state, starting_from=starting_from) # Stream the reply 1 token at a time. # This is based on the trick of using 'stopping_criteria' to create an iterator. @@ -360,12 +357,17 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings with generate_with_streaming(**generate_params) as generator: cumulative_reply = '' - starting_from = len(input_ids[0]) + starting_from = 0 if shared.is_seq2seq else len(input_ids[0]) for output in generator: if output[-1] in eos_token_ids: break - cumulative_reply += get_reply_from_output_ids(output, state, starting_from=starting_from) + new_content = get_reply_from_output_ids(output, state, starting_from=starting_from) + # check the partial unicode character + if chr(0xfffd) in new_content: + continue + + cumulative_reply += new_content starting_from = len(output) yield cumulative_reply diff --git a/one_click.py b/one_click.py index 6febcba..367a6e5 100644 --- a/one_click.py +++ b/one_click.py @@ -4,6 +4,7 @@ import hashlib import os import platform import re +import signal import site import subprocess import sys @@ -27,6 +28,13 @@ else: flags = f"{' '.join([flag for flag in sys.argv[1:] if flag != '--update'])} {CMD_FLAGS}" +def signal_handler(sig, frame): + sys.exit(0) + + +signal.signal(signal.SIGINT, signal_handler) + + def is_linux(): return sys.platform.startswith("linux") @@ -210,7 +218,7 @@ def install_webui(): elif is_linux() and (choice == "C" or choice == "N"): install_pytorch = "python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu" elif choice == "D": - install_pytorch = "python -m pip install torch==2.0.1a0 torchvision==0.15.2a0 intel_extension_for_pytorch==2.0.110+xpu -f https://developer.intel.com/ipex-whl-stable-xpu" + install_pytorch = "python -m pip install torch==2.0.1a0 torchvision==0.15.2a0 intel_extension_for_pytorch==2.0.110+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/" # Install Git and then Pytorch run_cmd(f"{install_git} && {install_pytorch} && python -m pip install py-cpuinfo==9.0.0", assert_success=True, environment=True) diff --git a/server.py b/server.py index ae4eceb..0f06f56 100644 --- a/server.py +++ b/server.py @@ -21,6 +21,7 @@ matplotlib.use('Agg') # This fixes LaTeX rendering on some systems import json import os +import signal import sys import time from functools import partial @@ -55,6 +56,17 @@ from modules.models_settings import ( from modules.utils import gradio +def signal_handler(sig, frame): + logger.info("Received Ctrl+C. Shutting down Text generation web UI gracefully.") + if 'interface' in shared.gradio: + shared.gradio['interface'].close() + + sys.exit(0) + + +signal.signal(signal.SIGINT, signal_handler) + + def create_interface(): title = 'Text generation web UI' diff --git a/settings-template.yaml b/settings-template.yaml index cb16844..5cd87e0 100644 --- a/settings-template.yaml +++ b/settings-template.yaml @@ -6,7 +6,7 @@ chat_style: cai-chat prompt-default: QA prompt-notebook: QA preset: simple-1 -max_new_tokens: 200 +max_new_tokens: 512 max_new_tokens_min: 1 max_new_tokens_max: 4096 seed: -1