Add an "Evaluate" tab to calculate the perplexities of models (#1322)

This commit is contained in:
oobabooga 2023-04-21 00:20:33 -03:00 committed by GitHub
parent ff0d0ac552
commit c4f4f41389
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 203 additions and 22 deletions

View file

@ -10,9 +10,12 @@ import gradio as gr
import torch
import transformers
from datasets import Dataset, load_dataset
from peft import LoraConfig, get_peft_model, set_peft_model_state_dict, prepare_model_for_int8_training
from peft import (LoraConfig, get_peft_model, prepare_model_for_int8_training,
set_peft_model_state_dict)
from modules import shared, ui
from modules.evaluate import calculate_perplexity, generate_markdown_table, save_past_evaluations
from server import get_available_loras, get_available_models
# This mapping is from a very recent commit, not yet released.
# If not available, default to a backup map for the 3 safe model types.
@ -40,10 +43,6 @@ def get_datasets(path: str, ext: str):
return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=str.lower)
def get_available_loras():
return ['None'] + sorted([item.name for item in list(Path(shared.args.lora_dir).glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
def create_train_interface():
with gr.Tab('Train LoRA', elem_id='lora-train-tab'):
with gr.Row():
@ -82,9 +81,9 @@ def create_train_interface():
eval_steps = gr.Number(label='Evaluate every n steps', value=100, info='If an evaluation dataset is given, test it every time this many steps pass.')
with gr.Tab(label='Raw Text File'):
with gr.Tab(label="Raw text file"):
with gr.Row():
raw_text_file = gr.Dropdown(choices=get_datasets('training/datasets', 'txt'), value='None', label='Text File', info='The raw text file to use for training.')
raw_text_file = gr.Dropdown(choices=get_datasets('training/datasets', 'txt'), value='None', label='Text file', info='The raw text file to use for training.')
ui.create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': get_datasets('training/datasets', 'txt')}, 'refresh-button')
with gr.Row():
@ -106,11 +105,48 @@ def create_train_interface():
output = gr.Markdown(value="Ready")
all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, do_shuffle, higher_rank_limit, warmup_steps, optimizer]
copy_from.change(do_copy_params, [copy_from] + all_params, all_params)
start_button.click(do_train, all_params, output)
stop_button.click(do_interrupt, None, None, queue=False)
higher_rank_limit.change(change_rank_limit, [higher_rank_limit], [lora_rank, lora_alpha])
with gr.Tab('Perplexity evaluation', elem_id='evaluate-tab'):
with gr.Row():
with gr.Column():
models = gr.Dropdown(get_available_models(), label='Models', multiselect=True)
evaluate_text_file = gr.Dropdown(choices=['wikitext', 'ptb', 'ptb_new'] + get_datasets('training/datasets', 'txt')[1:], value='wikitext', label='Input dataset', info='The raw text file on which the model will be evaluated. The first options are automatically downloaded: wikitext, ptb, and ptb_new. The next options are your local text files under training/datasets.')
with gr.Row():
stride_length = gr.Slider(label='Stride', minimum=1, maximum=2048, value=512, step=1, info='Used to make the evaluation faster at the cost of accuracy. 1 = slowest but most accurate. 512 is a common value.')
max_length = gr.Slider(label='max_length', minimum=1, maximum=8096, value=0, step=1, info='The context for each evaluation. If set to 0, the maximum context length for the model will be used.')
with gr.Row():
start_current_evaluation = gr.Button("Evaluate loaded model")
start_evaluation = gr.Button("Evaluate selected models")
stop_evaluation = gr.Button("Interrupt")
with gr.Column():
evaluation_log = gr.Markdown(value = '')
evaluation_table = gr.Dataframe(value=generate_markdown_table(), interactive=True)
save_comments = gr.Button('Save comments')
# Training events
all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, do_shuffle, higher_rank_limit, warmup_steps, optimizer]
copy_from.change(do_copy_params, [copy_from] + all_params, all_params)
start_button.click(do_train, all_params, output)
stop_button.click(do_interrupt, None, None, queue=False)
higher_rank_limit.change(change_rank_limit, [higher_rank_limit], [lora_rank, lora_alpha])
# Evaluation events. For some reason, the interrupt event
# doesn't work with the .then() syntax, so I write them one
# by one in this ugly but functional way.
ev = start_evaluation.click(calculate_perplexity, [models, evaluate_text_file, stride_length, max_length], evaluation_log, show_progress=False)
start_evaluation.click(generate_markdown_table, None, evaluation_table, show_progress=False)
tmp = gr.State('')
start_current_evaluation.click(lambda: ['current model'], None, tmp)
ev_cur = start_current_evaluation.click(calculate_perplexity, [tmp, evaluate_text_file, stride_length, max_length], evaluation_log, show_progress=False)
start_current_evaluation.click(generate_markdown_table, None, evaluation_table, show_progress=False)
stop_evaluation.click(None, None, None, cancels=[ev, ev_cur], queue=False)
save_comments.click(
save_past_evaluations, evaluation_table, None).then(
lambda: "Comments saved.", None, evaluation_log, show_progress=False)
def do_interrupt():
@ -133,6 +169,7 @@ def do_copy_params(lora_name: str, *args):
result.append(params[key])
else:
result.append(args[i])
return result
@ -155,7 +192,8 @@ def clean_path(base_path: str, path: str):
def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, do_shuffle: bool, higher_rank_limit: bool, warmup_steps: int, optimizer: str):
if shared.args.monkey_patch:
from monkeypatch.peft_tuners_lora_monkey_patch import replace_peft_model_with_gptq_lora_model
from monkeypatch.peft_tuners_lora_monkey_patch import \
replace_peft_model_with_gptq_lora_model
replace_peft_model_with_gptq_lora_model()
global WANT_INTERRUPT
@ -300,6 +338,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
if '4bit' in str(type(m)):
if m.is_v1_model:
m.zeros = m.zeros.half()
m.scales = m.scales.half()
class Tracked():