Make the code more like PEP8 for readability (#862)

This commit is contained in:
oobabooga 2023-04-07 00:15:45 -03:00 committed by GitHub
parent 848c4edfd5
commit ea6e77df72
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
28 changed files with 302 additions and 165 deletions

View file

@ -13,10 +13,11 @@ import torch
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog, max_help_position=54))
parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")
args = parser.parse_args()
def disable_torch_init():
"""
Disable the redundant torch default initialization to accelerate model creation.
@ -31,20 +32,22 @@ def disable_torch_init():
torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
def restore_torch_init():
"""Rollback the change made by disable_torch_init."""
import torch
setattr(torch.nn.Linear, "reset_parameters", torch_linear_init_backup)
setattr(torch.nn.LayerNorm, "reset_parameters", torch_layer_norm_init_backup)
if __name__ == '__main__':
path = Path(args.MODEL)
model_name = path.name
print(f"Loading {model_name}...")
#disable_torch_init()
# disable_torch_init()
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
#restore_torch_init()
# restore_torch_init()
tokenizer = AutoTokenizer.from_pretrained(path)