Improve the imports

This commit is contained in:
oobabooga 2023-02-23 14:41:42 -03:00
parent 364529d0c7
commit 7224343a70
10 changed files with 30 additions and 29 deletions

View file

@ -3,6 +3,7 @@
Converts a transformers model to a format compatible with flexgen.
'''
import argparse
import os
from pathlib import Path
@ -10,9 +11,8 @@ from pathlib import Path
import numpy as np
import torch
from tqdm import tqdm
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
parser.add_argument('MODEL', type=str, default=None, nargs='?', help="Path to the input model.")
args = parser.parse_args()
@ -31,7 +31,6 @@ def disable_torch_init():
torch_layer_norm_init_backup = torch.nn.LayerNorm.reset_parameters
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
def restore_torch_init():
"""Rollback the change made by disable_torch_init."""
import torch