diff --git a/download-model.py b/download-model.py index 880eeb4..cd5a3f3 100644 --- a/download-model.py +++ b/download-model.py @@ -2,7 +2,7 @@ Downloads models from Hugging Face to models/model-name. Example: -python download-model.py facebook/opt-1.3b +python download_model.py facebook/opt-1.3b ''' @@ -19,6 +19,7 @@ import requests import tqdm from tqdm.contrib.concurrent import thread_map + parser = argparse.ArgumentParser() parser.add_argument('MODEL', type=str, default=None, nargs='?') parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.') @@ -30,40 +31,6 @@ parser.add_argument('--check', action='store_true', help='Validates the checksum args = parser.parse_args() -def get_file(url, output_folder): - filename = Path(url.rsplit('/', 1)[1]) - output_path = output_folder / filename - if output_path.exists() and not args.clean: - # Check if the file has already been downloaded completely - r = requests.get(url, stream=True) - total_size = int(r.headers.get('content-length', 0)) - if output_path.stat().st_size >= total_size: - return - # Otherwise, resume the download from where it left off - headers = {'Range': f'bytes={output_path.stat().st_size}-'} - mode = 'ab' - else: - headers = {} - mode = 'wb' - - r = requests.get(url, stream=True, headers=headers) - with open(output_path, mode) as f: - total_size = int(r.headers.get('content-length', 0)) - block_size = 1024 - with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t: - for data in r.iter_content(block_size): - t.update(len(data)) - f.write(data) - - -def sanitize_branch_name(branch_name): - pattern = re.compile(r"^[a-zA-Z0-9._-]+$") - if pattern.match(branch_name): - return branch_name - else: - raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.") - - def select_model_from_default_options(): models = { "OPT 6.7B": ("facebook", "opt-6.7b", "main"), @@ -110,7 +77,20 @@ EleutherAI/pythia-1.4b-deduped return model, branch -def get_download_links_from_huggingface(model, branch): +def sanitize_model_and_branch_names(model, branch): + if model[-1] == '/': + model = model[:-1] + if branch is None: + branch = "main" + else: + pattern = re.compile(r"^[a-zA-Z0-9._-]+$") + if not pattern.match(branch): + raise ValueError("Invalid branch name. Only alphanumeric characters, period, underscore and dash are allowed.") + + return model, branch + + +def get_download_links_from_huggingface(model, branch, text_only=False): base = "https://huggingface.co" page = f"/api/models/{model}/tree/{branch}?cursor=" cursor = b"" @@ -149,7 +129,7 @@ def get_download_links_from_huggingface(model, branch): links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}") classifications.append('text') continue - if not args.text_only: + if not text_only: links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}") if is_safetensors: has_safetensors = True @@ -177,80 +157,114 @@ def get_download_links_from_huggingface(model, branch): return links, sha256, is_lora -def download_files(file_list, output_folder, num_threads=8): - thread_map(lambda url: get_file(url, output_folder), file_list, max_workers=num_threads, disable=True) - - -if __name__ == '__main__': - model = args.MODEL - branch = args.branch - if model is None: - model, branch = select_model_from_default_options() - else: - if model[-1] == '/': - model = model[:-1] - branch = args.branch - if branch is None: - branch = "main" - else: - try: - branch = sanitize_branch_name(branch) - except ValueError as err_branch: - print(f"Error: {err_branch}") - sys.exit() - - links, sha256, is_lora = get_download_links_from_huggingface(model, branch) - - if args.output is not None: - base_folder = args.output - else: +def get_output_folder(model, branch, is_lora, base_folder=None): + if base_folder is None: base_folder = 'models' if not is_lora else 'loras' output_folder = f"{'_'.join(model.split('/')[-2:])}" if branch != 'main': output_folder += f'_{branch}' output_folder = Path(base_folder) / output_folder + return output_folder + + +def get_single_file(url, output_folder, start_from_scratch=False): + filename = Path(url.rsplit('/', 1)[1]) + output_path = output_folder / filename + if output_path.exists() and not start_from_scratch: + # Check if the file has already been downloaded completely + r = requests.get(url, stream=True) + total_size = int(r.headers.get('content-length', 0)) + if output_path.stat().st_size >= total_size: + return + # Otherwise, resume the download from where it left off + headers = {'Range': f'bytes={output_path.stat().st_size}-'} + mode = 'ab' + else: + headers = {} + mode = 'wb' + + r = requests.get(url, stream=True, headers=headers) + with open(output_path, mode) as f: + total_size = int(r.headers.get('content-length', 0)) + block_size = 1024 + with tqdm.tqdm(total=total_size, unit='iB', unit_scale=True, bar_format='{l_bar}{bar}| {n_fmt:6}/{total_fmt:6} {rate_fmt:6}') as t: + for data in r.iter_content(block_size): + t.update(len(data)) + f.write(data) + + +def start_download_threads(file_list, output_folder, start_from_scratch=False, threads=1): + thread_map(lambda url: get_single_file(url, output_folder, start_from_scratch=start_from_scratch), file_list, max_workers=threads, disable=True) + + +def download_model_files(model, branch, links, sha256, output_folder, start_from_scratch=False, threads=1): + # Creating the folder and writing the metadata + if not output_folder.exists(): + output_folder.mkdir() + with open(output_folder / 'huggingface-metadata.txt', 'w') as f: + f.write(f'url: https://huggingface.co/{model}\n') + f.write(f'branch: {branch}\n') + f.write(f'download date: {str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}\n') + sha256_str = '' + for i in range(len(sha256)): + sha256_str += f' {sha256[i][1]} {sha256[i][0]}\n' + if sha256_str != '': + f.write(f'sha256sum:\n{sha256_str}') + + # Downloading the files + print(f"Downloading the model to {output_folder}") + start_download_threads(links, output_folder, start_from_scratch=start_from_scratch, threads=threads) + + +def check_model_files(model, branch, links, sha256, output_folder): + # Validate the checksums + validated = True + for i in range(len(sha256)): + fpath = (output_folder / sha256[i][0]) + + if not fpath.exists(): + print(f"The following file is missing: {fpath}") + validated = False + continue + + with open(output_folder / sha256[i][0], "rb") as f: + bytes = f.read() + file_hash = hashlib.sha256(bytes).hexdigest() + if file_hash != sha256[i][1]: + print(f'Checksum failed: {sha256[i][0]} {sha256[i][1]}') + validated = False + else: + print(f'Checksum validated: {sha256[i][0]} {sha256[i][1]}') + + if validated: + print('[+] Validated checksums of all model files!') + else: + print('[-] Invalid checksums. Rerun download-model.py with the --clean flag.') + + +if __name__ == '__main__': + branch = args.branch + model = args.MODEL + if model is None: + model, branch = select_model_from_default_options() + + # Cleaning up the model/branch names + try: + model, branch = sanitize_model_and_branch_names(model, branch) + except ValueError as err_branch: + print(f"Error: {err_branch}") + sys.exit() + + # Getting the download links from Hugging Face + links, sha256, is_lora = get_download_links_from_huggingface(model, branch, text_only=args.text_only) + + # Getting the output folder + output_folder = get_output_folder(model, branch, is_lora, base_folder=args.output) if args.check: - # Validate the checksums - validated = True - for i in range(len(sha256)): - fpath = (output_folder / sha256[i][0]) - - if not fpath.exists(): - print(f"The following file is missing: {fpath}") - validated = False - continue - - with open(output_folder / sha256[i][0], "rb") as f: - bytes = f.read() - file_hash = hashlib.sha256(bytes).hexdigest() - if file_hash != sha256[i][1]: - print(f'Checksum failed: {sha256[i][0]} {sha256[i][1]}') - validated = False - else: - print(f'Checksum validated: {sha256[i][0]} {sha256[i][1]}') - - if validated: - print('[+] Validated checksums of all model files!') - else: - print('[-] Invalid checksums. Rerun download-model.py with the --clean flag.') - + # Check previously downloaded files + check_model_files(model, branch, links, sha256, output_folder) else: - - # Creating the folder and writing the metadata - if not output_folder.exists(): - output_folder.mkdir() - with open(output_folder / 'huggingface-metadata.txt', 'w') as f: - f.write(f'url: https://huggingface.co/{model}\n') - f.write(f'branch: {branch}\n') - f.write(f'download date: {str(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}\n') - sha256_str = '' - for i in range(len(sha256)): - sha256_str += f' {sha256[i][1]} {sha256[i][0]}\n' - if sha256_str != '': - f.write(f'sha256sum:\n{sha256_str}') - - # Downloading the files - print(f"Downloading the model to {output_folder}") - download_files(links, output_folder, args.threads) + # Download files + download_model_files(model, branch, links, sha256, output_folder, threads=args.threads)