diff --git a/modules/training.py b/modules/training.py index 1f8e5e5..c98fded 100644 --- a/modules/training.py +++ b/modules/training.py @@ -445,9 +445,9 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch def generate_prompt(data_point: dict[str, str]): for options, data in format_data.items(): - if set(options.split(',')) == set(x[0] for x in data_point.items() if (x[1] is not None and len(x[1].strip()) > 0)): + if set(options.split(',')) == set(x[0] for x in data_point.items() if (type(x[1]) is str and len(x[1].strip()) > 0)): for key, val in data_point.items(): - if val is not None: + if type(val) is str: data = data.replace(f'%{key}%', val) return data raise RuntimeError(f'Data-point "{data_point}" has no keyset match within format "{list(format_data.keys())}"')