From bb4cb2245373acb950e1c8dbaa73caf75920723d Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Fri, 24 Mar 2023 00:49:04 -0300 Subject: [PATCH] Download .pt files using download-model.py (for 4-bit models) --- download-model.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/download-model.py b/download-model.py index 7c2965f..7ca33b7 100644 --- a/download-model.py +++ b/download-model.py @@ -116,10 +116,11 @@ def get_download_links_from_huggingface(model, branch): is_pytorch = re.match("(pytorch|adapter)_model.*\.bin", fname) is_safetensors = re.match("model.*\.safetensors", fname) + is_pt = re.match(".*\.pt", fname) is_tokenizer = re.match("tokenizer.*\.model", fname) is_text = re.match(".*\.(txt|json|py)", fname) or is_tokenizer - if any((is_pytorch, is_safetensors, is_text, is_tokenizer)): + if any((is_pytorch, is_safetensors, is_pt, is_tokenizer, is_text)): if is_text: links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}") classifications.append('text') @@ -132,7 +133,8 @@ def get_download_links_from_huggingface(model, branch): elif is_pytorch: has_pytorch = True classifications.append('pytorch') - + elif is_pt: + classifications.append('pt') cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50' cursor = base64.b64encode(cursor)