Handle training exception for unsupported models
This commit is contained in:
parent
a6d0373063
commit
58349f44a0
1 changed files with 8 additions and 1 deletions
|
@ -2,6 +2,7 @@ import json
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
|
import traceback
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
@ -184,7 +185,13 @@ def do_train(lora_name: str, micro_batch_size: int, batch_size: int, epochs: int
|
||||||
bias="none",
|
bias="none",
|
||||||
task_type="CAUSAL_LM"
|
task_type="CAUSAL_LM"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
lora_model = get_peft_model(shared.model, config)
|
lora_model = get_peft_model(shared.model, config)
|
||||||
|
except:
|
||||||
|
yield traceback.format_exc()
|
||||||
|
return
|
||||||
|
|
||||||
trainer = transformers.Trainer(
|
trainer = transformers.Trainer(
|
||||||
model=lora_model,
|
model=lora_model,
|
||||||
train_dataset=train_data,
|
train_dataset=train_data,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue