From 3d854ee5167152cd17f3cc9f5196ae6606abd213 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 4 Jan 2024 23:50:23 -0300 Subject: [PATCH] Pin PyTorch version to 2.1 (#5056) --- one_click.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/one_click.py b/one_click.py index 75f7247..ccb4cb2 100644 --- a/one_click.py +++ b/one_click.py @@ -89,6 +89,7 @@ def torch_version(): torver = [line for line in torch_version_file if '__version__' in line][0].split('__version__ = ')[1].strip("'") else: from torch import __version__ as torver + return torver @@ -203,7 +204,7 @@ def install_webui(): # Find the proper Pytorch installation command install_git = "conda install -y -k ninja git" - install_pytorch = "python -m pip install torch torchvision torchaudio" + install_pytorch = "python -m pip install torch==2.1.* torchvision==0.16.* torchaudio==2.1.* " use_cuda118 = "N" if any((is_windows(), is_linux())) and selected_gpu == "NVIDIA": @@ -219,20 +220,20 @@ def install_webui(): if use_cuda118 == 'Y': print("CUDA: 11.8") + install_pytorch += "--index-url https://download.pytorch.org/whl/cu118" else: print("CUDA: 12.1") - - install_pytorch = f"python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/{'cu121' if use_cuda118 == 'N' else 'cu118'}" + install_pytorch += "--index-url https://download.pytorch.org/whl/cu121" elif not is_macos() and selected_gpu == "AMD": if is_linux(): - install_pytorch = "python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.6" + install_pytorch += "--index-url https://download.pytorch.org/whl/rocm5.6" else: print("AMD GPUs are only supported on Linux. Exiting...") sys.exit(1) elif is_linux() and selected_gpu in ["APPLE", "NONE"]: - install_pytorch = "python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu" + install_pytorch += "--index-url https://download.pytorch.org/whl/cpu" elif selected_gpu == "INTEL": - install_pytorch = "python -m pip install torch==2.1.0a0 torchvision==0.16.0a0 intel_extension_for_pytorch==2.1.10+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/" + install_pytorch += "intel_extension_for_pytorch==2.1.* --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/" # Install Git and then Pytorch print_big_message("Installing PyTorch.")