Small style changes

This commit is contained in:
oobabooga 2023-03-27 21:24:39 -03:00
parent c2cad30772
commit 2f0571bfa4
3 changed files with 20 additions and 7 deletions

View file

@ -1,10 +1,17 @@
import sys, torch, json, threading, time
import json
import sys
import threading
import time
from pathlib import Path
import gradio as gr
from datasets import load_dataset
import torch
import transformers
from modules import ui, shared
from peft import prepare_model_for_int8_training, LoraConfig, get_peft_model, get_peft_model_state_dict
from datasets import load_dataset
from peft import (LoraConfig, get_peft_model, get_peft_model_state_dict,
prepare_model_for_int8_training)
from modules import shared, ui
WANT_INTERRUPT = False
CURRENT_STEPS = 0
@ -44,7 +51,7 @@ def create_train_interface():
with gr.Row():
startButton = gr.Button("Start LoRA Training")
stopButton = gr.Button("Interrupt")
output = gr.Markdown(value="(...)")
output = gr.Markdown(value="Ready")
startEvent = startButton.click(do_train, [loraName, microBatchSize, batchSize, epochs, learningRate, loraRank, loraAlpha, loraDropout, cutoffLen, dataset, evalDataset, format], [output])
stopButton.click(doInterrupt, [], [], cancels=[], queue=False)
@ -169,16 +176,20 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le
).__get__(loraModel, type(loraModel))
if torch.__version__ >= "2" and sys.platform != "win32":
loraModel = torch.compile(loraModel)
# == Main run and monitor loop ==
# TODO: save/load checkpoints to resume from?
print("Starting training...")
yield "Starting..."
def threadedRun():
trainer.train()
thread = threading.Thread(target=threadedRun)
thread.start()
lastStep = 0
startTime = time.perf_counter()
while thread.is_alive():
time.sleep(0.5)
if WANT_INTERRUPT:
@ -197,8 +208,10 @@ def do_train(loraName: str, microBatchSize: int, batchSize: int, epochs: int, le
timerInfo = f"`{1.0/its:.2f}` s/it"
totalTimeEstimate = (1.0/its) * (MAX_STEPS)
yield f"Running... **{CURRENT_STEPS}** / **{MAX_STEPS}** ... {timerInfo}, `{timeElapsed:.0f}`/`{totalTimeEstimate:.0f}` seconds"
print("Training complete, saving...")
loraModel.save_pretrained(loraName)
if WANT_INTERRUPT:
print("Training interrupted.")
yield f"Interrupted. Incomplete LoRA saved to `{loraName}`"