Add an "Evaluate" tab to calculate the perplexities of models (#1322)
This commit is contained in:
parent
ff0d0ac552
commit
c4f4f41389
5 changed files with 203 additions and 22 deletions
|
@ -10,9 +10,12 @@ import gradio as gr
|
|||
import torch
|
||||
import transformers
|
||||
from datasets import Dataset, load_dataset
|
||||
from peft import LoraConfig, get_peft_model, set_peft_model_state_dict, prepare_model_for_int8_training
|
||||
from peft import (LoraConfig, get_peft_model, prepare_model_for_int8_training,
|
||||
set_peft_model_state_dict)
|
||||
|
||||
from modules import shared, ui
|
||||
from modules.evaluate import calculate_perplexity, generate_markdown_table, save_past_evaluations
|
||||
from server import get_available_loras, get_available_models
|
||||
|
||||
# This mapping is from a very recent commit, not yet released.
|
||||
# If not available, default to a backup map for the 3 safe model types.
|
||||
|
@ -40,10 +43,6 @@ def get_datasets(path: str, ext: str):
|
|||
return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=str.lower)
|
||||
|
||||
|
||||
def get_available_loras():
|
||||
return ['None'] + sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
|
||||
|
||||
|
||||
def create_train_interface():
|
||||
with gr.Tab('Train LoRA', elem_id='lora-train-tab'):
|
||||
with gr.Row():
|
||||
|
@ -82,9 +81,9 @@ def create_train_interface():
|
|||
|
||||
eval_steps = gr.Number(label='Evaluate every n steps', value=100, info='If an evaluation dataset is given, test it every time this many steps pass.')
|
||||
|
||||
with gr.Tab(label='Raw Text File'):
|
||||
with gr.Tab(label="Raw text file"):
|
||||
with gr.Row():
|
||||
raw_text_file = gr.Dropdown(choices=get_datasets('training/datasets', 'txt'), value='None', label='Text File', info='The raw text file to use for training.')
|
||||
raw_text_file = gr.Dropdown(choices=get_datasets('training/datasets', 'txt'), value='None', label='Text file', info='The raw text file to use for training.')
|
||||
ui.create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': get_datasets('training/datasets', 'txt')}, 'refresh-button')
|
||||
|
||||
with gr.Row():
|
||||
|
@ -106,11 +105,48 @@ def create_train_interface():
|
|||
|
||||
output = gr.Markdown(value="Ready")
|
||||
|
||||
all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, do_shuffle, higher_rank_limit, warmup_steps, optimizer]
|
||||
copy_from.change(do_copy_params, [copy_from] + all_params, all_params)
|
||||
start_button.click(do_train, all_params, output)
|
||||
stop_button.click(do_interrupt, None, None, queue=False)
|
||||
higher_rank_limit.change(change_rank_limit, [higher_rank_limit], [lora_rank, lora_alpha])
|
||||
with gr.Tab('Perplexity evaluation', elem_id='evaluate-tab'):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
models = gr.Dropdown(get_available_models(), label='Models', multiselect=True)
|
||||
evaluate_text_file = gr.Dropdown(choices=['wikitext', 'ptb', 'ptb_new'] + get_datasets('training/datasets', 'txt')[1:], value='wikitext', label='Input dataset', info='The raw text file on which the model will be evaluated. The first options are automatically downloaded: wikitext, ptb, and ptb_new. The next options are your local text files under training/datasets.')
|
||||
with gr.Row():
|
||||
stride_length = gr.Slider(label='Stride', minimum=1, maximum=2048, value=512, step=1, info='Used to make the evaluation faster at the cost of accuracy. 1 = slowest but most accurate. 512 is a common value.')
|
||||
max_length = gr.Slider(label='max_length', minimum=1, maximum=8096, value=0, step=1, info='The context for each evaluation. If set to 0, the maximum context length for the model will be used.')
|
||||
|
||||
with gr.Row():
|
||||
start_current_evaluation = gr.Button("Evaluate loaded model")
|
||||
start_evaluation = gr.Button("Evaluate selected models")
|
||||
stop_evaluation = gr.Button("Interrupt")
|
||||
|
||||
with gr.Column():
|
||||
evaluation_log = gr.Markdown(value = '')
|
||||
|
||||
evaluation_table = gr.Dataframe(value=generate_markdown_table(), interactive=True)
|
||||
save_comments = gr.Button('Save comments')
|
||||
|
||||
# Training events
|
||||
all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, do_shuffle, higher_rank_limit, warmup_steps, optimizer]
|
||||
copy_from.change(do_copy_params, [copy_from] + all_params, all_params)
|
||||
start_button.click(do_train, all_params, output)
|
||||
stop_button.click(do_interrupt, None, None, queue=False)
|
||||
higher_rank_limit.change(change_rank_limit, [higher_rank_limit], [lora_rank, lora_alpha])
|
||||
|
||||
# Evaluation events. For some reason, the interrupt event
|
||||
# doesn't work with the .then() syntax, so I write them one
|
||||
# by one in this ugly but functional way.
|
||||
ev = start_evaluation.click(calculate_perplexity, [models, evaluate_text_file, stride_length, max_length], evaluation_log, show_progress=False)
|
||||
start_evaluation.click(generate_markdown_table, None, evaluation_table, show_progress=False)
|
||||
|
||||
tmp = gr.State('')
|
||||
start_current_evaluation.click(lambda: ['current model'], None, tmp)
|
||||
ev_cur = start_current_evaluation.click(calculate_perplexity, [tmp, evaluate_text_file, stride_length, max_length], evaluation_log, show_progress=False)
|
||||
start_current_evaluation.click(generate_markdown_table, None, evaluation_table, show_progress=False)
|
||||
|
||||
stop_evaluation.click(None, None, None, cancels=[ev, ev_cur], queue=False)
|
||||
save_comments.click(
|
||||
save_past_evaluations, evaluation_table, None).then(
|
||||
lambda: "Comments saved.", None, evaluation_log, show_progress=False)
|
||||
|
||||
|
||||
def do_interrupt():
|
||||
|
@ -133,6 +169,7 @@ def do_copy_params(lora_name: str, *args):
|
|||
result.append(params[key])
|
||||
else:
|
||||
result.append(args[i])
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
@ -155,7 +192,8 @@ def clean_path(base_path: str, path: str):
|
|||
def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, do_shuffle: bool, higher_rank_limit: bool, warmup_steps: int, optimizer: str):
|
||||
|
||||
if shared.args.monkey_patch:
|
||||
from monkeypatch.peft_tuners_lora_monkey_patch import replace_peft_model_with_gptq_lora_model
|
||||
from monkeypatch.peft_tuners_lora_monkey_patch import \
|
||||
replace_peft_model_with_gptq_lora_model
|
||||
replace_peft_model_with_gptq_lora_model()
|
||||
|
||||
global WANT_INTERRUPT
|
||||
|
@ -300,6 +338,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||
if '4bit' in str(type(m)):
|
||||
if m.is_v1_model:
|
||||
m.zeros = m.zeros.half()
|
||||
|
||||
m.scales = m.scales.half()
|
||||
|
||||
class Tracked():
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue