Lora trainer docs (#1493)

This commit is contained in:
Alex "mcmonkey" Goodwin 2023-04-23 08:54:41 -07:00 committed by GitHub
parent 7ff645899e
commit 459e725af9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 182 additions and 42 deletions

View file

@ -18,14 +18,14 @@ from modules.evaluate import calculate_perplexity, generate_markdown_table, save
from server import get_available_loras, get_available_models
# This mapping is from a very recent commit, not yet released.
# If not available, default to a backup map for the 3 safe model types.
# If not available, default to a backup map for some common model types.
try:
from peft.utils.other import \
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING as \
model_to_lora_modules
except:
standard_modules = ["q_proj", "v_proj"]
model_to_lora_modules = {"llama": standard_modules, "opt": standard_modules, "gptj": standard_modules}
model_to_lora_modules = {"llama": standard_modules, "opt": standard_modules, "gptj": standard_modules, "gpt_neox": ["query_key_value"]}
WANT_INTERRUPT = False
@ -35,7 +35,8 @@ PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size",
MODEL_CLASSES = {
"LlamaForCausalLM": "llama",
"OPTForCausalLM": "opt",
"GPTJForCausalLM": "gptj"
"GPTJForCausalLM": "gptj",
"GPTNeoXForCausalLM": "gpt_neox"
}
@ -45,6 +46,8 @@ def get_datasets(path: str, ext: str):
def create_train_interface():
with gr.Tab('Train LoRA', elem_id='lora-train-tab'):
gr.Markdown("Confused? [[Click here for a guide]](https://github.com/oobabooga/text-generation-webui/blob/main/docs/Training-LoRAs.md)")
with gr.Row():
lora_name = gr.Textbox(label='Name', info='The name of your new LoRA file')
always_override = gr.Checkbox(label='Override Existing Files', value=False, info='If the name given is the same as an existing file, checking this will replace that file. Leaving unchecked will load that file and continue from it (must use the same rank value as the original had).')
@ -215,11 +218,15 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
else:
model_id = "llama"
if model_type == "PeftModelForCausalLM":
yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
print("Warning: Training LoRA over top of another LoRA. May have unexpected effects.")
if len(shared.args.lora_names) > 0:
yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
print("Warning: Training LoRA over top of another LoRA. May have unexpected effects.")
else:
yield "Model ID not matched due to LoRA loading. Consider reloading base model. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
print("Warning: Model ID not matched due to LoRA loading. Consider reloading base model.")
else:
yield "LoRA training has only currently been validated for LLaMA, OPT, and GPT-J models. Unexpected errors may follow. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
print(f"Warning: LoRA training has only currently been validated for LLaMA, OPT, and GPT-J models. (Found model type: {model_type})")
yield "LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. Unexpected errors may follow. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
print(f"Warning: LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. (Found model type: {model_type})")
time.sleep(5)
if shared.args.wbits > 0 and not shared.args.monkey_patch: