Add LoRA support
This commit is contained in:
parent
ee164d1821
commit
104293f411
6 changed files with 51 additions and 8 deletions
|
@ -101,6 +101,7 @@ def get_download_links_from_huggingface(model, branch):
|
|||
classifications = []
|
||||
has_pytorch = False
|
||||
has_safetensors = False
|
||||
is_lora = False
|
||||
while True:
|
||||
content = requests.get(f"{base}{page}{cursor.decode()}").content
|
||||
|
||||
|
@ -110,8 +111,10 @@ def get_download_links_from_huggingface(model, branch):
|
|||
|
||||
for i in range(len(dict)):
|
||||
fname = dict[i]['path']
|
||||
if not is_lora and fname.endswith(('adapter_config.json', 'adapter_model.bin')):
|
||||
is_lora = True
|
||||
|
||||
is_pytorch = re.match("pytorch_model.*\.bin", fname)
|
||||
is_pytorch = re.match("(pytorch|adapter)_model.*\.bin", fname)
|
||||
is_safetensors = re.match("model.*\.safetensors", fname)
|
||||
is_tokenizer = re.match("tokenizer.*\.model", fname)
|
||||
is_text = re.match(".*\.(txt|json)", fname) or is_tokenizer
|
||||
|
@ -130,6 +133,7 @@ def get_download_links_from_huggingface(model, branch):
|
|||
has_pytorch = True
|
||||
classifications.append('pytorch')
|
||||
|
||||
|
||||
cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50'
|
||||
cursor = base64.b64encode(cursor)
|
||||
cursor = cursor.replace(b'=', b'%3D')
|
||||
|
@ -140,7 +144,7 @@ def get_download_links_from_huggingface(model, branch):
|
|||
if classifications[i] == 'pytorch':
|
||||
links.pop(i)
|
||||
|
||||
return links
|
||||
return links, is_lora
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = args.MODEL
|
||||
|
@ -159,15 +163,16 @@ if __name__ == '__main__':
|
|||
except ValueError as err_branch:
|
||||
print(f"Error: {err_branch}")
|
||||
sys.exit()
|
||||
|
||||
links, is_lora = get_download_links_from_huggingface(model, branch)
|
||||
base_folder = 'models' if not is_lora else 'loras'
|
||||
if branch != 'main':
|
||||
output_folder = Path("models") / (model.split('/')[-1] + f'_{branch}')
|
||||
output_folder = Path(base_folder) / (model.split('/')[-1] + f'_{branch}')
|
||||
else:
|
||||
output_folder = Path("models") / model.split('/')[-1]
|
||||
output_folder = Path(base_folder) / model.split('/')[-1]
|
||||
if not output_folder.exists():
|
||||
output_folder.mkdir()
|
||||
|
||||
links = get_download_links_from_huggingface(model, branch)
|
||||
|
||||
# Downloading the files
|
||||
print(f"Downloading the model to {output_folder}")
|
||||
pool = multiprocessing.Pool(processes=args.threads)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue