Add FlexGen support #92 (experimental)

This commit is contained in:
oobabooga 2023-02-21 21:00:06 -03:00
parent e52b697d5a
commit b83f51ee04
2 changed files with 123 additions and 20 deletions

View file

@ -45,6 +45,7 @@ parser.add_argument('--disk', action='store_true', help='If the model is too lar
parser.add_argument('--disk-cache-dir', type=str, help='Directory to save the disk cache to. Defaults to "cache/".')
parser.add_argument('--gpu-memory', type=int, help='Maximum GPU memory in GiB to allocate. This is useful if you get out of memory errors while trying to generate text. Must be an integer number.')
parser.add_argument('--cpu-memory', type=int, help='Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99.')
parser.add_argument('--flexgen', action='store_true', help='Enable the use of FlexGen offloading.')
parser.add_argument('--deepspeed', action='store_true', help='Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration.')
parser.add_argument('--nvme-offload-dir', type=str, help='DeepSpeed: Directory to use for ZeRO-3 NVME offloading.')
parser.add_argument('--local_rank', type=int, default=0, help='DeepSpeed: Optional argument for distributed setups.')
@ -86,6 +87,9 @@ if args.settings is not None and Path(args.settings).exists():
for item in new_settings:
settings[item] = new_settings[item]
if args.flexgen:
from flexgen.flex_opt import (Policy, OptLM, TorchDevice, TorchDisk, TorchMixedDevice, CompressionConfig, Env, Task, get_opt_config)
if args.deepspeed:
import deepspeed
from transformers.deepspeed import HfDeepSpeedConfig, is_deepspeed_zero3_enabled
@ -107,12 +111,39 @@ def load_model(model_name):
t0 = time.time()
# Default settings
if not (args.cpu or args.load_in_8bit or args.auto_devices or args.disk or args.gpu_memory is not None or args.cpu_memory is not None or args.deepspeed):
if not (args.cpu or args.load_in_8bit or args.auto_devices or args.disk or args.gpu_memory is not None or args.cpu_memory is not None or args.deepspeed or args.flexgen):
if any(size in model_name.lower() for size in ('13b', '20b', '30b')):
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), device_map='auto', load_in_8bit=True)
else:
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if args.bf16 else torch.float16).cuda()
# FlexGen
elif args.flexgen:
gpu = TorchDevice("cuda:0")
cpu = TorchDevice("cpu")
disk = TorchDisk("cache")
env = Env(gpu=gpu, cpu=cpu, disk=disk, mixed=TorchMixedDevice([gpu, cpu, disk]))
# Offloading policy
policy = Policy(1, 1,
100, 0,
100, 0,
100, 0,
overlap=True, sep_layer=True, pin_weight=True,
cpu_cache_compute=False, attn_sparsity=1.0,
compress_weight=False,
comp_weight_config=CompressionConfig(
num_bits=4, group_size=64,
group_dim=0, symmetric=False),
compress_cache=False,
comp_cache_config=CompressionConfig(
num_bits=4, group_size=64,
group_dim=2, symmetric=False))
opt_config = get_opt_config(f"facebook/{model_name}")
model = OptLM(opt_config, env, "models", policy)
model.init_all_weights()
# DeepSpeed ZeRO-3
elif args.deepspeed:
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), torch_dtype=torch.bfloat16 if args.bf16 else torch.float16)
@ -273,7 +304,7 @@ def get_max_prompt_length(tokens):
def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
input_ids = tokenizer.encode(str(prompt), return_tensors='pt', truncation=True, max_length=get_max_prompt_length(tokens_to_generate), add_special_tokens=add_special_tokens)
if args.cpu:
if args.cpu or args.flexgen:
return input_ids
elif args.deepspeed:
return input_ids.to(device=local_rank)
@ -315,7 +346,7 @@ def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top
print(f"\n\n{question}\n--------------------\n")
input_ids = encode(question, tokens)
cuda = "" if (args.cpu or args.deepspeed) else ".cuda()"
cuda = "" if (args.cpu or args.deepspeed or args.flexgen) else ".cuda()"
n = tokenizer.eos_token_id if eos_token is None else tokenizer.encode(eos_token, return_tensors='pt')[0][-1]
if stopping_string is not None:
# The stopping_criteria code below was copied from
@ -330,22 +361,28 @@ def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top
else:
stopping_criteria_list = None
generate_params = [
f"eos_token_id={n}",
f"stopping_criteria=stopping_criteria_list",
f"do_sample={do_sample}",
f"temperature={temperature}",
f"top_p={top_p}",
f"typical_p={typical_p}",
f"repetition_penalty={repetition_penalty}",
f"top_k={top_k}",
f"min_length={min_length if args.no_stream else 0}",
f"no_repeat_ngram_size={no_repeat_ngram_size}",
f"num_beams={num_beams}",
f"penalty_alpha={penalty_alpha}",
f"length_penalty={length_penalty}",
f"early_stopping={early_stopping}",
]
if not args.flexgen:
generate_params = [
f"eos_token_id={n}",
f"stopping_criteria=stopping_criteria_list",
f"do_sample={do_sample}",
f"temperature={temperature}",
f"top_p={top_p}",
f"typical_p={typical_p}",
f"repetition_penalty={repetition_penalty}",
f"top_k={top_k}",
f"min_length={min_length if args.no_stream else 0}",
f"no_repeat_ngram_size={no_repeat_ngram_size}",
f"num_beams={num_beams}",
f"penalty_alpha={penalty_alpha}",
f"length_penalty={length_penalty}",
f"early_stopping={early_stopping}",
]
else:
generate_params = [
f"do_sample={do_sample}",
f"temperature={temperature}",
]
if args.deepspeed:
generate_params.append("synced_gpus=True")
@ -391,7 +428,10 @@ def generate_reply(question, tokens, do_sample, max_new_tokens, temperature, top
reply = original_question + apply_extensions(reply[len(question):], "output")
yield formatted_outputs(reply, model_name)
input_ids = torch.reshape(output, (1, output.shape[0]))
if not args.flexgen:
input_ids = torch.reshape(output, (1, output.shape[0]))
else:
input_ids = np.reshape(output, (1, output.shape[0]))
if soft_prompt:
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)