Seq2Seq support (including FLAN-T5) (#1535)
--------- Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
This commit is contained in:
parent
95aa43b9c2
commit
92cdb4f22b
2 changed files with 30 additions and 24 deletions
|
@ -11,7 +11,8 @@ import torch
|
|||
import transformers
|
||||
from accelerate import infer_auto_device_map, init_empty_weights
|
||||
from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM,
|
||||
AutoTokenizer, BitsAndBytesConfig, LlamaTokenizer)
|
||||
AutoModelForSeq2SeqLM, AutoTokenizer,
|
||||
BitsAndBytesConfig, LlamaTokenizer)
|
||||
|
||||
import modules.shared as shared
|
||||
from modules import llama_attn_hijack
|
||||
|
@ -55,7 +56,12 @@ def find_model_type(model_name):
|
|||
elif any((k in model_name for k in ['gpt4chan', 'gpt-4chan'])):
|
||||
return 'gpt4chan'
|
||||
else:
|
||||
return 'HF_generic'
|
||||
config = AutoConfig.from_pretrained(f"{shared.args.model_dir}/{model_name}")
|
||||
# Not a "catch all", but fairly accurate
|
||||
if config.to_dict().get("is_encoder_decoder", False):
|
||||
return 'HF_seq2seq'
|
||||
else:
|
||||
return 'HF_generic'
|
||||
|
||||
|
||||
def load_model(model_name):
|
||||
|
@ -66,6 +72,9 @@ def load_model(model_name):
|
|||
if shared.model_type == 'chatglm':
|
||||
LoaderClass = AutoModel
|
||||
trust_remote_code = shared.args.trust_remote_code
|
||||
elif shared.model_type == 'HF_seq2seq':
|
||||
LoaderClass = AutoModelForSeq2SeqLM
|
||||
trust_remote_code = False
|
||||
else:
|
||||
LoaderClass = AutoModelForCausalLM
|
||||
trust_remote_code = False
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue