Lora Trainer improvements, part 6 - slightly better raw text inputs (#2108)
This commit is contained in:
parent
511470a89b
commit
50c70e28f0
2 changed files with 35 additions and 13 deletions
|
@ -36,9 +36,9 @@ except:
|
|||
"GPTNeoXForCausalLM": "gpt_neox"
|
||||
}
|
||||
|
||||
WANT_INTERRUPT = False
|
||||
|
||||
PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "higher_rank_limit", "warmup_steps", "optimizer"]
|
||||
WANT_INTERRUPT = False
|
||||
PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "higher_rank_limit", "warmup_steps", "optimizer", "hard_cut_string"]
|
||||
|
||||
|
||||
def create_train_interface():
|
||||
|
@ -85,6 +85,7 @@ def create_train_interface():
|
|||
with gr.Row():
|
||||
raw_text_file = gr.Dropdown(choices=utils.get_datasets('training/datasets', 'txt'), value='None', label='Text file', info='The raw text file to use for training.')
|
||||
ui.create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': utils.get_datasets('training/datasets', 'txt')}, 'refresh-button')
|
||||
hard_cut_string = gr.Textbox(label='Hard Cut String', value='\\n\\n\\n', info='String that indicates a hard cut between text parts. Helps prevent unwanted overlap.')
|
||||
|
||||
with gr.Row():
|
||||
overlap_len = gr.Slider(label='Overlap Length', minimum=0, maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length below). Setting overlap to exactly half the cutoff length may be ideal.')
|
||||
|
@ -125,7 +126,7 @@ def create_train_interface():
|
|||
save_comments = gr.Button('Save comments')
|
||||
|
||||
# Training events
|
||||
all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, higher_rank_limit, warmup_steps, optimizer]
|
||||
all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, higher_rank_limit, warmup_steps, optimizer, hard_cut_string]
|
||||
copy_from.change(do_copy_params, [copy_from] + all_params, all_params)
|
||||
start_button.click(do_train, all_params, output)
|
||||
stop_button.click(do_interrupt, None, None, queue=False)
|
||||
|
@ -178,7 +179,7 @@ def change_rank_limit(use_higher_ranks: bool):
|
|||
|
||||
|
||||
def clean_path(base_path: str, path: str):
|
||||
""""Strips unusual symbols and forcibly builds a path as relative to the intended directory."""
|
||||
"""Strips unusual symbols and forcibly builds a path as relative to the intended directory."""
|
||||
# TODO: Probably could do with a security audit to guarantee there's no ways this can be bypassed to target an unwanted path.
|
||||
# Or swap it to a strict whitelist of [a-zA-Z_0-9]
|
||||
path = path.replace('\\', '/').replace('..', '_')
|
||||
|
@ -188,7 +189,7 @@ def clean_path(base_path: str, path: str):
|
|||
return f'{Path(base_path).absolute()}/{path}'
|
||||
|
||||
|
||||
def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str):
|
||||
def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str):
|
||||
|
||||
if shared.args.monkey_patch:
|
||||
from monkeypatch.peft_tuners_lora_monkey_patch import \
|
||||
|
@ -254,16 +255,30 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||
if raw_text_file not in ['None', '']:
|
||||
logging.info("Loading raw text file dataset...")
|
||||
with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file:
|
||||
raw_text = file.read()
|
||||
raw_text = file.read().replace('\r', '')
|
||||
|
||||
cut_string = hard_cut_string.replace('\\n', '\n')
|
||||
out_tokens = []
|
||||
for text_part in raw_text.split(cut_string):
|
||||
if text_part.strip() == '':
|
||||
continue
|
||||
|
||||
tokens = shared.tokenizer.encode(text_part)
|
||||
step = cutoff_len - overlap_len
|
||||
if step <= 0:
|
||||
yield f"Error: overlap_len ({overlap_len}) cannot be greater than or equal to cutoff_len ({cutoff_len})"
|
||||
return
|
||||
|
||||
tokens = list(split_chunks(tokens, step))
|
||||
for i in range(1, len(tokens)):
|
||||
tokens[i] = tokens[i - 1][-overlap_len:] + tokens[i]
|
||||
|
||||
out_tokens.extend(tokens)
|
||||
del tokens
|
||||
|
||||
tokens = shared.tokenizer.encode(raw_text)
|
||||
del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
|
||||
tokens = list(split_chunks(tokens, cutoff_len - overlap_len))
|
||||
for i in range(1, len(tokens)):
|
||||
tokens[i] = tokens[i - 1][-overlap_len:] + tokens[i]
|
||||
|
||||
text_chunks = [shared.tokenizer.decode(x) for x in tokens]
|
||||
del tokens
|
||||
text_chunks = [shared.tokenizer.decode(x) for x in out_tokens]
|
||||
del out_tokens
|
||||
if newline_favor_len > 0:
|
||||
text_chunks = [cut_chunk_for_newline(x, newline_favor_len) for x in text_chunks]
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue