Make the code more like PEP8 for readability (#862)

This commit is contained in:
oobabooga 2023-04-07 00:15:45 -03:00 committed by GitHub
parent 848c4edfd5
commit ea6e77df72
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
28 changed files with 302 additions and 165 deletions

View file

@ -19,9 +19,11 @@ CURRENT_STEPS = 0
MAX_STEPS = 0
CURRENT_GRADIENT_ACCUM = 1
def get_dataset(path: str, ext: str):
return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=str.lower)
def create_train_interface():
with gr.Tab('Train LoRA', elem_id='lora-train-tab'):
lora_name = gr.Textbox(label="Name", info="The name of your new LoRA file")
@ -44,16 +46,16 @@ def create_train_interface():
with gr.Tab(label="Formatted Dataset"):
with gr.Row():
dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Dataset', info='The dataset file to use for training.')
ui.create_refresh_button(dataset, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
ui.create_refresh_button(dataset, lambda: None, lambda: {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
eval_dataset = gr.Dropdown(choices=get_dataset('training/datasets', 'json'), value='None', label='Evaluation Dataset', info='The (optional) dataset file used to evaluate the model after training.')
ui.create_refresh_button(eval_dataset, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
ui.create_refresh_button(eval_dataset, lambda: None, lambda: {'choices': get_dataset('training/datasets', 'json')}, 'refresh-button')
format = gr.Dropdown(choices=get_dataset('training/formats', 'json'), value='None', label='Data Format', info='The format file used to decide how to format the dataset input.')
ui.create_refresh_button(format, lambda : None, lambda : {'choices': get_dataset('training/formats', 'json')}, 'refresh-button')
ui.create_refresh_button(format, lambda: None, lambda: {'choices': get_dataset('training/formats', 'json')}, 'refresh-button')
with gr.Tab(label="Raw Text File"):
with gr.Row():
raw_text_file = gr.Dropdown(choices=get_dataset('training/datasets', 'txt'), value='None', label='Text File', info='The raw text file to use for training.')
ui.create_refresh_button(raw_text_file, lambda : None, lambda : {'choices': get_dataset('training/datasets', 'txt')}, 'refresh-button')
ui.create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': get_dataset('training/datasets', 'txt')}, 'refresh-button')
with gr.Row():
overlap_len = gr.Slider(label='Overlap Length', minimum=0, maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length below). Setting overlap to exactly half the cutoff length may be ideal.')
newline_favor_len = gr.Slider(label='Prefer Newline Cut Length', minimum=0, maximum=512, value=128, step=16, info='Length (in characters, not tokens) of the maximum distance to shift an overlap cut by to ensure chunks cut at newlines. If too low, cuts may occur in the middle of lines.')
@ -67,10 +69,12 @@ def create_train_interface():
cutoff_len, dataset, eval_dataset, format, raw_text_file, overlap_len, newline_favor_len], [output])
stop_button.click(do_interrupt, [], [], cancels=[], queue=False)
def do_interrupt():
global WANT_INTERRUPT
WANT_INTERRUPT = True
class Callbacks(transformers.TrainerCallback):
def on_step_begin(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
global CURRENT_STEPS, MAX_STEPS
@ -79,6 +83,7 @@ class Callbacks(transformers.TrainerCallback):
if WANT_INTERRUPT:
control.should_epoch_stop = True
control.should_training_stop = True
def on_substep_end(self, args: transformers.TrainingArguments, state: transformers.TrainerState, control: transformers.TrainerControl, **kwargs):
global CURRENT_STEPS
CURRENT_STEPS += 1
@ -86,6 +91,7 @@ class Callbacks(transformers.TrainerCallback):
control.should_epoch_stop = True
control.should_training_stop = True
def clean_path(base_path: str, path: str):
""""Strips unusual symbols and forcibly builds a path as relative to the intended directory."""
# TODO: Probably could do with a security audit to guarantee there's no ways this can be bypassed to target an unwanted path.
@ -95,6 +101,7 @@ def clean_path(base_path: str, path: str):
return path
return f'{Path(base_path).absolute()}/{path}'
def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lora_rank: int, lora_alpha: int, lora_dropout: float,
cutoff_len: int, dataset: str, eval_dataset: str, format: str, raw_text_file: str, overlap_len: int, newline_favor_len: int):
global WANT_INTERRUPT, CURRENT_STEPS, MAX_STEPS, CURRENT_GRADIENT_ACCUM
@ -124,7 +131,7 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
elif not shared.args.load_in_8bit:
yield "It is highly recommended you use `--load-in-8bit` for LoRA training. *(Will continue anyway in 2 seconds, press `Interrupt` to stop.)*"
print("Warning: It is highly recommended you use `--load-in-8bit` for LoRA training.")
time.sleep(2) # Give it a moment for the message to show in UI before continuing
time.sleep(2) # Give it a moment for the message to show in UI before continuing
if cutoff_len <= 0 or micro_batch_size <= 0 or batch_size <= 0 or actual_lr <= 0 or lora_rank <= 0 or lora_alpha <= 0:
yield "Cannot input zeroes."
@ -148,7 +155,7 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r') as file:
raw_text = file.read()
tokens = shared.tokenizer.encode(raw_text)
del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
tokens = list(split_chunks(tokens, cutoff_len - overlap_len))
for i in range(1, len(tokens)):
@ -197,18 +204,18 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
else:
eval_data = load_dataset("json", data_files=clean_path('training/datasets', f'{eval_dataset}.json'))
eval_data = eval_data['train'].shuffle().map(generate_and_tokenize_prompt)
# == Start prepping the model itself ==
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
print("Getting model ready...")
prepare_model_for_int8_training(shared.model)
print("Prepping for training...")
config = LoraConfig(
r=lora_rank,
lora_alpha=lora_alpha,
# TODO: Should target_modules be configurable?
target_modules=[ "q_proj", "v_proj" ],
target_modules=["q_proj", "v_proj"],
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM"
@ -289,7 +296,7 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
timer_info = f"`{its:.2f}` it/s"
else:
timer_info = f"`{1.0/its:.2f}` s/it"
total_time_estimate = (1.0/its) * (MAX_STEPS)
total_time_estimate = (1.0 / its) * (MAX_STEPS)
yield f"Running... **{CURRENT_STEPS}** / **{MAX_STEPS}** ... {timer_info}, {format_time(time_elapsed)} / {format_time(total_time_estimate)} ... {format_time(total_time_estimate - time_elapsed)} remaining"
print("Training complete, saving...")
@ -302,10 +309,12 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
print("Training complete!")
yield f"Done! LoRA saved to `{lora_name}`"
def split_chunks(arr, step):
for i in range(0, len(arr), step):
yield arr[i:i + step]
def cut_chunk_for_newline(chunk: str, max_length: int):
if '\n' not in chunk:
return chunk
@ -319,6 +328,7 @@ def cut_chunk_for_newline(chunk: str, max_length: int):
chunk = chunk[:last_newline]
return chunk
def format_time(seconds: float):
if seconds < 120:
return f"`{seconds:.0f}` seconds"