Better warning messages

This commit is contained in:
oobabooga 2023-05-03 21:43:17 -03:00
parent 0a48b29cd8
commit 95d04d6a8d
13 changed files with 194 additions and 83 deletions

View file

@ -1,4 +1,5 @@
import json
import logging
import math
import sys
import threading
@ -40,7 +41,6 @@ WANT_INTERRUPT = False
PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "higher_rank_limit", "warmup_steps", "optimizer"]
def get_datasets(path: str, ext: str):
return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=str.lower)
@ -123,7 +123,7 @@ def create_train_interface():
stop_evaluation = gr.Button("Interrupt")
with gr.Column():
evaluation_log = gr.Markdown(value = '')
evaluation_log = gr.Markdown(value='')
evaluation_table = gr.Dataframe(value=generate_markdown_table(), interactive=True)
save_comments = gr.Button('Save comments')
@ -220,13 +220,14 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
if model_type == "PeftModelForCausalLM":
if len(shared.args.lora_names) > 0:
yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
print("Warning: Training LoRA over top of another LoRA. May have unexpected effects.")
logging.warning("Training LoRA over top of another LoRA. May have unexpected effects.")
else:
yield "Model ID not matched due to LoRA loading. Consider reloading base model. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
print("Warning: Model ID not matched due to LoRA loading. Consider reloading base model.")
logging.warning("Model ID not matched due to LoRA loading. Consider reloading base model.")
else:
yield "LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. Unexpected errors may follow. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
print(f"Warning: LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. (Found model type: {model_type})")
logging.warning(f"LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. (Found model type: {model_type})")
time.sleep(5)
if shared.args.wbits > 0 and not shared.args.monkey_patch:
@ -235,7 +236,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
elif not shared.args.load_in_8bit and shared.args.wbits <= 0:
yield "It is highly recommended you use `--load-in-8bit` for LoRA training. *(Will continue anyway in 2 seconds, press `Interrupt` to stop.)*"
print("Warning: It is highly recommended you use `--load-in-8bit` for LoRA training.")
logging.warning("It is highly recommended you use `--load-in-8bit` for LoRA training.")
time.sleep(2) # Give it a moment for the message to show in UI before continuing
if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0:
@ -255,7 +256,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
# == Prep the dataset, format, etc ==
if raw_text_file not in ['None', '']:
print("Loading raw text file dataset...")
logging.info("Loading raw text file dataset...")
with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file:
raw_text = file.read()
@ -299,7 +300,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
prompt = generate_prompt(data_point)
return tokenize(prompt)
print("Loading JSON datasets...")
logging.info("Loading JSON datasets...")
data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json'))
train_data = data['train'].map(generate_and_tokenize_prompt)
@ -311,10 +312,10 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
# == Start prepping the model itself ==
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
print("Getting model ready...")
logging.info("Getting model ready...")
prepare_model_for_int8_training(shared.model)
print("Prepping for training...")
logging.info("Prepping for training...")
config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
@ -325,10 +326,10 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
)
try:
print("Creating LoRA model...")
logging.info("Creating LoRA model...")
lora_model = get_peft_model(shared.model, config)
if not always_override and Path(f"{lora_file_path}/adapter_model.bin").is_file():
print("Loading existing LoRA data...")
logging.info("Loading existing LoRA data...")
state_dict_peft = torch.load(f"{lora_file_path}/adapter_model.bin")
set_peft_model_state_dict(lora_model, state_dict_peft)
except:
@ -406,7 +407,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
json.dump({x: vars[x] for x in PARAMETERS}, file)
# == Main run and monitor loop ==
print("Starting training...")
logging.info("Starting training...")
yield "Starting..."
if WANT_INTERRUPT:
yield "Interrupted before start."
@ -416,7 +417,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
trainer.train()
# Note: save in the thread in case the gradio thread breaks (eg browser closed)
lora_model.save_pretrained(lora_file_path)
print("LoRA training run is completed and saved.")
logging.info("LoRA training run is completed and saved.")
tracked.did_save = True
thread = threading.Thread(target=threaded_run)
@ -448,14 +449,14 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
# Saving in the train thread might fail if an error occurs, so save here if so.
if not tracked.did_save:
print("Training complete, saving...")
logging.info("Training complete, saving...")
lora_model.save_pretrained(lora_file_path)
if WANT_INTERRUPT:
print("Training interrupted.")
logging.info("Training interrupted.")
yield f"Interrupted. Incomplete LoRA saved to `{lora_file_path}`"
else:
print("Training complete!")
logging.info("Training complete!")
yield f"Done! LoRA saved to `{lora_file_path}`"