Prevent unwanted log messages from modules

This commit is contained in:
oobabooga 2023-05-21 22:42:34 -03:00
parent fb91406e93
commit e116d31180
20 changed files with 120 additions and 111 deletions

View file

@ -1,5 +1,4 @@
import json
import logging
import math
import sys
import threading
@ -15,8 +14,9 @@ from peft import (LoraConfig, get_peft_model, prepare_model_for_int8_training,
set_peft_model_state_dict)
from modules import shared, ui, utils
from modules.evaluate import calculate_perplexity, generate_markdown_table, save_past_evaluations
from modules.evaluate import (calculate_perplexity, generate_markdown_table,
save_past_evaluations)
from modules.logging_colors import logger
# This mapping is from a very recent commit, not yet released.
# If not available, default to a backup map for some common model types.
@ -24,7 +24,8 @@ try:
from peft.utils.other import \
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING as \
model_to_lora_modules
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from transformers.models.auto.modeling_auto import \
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
MODEL_CLASSES = {v: k for k, v in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES}
except:
standard_modules = ["q_proj", "v_proj"]
@ -217,13 +218,13 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
if model_type == "PeftModelForCausalLM":
if len(shared.args.lora_names) > 0:
yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
logging.warning("Training LoRA over top of another LoRA. May have unexpected effects.")
logger.warning("Training LoRA over top of another LoRA. May have unexpected effects.")
else:
yield "Model ID not matched due to LoRA loading. Consider reloading base model. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
logging.warning("Model ID not matched due to LoRA loading. Consider reloading base model.")
logger.warning("Model ID not matched due to LoRA loading. Consider reloading base model.")
else:
yield "LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. Unexpected errors may follow. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
logging.warning(f"LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. (Found model type: {model_type})")
logger.warning(f"LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. (Found model type: {model_type})")
time.sleep(5)
@ -233,7 +234,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
elif not shared.args.load_in_8bit and shared.args.wbits <= 0:
yield "It is highly recommended you use `--load-in-8bit` for LoRA training. *(Will continue anyway in 2 seconds, press `Interrupt` to stop.)*"
logging.warning("It is highly recommended you use `--load-in-8bit` for LoRA training.")
logger.warning("It is highly recommended you use `--load-in-8bit` for LoRA training.")
time.sleep(2) # Give it a moment for the message to show in UI before continuing
if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0:
@ -253,7 +254,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
# == Prep the dataset, format, etc ==
if raw_text_file not in ['None', '']:
logging.info("Loading raw text file dataset...")
logger.info("Loading raw text file dataset...")
with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file:
raw_text = file.read().replace('\r', '')
@ -311,7 +312,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
prompt = generate_prompt(data_point)
return tokenize(prompt)
logging.info("Loading JSON datasets...")
logger.info("Loading JSON datasets...")
data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json'))
train_data = data['train'].map(generate_and_tokenize_prompt)
@ -323,10 +324,10 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
# == Start prepping the model itself ==
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
logging.info("Getting model ready...")
logger.info("Getting model ready...")
prepare_model_for_int8_training(shared.model)
logging.info("Prepping for training...")
logger.info("Prepping for training...")
config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
@ -337,10 +338,10 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
)
try:
logging.info("Creating LoRA model...")
logger.info("Creating LoRA model...")
lora_model = get_peft_model(shared.model, config)
if not always_override and Path(f"{lora_file_path}/adapter_model.bin").is_file():
logging.info("Loading existing LoRA data...")
logger.info("Loading existing LoRA data...")
state_dict_peft = torch.load(f"{lora_file_path}/adapter_model.bin")
set_peft_model_state_dict(lora_model, state_dict_peft)
except:
@ -418,7 +419,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
json.dump({x: vars[x] for x in PARAMETERS}, file)
# == Main run and monitor loop ==
logging.info("Starting training...")
logger.info("Starting training...")
yield "Starting..."
if WANT_INTERRUPT:
yield "Interrupted before start."
@ -428,7 +429,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
trainer.train()
# Note: save in the thread in case the gradio thread breaks (eg browser closed)
lora_model.save_pretrained(lora_file_path)
logging.info("LoRA training run is completed and saved.")
logger.info("LoRA training run is completed and saved.")
tracked.did_save = True
thread = threading.Thread(target=threaded_run)
@ -460,14 +461,14 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
# Saving in the train thread might fail if an error occurs, so save here if so.
if not tracked.did_save:
logging.info("Training complete, saving...")
logger.info("Training complete, saving...")
lora_model.save_pretrained(lora_file_path)
if WANT_INTERRUPT:
logging.info("Training interrupted.")
logger.info("Training interrupted.")
yield f"Interrupted. Incomplete LoRA saved to `{lora_file_path}`"
else:
logging.info("Training complete!")
logger.info("Training complete!")
yield f"Done! LoRA saved to `{lora_file_path}`"