Handle training exception for unsupported models
This commit is contained in:
parent
a6d0373063
commit
58349f44a0
1 changed files with 8 additions and 1 deletions
|
@ -2,6 +2,7 @@ import json
|
|||
import sys
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
|
@ -184,7 +185,13 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
|||
bias="none",
|
||||
task_type="CAUSAL_LM"
|
||||
)
|
||||
|
||||
try:
|
||||
lora_model = get_peft_model(shared.model, config)
|
||||
except:
|
||||
yield traceback.format_exc()
|
||||
return
|
||||
|
||||
trainer = transformers.Trainer(
|
||||
model=lora_model,
|
||||
train_dataset=train_data,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue