From 7a3ca2c68f1ca49ac4e4b62f016718556fd3805c Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sat, 23 Sep 2023 13:04:27 -0700 Subject: [PATCH] Better detect EXL2 models --- modules/models_settings.py | 2 ++ modules/ui_model_menu.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/modules/models_settings.py b/modules/models_settings.py index bc3ace6..537bf0a 100644 --- a/modules/models_settings.py +++ b/modules/models_settings.py @@ -76,6 +76,8 @@ def infer_loader(model_name, model_settings): loader = 'llama.cpp' elif re.match(r'.*rwkv.*\.pth', model_name.lower()): loader = 'RWKV' + elif re.match(r'.*exl2', model_name.lower()): + loader = 'ExLlamav2_HF' else: loader = 'Transformers' diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py index f965d80..78ac545 100644 --- a/modules/ui_model_menu.py +++ b/modules/ui_model_menu.py @@ -251,7 +251,7 @@ def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), retur def update_truncation_length(current_length, state): - if state['loader'] in ['ExLlama', 'ExLlama_HF']: + if state['loader'].lower().startswith('exllama'): return state['max_seq_len'] elif state['loader'] in ['llama.cpp', 'llamacpp_HF', 'ctransformers']: return state['n_ctx']