Improve several log messages

This commit is contained in:
oobabooga 2023-12-19 20:54:32 -08:00
parent 23818dc098
commit 9992f7d8c0
7 changed files with 37 additions and 28 deletions

View file

@ -249,7 +249,7 @@ def backup_adapter(input_folder):
adapter_file = Path(f"{input_folder}/adapter_model.bin")
if adapter_file.is_file():
logger.info("Backing up existing LoRA adapter...")
logger.info("Backing up existing LoRA adapter")
creation_date = datetime.fromtimestamp(adapter_file.stat().st_ctime)
creation_date_str = creation_date.strftime("Backup-%Y-%m-%d")
@ -406,7 +406,7 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
# == Prep the dataset, format, etc ==
if raw_text_file not in ['None', '']:
train_template["template_type"] = "raw_text"
logger.info("Loading raw text file dataset...")
logger.info("Loading raw text file dataset")
fullpath = clean_path('training/datasets', f'{raw_text_file}')
fullpath = Path(fullpath)
if fullpath.is_dir():
@ -486,7 +486,7 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
prompt = generate_prompt(data_point)
return tokenize(prompt, add_eos_token)
logger.info("Loading JSON datasets...")
logger.info("Loading JSON datasets")
data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json'))
train_data = data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30))
@ -516,13 +516,13 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
# == Start prepping the model itself ==
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
logger.info("Getting model ready...")
logger.info("Getting model ready")
prepare_model_for_kbit_training(shared.model)
# base model is now frozen and should not be reused for any other LoRA training than this one
shared.model_dirty_from_training = True
logger.info("Preparing for training...")
logger.info("Preparing for training")
config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
@ -540,10 +540,10 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
model_trainable_params, model_all_params = calc_trainable_parameters(shared.model)
try:
logger.info("Creating LoRA model...")
logger.info("Creating LoRA model")
lora_model = get_peft_model(shared.model, config)
if not always_override and Path(f"{lora_file_path}/adapter_model.bin").is_file():
logger.info("Loading existing LoRA data...")
logger.info("Loading existing LoRA data")
state_dict_peft = torch.load(f"{lora_file_path}/adapter_model.bin", weights_only=True)
set_peft_model_state_dict(lora_model, state_dict_peft)
except:
@ -648,7 +648,7 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
json.dump(train_template, file, indent=2)
# == Main run and monitor loop ==
logger.info("Starting training...")
logger.info("Starting training")
yield "Starting..."
lora_trainable_param, lora_all_param = calc_trainable_parameters(lora_model)
@ -730,7 +730,7 @@ def do_train(lora_name: str, always_override: bool, q_proj_en: bool, v_proj_en:
# Saving in the train thread might fail if an error occurs, so save here if so.
if not tracked.did_save:
logger.info("Training complete, saving...")
logger.info("Training complete, saving")
lora_model.save_pretrained(lora_file_path)
if WANT_INTERRUPT: