diff --git a/modules/training.py b/modules/training.py index b9f3d19..c83427d 100644 --- a/modules/training.py +++ b/modules/training.py @@ -1,4 +1,4 @@ -import sys, torch, json +import sys, torch, json, threading, time from pathlib import Path import gradio as gr from datasets import load_dataset @@ -6,6 +6,9 @@ import transformers from modules import ui, shared from peft import prepare_model_for_int8_training, LoraConfig, get_peft_model, get_peft_model_state_dict +CURRENT_STEPS = 0 +MAX_STEPS = 0 + def get_json_dataset(path: str): def get_set(): return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path(path).glob('*.json'))), key=str.lower) @@ -40,6 +43,12 @@ def create_train_interface(): output = gr.Markdown(value="(...)") startButton.click(do_train, [loraName, microBatchSize, batchSize, epochs, learningRate, loraRank, loraAlpha, loraDropout, cutoffLen, dataset, evalDataset, format], [output]) +class Callbacks(transformers.TrainerCallback): + def on_step_begin(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs): + global CURRENT_STEPS, MAX_STEPS + CURRENT_STEPS = state.global_step + MAX_STEPS = state.max_steps + def cleanPath(basePath: str, path: str): """"Strips unusual symbols and forcibly builds a path as relative to the intended directory.""" # TODO: Probably could do with a security audit to guarantee there's no ways this can be bypassed to target an unwanted path. @@ -50,8 +59,11 @@ def cleanPath(basePath: str, path: str): return f'{Path(basePath).absolute()}/{path}' def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, learningRate: float, loraRank: int, loraAlpha: int, loraDropout: float, cutoffLen: int, dataset: str, evalDataset: str, format: str): + global CURRENT_STEPS, MAX_STEPS + CURRENT_STEPS = 0 + MAX_STEPS = 0 yield "Prepping..." - # Input validation / processing + # == Input validation / processing == # TODO: --lora-dir PR once pulled will need to be applied here loraName = f"loras/{cleanPath(None, loraName)}" if dataset is None: @@ -62,7 +74,7 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le actualLR = float(learningRate) shared.tokenizer.pad_token = 0 shared.tokenizer.padding_side = "left" - # Prep the dataset, format, etc + # == Prep the dataset, format, etc == with open(cleanPath('training/formats', f'{format}.json'), 'r') as formatFile: formatData: dict[str, str] = json.load(formatFile) def tokenize(prompt): @@ -89,7 +101,7 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le else: evalData = load_dataset("json", data_files=cleanPath('training/datasets', f'{evalDataset}.json')) evalData = evalData['train'].shuffle().map(generate_and_tokenize_prompt) - # Start prepping the model itself + # == Start prepping the model itself == if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'): print("Getting model ready...") prepare_model_for_int8_training(shared.model) @@ -128,6 +140,7 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le ddp_find_unused_parameters=None ), data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False), + callbacks=list([Callbacks()]) ) loraModel.config.use_cache = False old_state_dict = loraModel.state_dict @@ -136,12 +149,31 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le ).__get__(loraModel, type(loraModel)) if torch.__version__ >= "2" and sys.platform != "win32": loraModel = torch.compile(loraModel) - # Actually start and run and save at the end + # == Main run and monitor loop == # TODO: save/load checkpoints to resume from? print("Starting training...") - yield "Running..." - trainer.train() + yield "Starting..." + def threadedRun(): + trainer.train() + thread = threading.Thread(target=threadedRun) + thread.start() + lastStep = 0 + startTime = time.perf_counter() + while thread.is_alive(): + time.sleep(0.5) + if CURRENT_STEPS != lastStep: + lastStep = CURRENT_STEPS + timeElapsed = time.perf_counter() - startTime + if timeElapsed <= 0: + timerInfo = "" + else: + its = CURRENT_STEPS / timeElapsed + if its > 1: + timerInfo = f"`{its:.2f}` it/s" + else: + timerInfo = f"`{1.0/its:.2f}` s/it" + yield f"Running... **{CURRENT_STEPS}** / **{MAX_STEPS}** ... {timerInfo}, `{timeElapsed:.1f}` seconds" print("Training complete, saving...") loraModel.save_pretrained(loraName) print("Training complete!") - yield f"Done! Lora saved to `{loraName}`" + yield f"Done! LoRA saved to `{loraName}`"