Add LoRA support

This commit is contained in:
oobabooga 2023-03-16 21:31:39 -03:00
parent ee164d1821
commit 104293f411
6 changed files with 51 additions and 8 deletions

View file

@ -101,6 +101,7 @@ def get_download_links_from_huggingface(model, branch):
classifications = []
has_pytorch = False
has_safetensors = False
is_lora = False
while True:
content = requests.get(f"{base}{page}{cursor.decode()}").content
@ -110,8 +111,10 @@ def get_download_links_from_huggingface(model, branch):
for i in range(len(dict)):
fname = dict[i]['path']
if not is_lora and fname.endswith(('adapter_config.json', 'adapter_model.bin')):
is_lora = True
is_pytorch = re.match("pytorch_model.*\.bin", fname)
is_pytorch = re.match("(pytorch|adapter)_model.*\.bin", fname)
is_safetensors = re.match("model.*\.safetensors", fname)
is_tokenizer = re.match("tokenizer.*\.model", fname)
is_text = re.match(".*\.(txt|json)", fname) or is_tokenizer
@ -130,6 +133,7 @@ def get_download_links_from_huggingface(model, branch):
has_pytorch = True
classifications.append('pytorch')
cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50'
cursor = base64.b64encode(cursor)
cursor = cursor.replace(b'=', b'%3D')
@ -140,7 +144,7 @@ def get_download_links_from_huggingface(model, branch):
if classifications[i] == 'pytorch':
links.pop(i)
return links
return links, is_lora
if __name__ == '__main__':
model = args.MODEL
@ -159,15 +163,16 @@ if __name__ == '__main__':
except ValueError as err_branch:
print(f"Error: {err_branch}")
sys.exit()
links, is_lora = get_download_links_from_huggingface(model, branch)
base_folder = 'models' if not is_lora else 'loras'
if branch != 'main':
output_folder = Path("models") / (model.split('/')[-1] + f'_{branch}')
output_folder = Path(base_folder) / (model.split('/')[-1] + f'_{branch}')
else:
output_folder = Path("models") / model.split('/')[-1]
output_folder = Path(base_folder) / model.split('/')[-1]
if not output_folder.exists():
output_folder.mkdir()
links = get_download_links_from_huggingface(model, branch)
# Downloading the files
print(f"Downloading the model to {output_folder}")
pool = multiprocessing.Pool(processes=args.threads)