Prevent unwanted log messages from modules
This commit is contained in:
parent
fb91406e93
commit
e116d31180
20 changed files with 120 additions and 111 deletions
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import logging
|
||||
import math
|
||||
import sys
|
||||
import threading
|
||||
|
@ -15,8 +14,9 @@ from peft import (LoraConfig, get_peft_model, prepare_model_for_int8_training,
|
|||
set_peft_model_state_dict)
|
||||
|
||||
from modules import shared, ui, utils
|
||||
from modules.evaluate import calculate_perplexity, generate_markdown_table, save_past_evaluations
|
||||
|
||||
from modules.evaluate import (calculate_perplexity, generate_markdown_table,
|
||||
save_past_evaluations)
|
||||
from modules.logging_colors import logger
|
||||
|
||||
# This mapping is from a very recent commit, not yet released.
|
||||
# If not available, default to a backup map for some common model types.
|
||||
|
@ -24,7 +24,8 @@ try:
|
|||
from peft.utils.other import \
|
||||
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING as \
|
||||
model_to_lora_modules
|
||||
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
||||
from transformers.models.auto.modeling_auto import \
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
||||
MODEL_CLASSES = {v: k for k, v in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES}
|
||||
except:
|
||||
standard_modules = ["q_proj", "v_proj"]
|
||||
|
@ -217,13 +218,13 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||
if model_type == "PeftModelForCausalLM":
|
||||
if len(shared.args.lora_names) > 0:
|
||||
yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
|
||||
logging.warning("Training LoRA over top of another LoRA. May have unexpected effects.")
|
||||
logger.warning("Training LoRA over top of another LoRA. May have unexpected effects.")
|
||||
else:
|
||||
yield "Model ID not matched due to LoRA loading. Consider reloading base model. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
|
||||
logging.warning("Model ID not matched due to LoRA loading. Consider reloading base model.")
|
||||
logger.warning("Model ID not matched due to LoRA loading. Consider reloading base model.")
|
||||
else:
|
||||
yield "LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. Unexpected errors may follow. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
|
||||
logging.warning(f"LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. (Found model type: {model_type})")
|
||||
logger.warning(f"LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. (Found model type: {model_type})")
|
||||
|
||||
time.sleep(5)
|
||||
|
||||
|
@ -233,7 +234,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||
|
||||
elif not shared.args.load_in_8bit and shared.args.wbits <= 0:
|
||||
yield "It is highly recommended you use `--load-in-8bit` for LoRA training. *(Will continue anyway in 2 seconds, press `Interrupt` to stop.)*"
|
||||
logging.warning("It is highly recommended you use `--load-in-8bit` for LoRA training.")
|
||||
logger.warning("It is highly recommended you use `--load-in-8bit` for LoRA training.")
|
||||
time.sleep(2) # Give it a moment for the message to show in UI before continuing
|
||||
|
||||
if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0:
|
||||
|
@ -253,7 +254,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||
|
||||
# == Prep the dataset, format, etc ==
|
||||
if raw_text_file not in ['None', '']:
|
||||
logging.info("Loading raw text file dataset...")
|
||||
logger.info("Loading raw text file dataset...")
|
||||
with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file:
|
||||
raw_text = file.read().replace('\r', '')
|
||||
|
||||
|
@ -311,7 +312,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||
prompt = generate_prompt(data_point)
|
||||
return tokenize(prompt)
|
||||
|
||||
logging.info("Loading JSON datasets...")
|
||||
logger.info("Loading JSON datasets...")
|
||||
data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json'))
|
||||
train_data = data['train'].map(generate_and_tokenize_prompt)
|
||||
|
||||
|
@ -323,10 +324,10 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||
|
||||
# == Start prepping the model itself ==
|
||||
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
|
||||
logging.info("Getting model ready...")
|
||||
logger.info("Getting model ready...")
|
||||
prepare_model_for_int8_training(shared.model)
|
||||
|
||||
logging.info("Prepping for training...")
|
||||
logger.info("Prepping for training...")
|
||||
config = LoraConfig(
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
|
@ -337,10 +338,10 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||
)
|
||||
|
||||
try:
|
||||
logging.info("Creating LoRA model...")
|
||||
logger.info("Creating LoRA model...")
|
||||
lora_model = get_peft_model(shared.model, config)
|
||||
if not always_override and Path(f"{lora_file_path}/adapter_model.bin").is_file():
|
||||
logging.info("Loading existing LoRA data...")
|
||||
logger.info("Loading existing LoRA data...")
|
||||
state_dict_peft = torch.load(f"{lora_file_path}/adapter_model.bin")
|
||||
set_peft_model_state_dict(lora_model, state_dict_peft)
|
||||
except:
|
||||
|
@ -418,7 +419,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||
json.dump({x: vars[x] for x in PARAMETERS}, file)
|
||||
|
||||
# == Main run and monitor loop ==
|
||||
logging.info("Starting training...")
|
||||
logger.info("Starting training...")
|
||||
yield "Starting..."
|
||||
if WANT_INTERRUPT:
|
||||
yield "Interrupted before start."
|
||||
|
@ -428,7 +429,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||
trainer.train()
|
||||
# Note: save in the thread in case the gradio thread breaks (eg browser closed)
|
||||
lora_model.save_pretrained(lora_file_path)
|
||||
logging.info("LoRA training run is completed and saved.")
|
||||
logger.info("LoRA training run is completed and saved.")
|
||||
tracked.did_save = True
|
||||
|
||||
thread = threading.Thread(target=threaded_run)
|
||||
|
@ -460,14 +461,14 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||
|
||||
# Saving in the train thread might fail if an error occurs, so save here if so.
|
||||
if not tracked.did_save:
|
||||
logging.info("Training complete, saving...")
|
||||
logger.info("Training complete, saving...")
|
||||
lora_model.save_pretrained(lora_file_path)
|
||||
|
||||
if WANT_INTERRUPT:
|
||||
logging.info("Training interrupted.")
|
||||
logger.info("Training interrupted.")
|
||||
yield f"Interrupted. Incomplete LoRA saved to `{lora_file_path}`"
|
||||
else:
|
||||
logging.info("Training complete!")
|
||||
logger.info("Training complete!")
|
||||
yield f"Done! LoRA saved to `{lora_file_path}`"
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue