Allow full model URL to be used for download (#3919)
--------- Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
This commit is contained in:
parent
ed6b6411fb
commit
7c9664ed35
2 changed files with 12 additions and 7 deletions
|
@ -22,6 +22,9 @@ from requests.adapters import HTTPAdapter
|
|||
from tqdm.contrib.concurrent import thread_map
|
||||
|
||||
|
||||
base = "https://huggingface.co"
|
||||
|
||||
|
||||
class ModelDownloader:
|
||||
def __init__(self, max_retries=5):
|
||||
self.session = requests.Session()
|
||||
|
@ -37,6 +40,13 @@ class ModelDownloader:
|
|||
if model[-1] == '/':
|
||||
model = model[:-1]
|
||||
|
||||
if model.startswith(base + '/'):
|
||||
model = model[len(base) + 1:]
|
||||
|
||||
model_parts = model.split(":")
|
||||
model = model_parts[0] if len(model_parts) > 0 else model
|
||||
branch = model_parts[1] if len(model_parts) > 1 else branch
|
||||
|
||||
if branch is None:
|
||||
branch = "main"
|
||||
else:
|
||||
|
@ -48,7 +58,6 @@ class ModelDownloader:
|
|||
return model, branch
|
||||
|
||||
def get_download_links_from_huggingface(self, model, branch, text_only=False, specific_file=None):
|
||||
base = "https://huggingface.co"
|
||||
page = f"/api/models/{model}/tree/{branch}"
|
||||
cursor = b""
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue