Prevent unwanted log messages from modules
This commit is contained in:
parent
fb91406e93
commit
e116d31180
20 changed files with 120 additions and 111 deletions
|
@ -1,5 +1,4 @@
|
|||
import inspect
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
@ -10,14 +9,15 @@ import transformers
|
|||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
|
||||
import modules.shared as shared
|
||||
from modules.logging_colors import logger
|
||||
|
||||
sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa")))
|
||||
|
||||
try:
|
||||
import llama_inference_offload
|
||||
except ImportError:
|
||||
logging.error('Failed to load GPTQ-for-LLaMa')
|
||||
logging.error('See https://github.com/oobabooga/text-generation-webui/blob/main/docs/GPTQ-models-(4-bit-mode).md')
|
||||
logger.error('Failed to load GPTQ-for-LLaMa')
|
||||
logger.error('See https://github.com/oobabooga/text-generation-webui/blob/main/docs/GPTQ-models-(4-bit-mode).md')
|
||||
sys.exit(-1)
|
||||
|
||||
try:
|
||||
|
@ -127,7 +127,7 @@ def find_quantized_model_file(model_name):
|
|||
found = list(path_to_model.glob(f"*{ext}"))
|
||||
if len(found) > 0:
|
||||
if len(found) > 1:
|
||||
logging.warning(f'More than one {ext} model has been found. The last one will be selected. It could be wrong.')
|
||||
logger.warning(f'More than one {ext} model has been found. The last one will be selected. It could be wrong.')
|
||||
|
||||
pt_path = found[-1]
|
||||
break
|
||||
|
@ -138,8 +138,8 @@ def find_quantized_model_file(model_name):
|
|||
# The function that loads the model in modules/models.py
|
||||
def load_quantized(model_name):
|
||||
if shared.args.model_type is None:
|
||||
logging.error("The model could not be loaded because its type could not be inferred from its name.")
|
||||
logging.error("Please specify the type manually using the --model_type argument.")
|
||||
logger.error("The model could not be loaded because its type could not be inferred from its name.")
|
||||
logger.error("Please specify the type manually using the --model_type argument.")
|
||||
return None
|
||||
|
||||
# Select the appropriate load_quant function
|
||||
|
@ -148,21 +148,21 @@ def load_quantized(model_name):
|
|||
load_quant = llama_inference_offload.load_quant
|
||||
elif model_type in ('llama', 'opt', 'gptj'):
|
||||
if shared.args.pre_layer:
|
||||
logging.warning("Ignoring --pre_layer because it only works for llama model type.")
|
||||
logger.warning("Ignoring --pre_layer because it only works for llama model type.")
|
||||
|
||||
load_quant = _load_quant
|
||||
else:
|
||||
logging.error("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported")
|
||||
logger.error("Unknown pre-quantized model type specified. Only 'llama', 'opt' and 'gptj' are supported")
|
||||
exit()
|
||||
|
||||
# Find the quantized model weights file (.pt/.safetensors)
|
||||
path_to_model = Path(f'{shared.args.model_dir}/{model_name}')
|
||||
pt_path = find_quantized_model_file(model_name)
|
||||
if not pt_path:
|
||||
logging.error("Could not find the quantized model in .pt or .safetensors format, exiting...")
|
||||
logger.error("Could not find the quantized model in .pt or .safetensors format, exiting...")
|
||||
exit()
|
||||
else:
|
||||
logging.info(f"Found the following quantized model: {pt_path}")
|
||||
logger.info(f"Found the following quantized model: {pt_path}")
|
||||
|
||||
# qwopqwop200's offload
|
||||
if model_type == 'llama' and shared.args.pre_layer:
|
||||
|
@ -190,7 +190,7 @@ def load_quantized(model_name):
|
|||
max_memory = accelerate.utils.get_balanced_memory(model)
|
||||
|
||||
device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LlamaDecoderLayer"])
|
||||
logging.info("Using the following device map for the quantized model:", device_map)
|
||||
logger.info("Using the following device map for the quantized model:", device_map)
|
||||
# https://huggingface.co/docs/accelerate/package_reference/big_modeling#accelerate.dispatch_model
|
||||
model = accelerate.dispatch_model(model, device_map=device_map, offload_buffers=True)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue