Add QuIP# support (#4803)

It has to be installed manually for now.
This commit is contained in:
oobabooga 2023-12-06 00:01:01 -03:00 committed by GitHub
parent 6430acadde
commit 98361af4d5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 95 additions and 8 deletions

View file

@ -1,4 +1,5 @@
import gc
import logging
import os
import re
import time
@ -23,6 +24,7 @@ import modules.shared as shared
from modules import RoPE, llama_attn_hijack, sampler_hijack
from modules.logging_colors import logger
from modules.models_settings import get_model_metadata
from modules.relative_imports import RelativeImport
transformers.logging.set_verbosity_error()
@ -69,6 +71,7 @@ def load_model(model_name, loader=None):
'ExLlamav2_HF': ExLlamav2_HF_loader,
'ctransformers': ctransformers_loader,
'AutoAWQ': AutoAWQ_loader,
'QuIP#': QuipSharp_loader,
}
metadata = get_model_metadata(model_name)
@ -321,6 +324,37 @@ def AutoAWQ_loader(model_name):
return model
def QuipSharp_loader(model_name):
try:
with RelativeImport("repositories/quip-sharp"):
from lib.utils.unsafe_import import model_from_hf_path
except:
logger.error(
"\nQuIP# has not been found. It must be installed manually for now.\n"
"For instructions on how to do that, please consult:\n"
"https://github.com/oobabooga/text-generation-webui/pull/4803\n"
)
return None, None
# This fixes duplicate logging messages after the import above.
handlers = logging.getLogger().handlers
if len(handlers) > 1:
logging.getLogger().removeHandler(handlers[1])
model_dir = Path(f'{shared.args.model_dir}/{model_name}')
if not all((model_dir / file).exists() for file in ['tokenizer_config.json', 'special_tokens_map.json', 'tokenizer.model']):
logger.error(f"Could not load the model because the tokenizer files could not be found in the model folder. Please download the following files from the original (unquantized) model into {model_dir}: special_tokens_map.json, tokenizer.json, tokenizer.model, tokenizer_config.json.")
return None, None
model, model_str = model_from_hf_path(
model_dir,
use_cuda_graph=False,
use_flash_attn=not shared.args.no_flash_attn
)
return model
def GPTQ_loader(model_name):
# Monkey patch