From bd531c2dc2e37d89c92d330cd9593a253ce16638 Mon Sep 17 00:00:00 2001 From: Mylo <36931363+gitmylo@users.noreply.github.com> Date: Thu, 4 May 2023 07:01:28 +0200 Subject: [PATCH] Make --trust-remote-code work for all models (#1772) --- modules/models.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/modules/models.py b/modules/models.py index ff61a4f..8151c5e 100644 --- a/modules/models.py +++ b/modules/models.py @@ -57,7 +57,7 @@ def find_model_type(model_name): elif any((k in model_name_lower for k in ['gpt4chan', 'gpt-4chan'])): return 'gpt4chan' else: - config = AutoConfig.from_pretrained(Path(f'{shared.args.model_dir}/{model_name}')) + config = AutoConfig.from_pretrained(Path(f'{shared.args.model_dir}/{model_name}'), trust_remote_code=shared.args.trust_remote_code) # Not a "catch all", but fairly accurate if config.to_dict().get("is_encoder_decoder", False): return 'HF_seq2seq' @@ -70,15 +70,13 @@ def load_model(model_name): t0 = time.time() shared.model_type = find_model_type(model_name) + trust_remote_code = shared.args.trust_remote_code if shared.model_type == 'chatglm': LoaderClass = AutoModel - trust_remote_code = shared.args.trust_remote_code elif shared.model_type == 'HF_seq2seq': LoaderClass = AutoModelForSeq2SeqLM - trust_remote_code = False else: LoaderClass = AutoModelForCausalLM - trust_remote_code = False # Load the model in simple 16-bit mode by default if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.wbits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.model_type in ['rwkv', 'llamacpp']]):