Lora trainer docs (#1493)
This commit is contained in:
parent
7ff645899e
commit
459e725af9
3 changed files with 182 additions and 42 deletions
|
@ -18,14 +18,14 @@ from modules.evaluate import calculate_perplexity, generate_markdown_table, save
|
|||
from server import get_available_loras, get_available_models
|
||||
|
||||
# This mapping is from a very recent commit, not yet released.
|
||||
# If not available, default to a backup map for the 3 safe model types.
|
||||
# If not available, default to a backup map for some common model types.
|
||||
try:
|
||||
from peft.utils.other import \
|
||||
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING as \
|
||||
model_to_lora_modules
|
||||
except:
|
||||
standard_modules = ["q_proj", "v_proj"]
|
||||
model_to_lora_modules = {"llama": standard_modules, "opt": standard_modules, "gptj": standard_modules}
|
||||
model_to_lora_modules = {"llama": standard_modules, "opt": standard_modules, "gptj": standard_modules, "gpt_neox": ["query_key_value"]}
|
||||
|
||||
WANT_INTERRUPT = False
|
||||
|
||||
|
@ -35,7 +35,8 @@ PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size",
|
|||
MODEL_CLASSES = {
|
||||
"LlamaForCausalLM": "llama",
|
||||
"OPTForCausalLM": "opt",
|
||||
"GPTJForCausalLM": "gptj"
|
||||
"GPTJForCausalLM": "gptj",
|
||||
"GPTNeoXForCausalLM": "gpt_neox"
|
||||
}
|
||||
|
||||
|
||||
|
@ -45,6 +46,8 @@ def get_datasets(path: str, ext: str):
|
|||
|
||||
def create_train_interface():
|
||||
with gr.Tab('Train LoRA', elem_id='lora-train-tab'):
|
||||
gr.Markdown("Confused? [[Click here for a guide]](https://github.com/oobabooga/text-generation-webui/blob/main/docs/Training-LoRAs.md)")
|
||||
|
||||
with gr.Row():
|
||||
lora_name = gr.Textbox(label='Name', info='The name of your new LoRA file')
|
||||
always_override = gr.Checkbox(label='Override Existing Files', value=False, info='If the name given is the same as an existing file, checking this will replace that file. Leaving unchecked will load that file and continue from it (must use the same rank value as the original had).')
|
||||
|
@ -215,11 +218,15 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||
else:
|
||||
model_id = "llama"
|
||||
if model_type == "PeftModelForCausalLM":
|
||||
yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
|
||||
print("Warning: Training LoRA over top of another LoRA. May have unexpected effects.")
|
||||
if len(shared.args.lora_names) > 0:
|
||||
yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
|
||||
print("Warning: Training LoRA over top of another LoRA. May have unexpected effects.")
|
||||
else:
|
||||
yield "Model ID not matched due to LoRA loading. Consider reloading base model. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
|
||||
print("Warning: Model ID not matched due to LoRA loading. Consider reloading base model.")
|
||||
else:
|
||||
yield "LoRA training has only currently been validated for LLaMA, OPT, and GPT-J models. Unexpected errors may follow. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
|
||||
print(f"Warning: LoRA training has only currently been validated for LLaMA, OPT, and GPT-J models. (Found model type: {model_type})")
|
||||
yield "LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. Unexpected errors may follow. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
|
||||
print(f"Warning: LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. (Found model type: {model_type})")
|
||||
time.sleep(5)
|
||||
|
||||
if shared.args.wbits > 0 and not shared.args.monkey_patch:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue