Seq2Seq support (including FLAN-T5) (#1535)

---------

Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
This commit is contained in:
Vincent Brouwers 2023-04-26 03:39:04 +02:00 committed by GitHub
parent 95aa43b9c2
commit 92cdb4f22b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 24 deletions

View file

@ -11,7 +11,8 @@ import torch
import transformers
from accelerate import infer_auto_device_map, init_empty_weights
from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM,
AutoTokenizer, BitsAndBytesConfig, LlamaTokenizer)
AutoModelForSeq2SeqLM, AutoTokenizer,
BitsAndBytesConfig, LlamaTokenizer)
import modules.shared as shared
from modules import llama_attn_hijack
@ -55,7 +56,12 @@ def find_model_type(model_name):
elif any((k in model_name for k in ['gpt4chan', 'gpt-4chan'])):
return 'gpt4chan'
else:
return 'HF_generic'
config = AutoConfig.from_pretrained(f"{shared.args.model_dir}/{model_name}")
# Not a "catch all", but fairly accurate
if config.to_dict().get("is_encoder_decoder", False):
return 'HF_seq2seq'
else:
return 'HF_generic'
def load_model(model_name):
@ -66,6 +72,9 @@ def load_model(model_name):
if shared.model_type == 'chatglm':
LoaderClass = AutoModel
trust_remote_code = shared.args.trust_remote_code
elif shared.model_type == 'HF_seq2seq':
LoaderClass = AutoModelForSeq2SeqLM
trust_remote_code = False
else:
LoaderClass = AutoModelForCausalLM
trust_remote_code = False