Add flash-attention 2 for windows (#4235)
This commit is contained in:
parent
258d046218
commit
3345da2ea4
10 changed files with 139 additions and 110 deletions
25
one_click.py
25
one_click.py
|
@ -171,8 +171,19 @@ def install_webui():
|
|||
install_git = "conda install -y -k ninja git"
|
||||
install_pytorch = "python -m pip install torch torchvision torchaudio"
|
||||
|
||||
use_cuda118 = "N"
|
||||
if any((is_windows(), is_linux())) and choice == "A":
|
||||
install_pytorch = "python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118"
|
||||
if "USE_CUDA118" in os.environ:
|
||||
use_cuda118 = os.environ.get("USE_CUDA118", "").lower() in ("yes", "y", "trye", "1", "t", "on")
|
||||
else:
|
||||
# Ask for CUDA version if using NVIDIA
|
||||
print("\nWould you like to use CUDA 11.8 instead of 12.1? This is only necessary for older GPUs like Kepler.\nIf unsure, say \"N\".\n")
|
||||
use_cuda118 = input("Input (Y/N)> ").upper().strip('"\'').strip()
|
||||
while use_cuda118 not in 'YN':
|
||||
print("Invalid choice. Please try again.")
|
||||
use_cuda118 = input("Input> ").upper().strip('"\'').strip()
|
||||
|
||||
install_pytorch = f"python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/{'cu121' if use_cuda118 == 'N' else 'cu118'}"
|
||||
elif not is_macos() and choice == "B":
|
||||
if is_linux():
|
||||
install_pytorch = "python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.6"
|
||||
|
@ -189,7 +200,7 @@ def install_webui():
|
|||
|
||||
# Install CUDA libraries (this wasn't necessary for Pytorch before...)
|
||||
if choice == "A":
|
||||
run_cmd("conda install -y -c \"nvidia/label/cuda-11.8.0\" cuda-runtime", assert_success=True, environment=True)
|
||||
run_cmd(f"conda install -y -c \"nvidia/label/{'cuda-12.1.0' if use_cuda118 == 'N' else 'cuda-11.8.0'}\" cuda-runtime", assert_success=True, environment=True)
|
||||
|
||||
# Install the webui requirements
|
||||
update_requirements(initial_installation=True)
|
||||
|
@ -236,9 +247,11 @@ def update_requirements(initial_installation=False):
|
|||
elif initial_installation:
|
||||
print_big_message("Will not install extensions due to INSTALL_EXTENSIONS environment variable.")
|
||||
|
||||
# Detect the PyTorch version
|
||||
# Detect the Python and PyTorch versions
|
||||
torver = torch_version()
|
||||
is_cuda = '+cu' in torver # 2.0.1+cu118
|
||||
print(f"TORCH: {torver}")
|
||||
is_cuda = '+cu' in torver
|
||||
is_cuda118 = '+cu118' in torver # 2.1.0+cu118
|
||||
is_cuda117 = '+cu117' in torver # 2.0.1+cu117
|
||||
is_rocm = '+rocm' in torver # 2.0.1+rocm5.4.2
|
||||
is_intel = '+cxx11' in torver # 2.0.1a0+cxx11.abi
|
||||
|
@ -269,7 +282,9 @@ def update_requirements(initial_installation=False):
|
|||
print_big_message(f"Installing webui requirements from file: {requirements_file}")
|
||||
textgen_requirements = open(requirements_file).read().splitlines()
|
||||
if is_cuda117:
|
||||
textgen_requirements = [req.replace('+cu118', '+cu117').replace('torch2.1', 'torch2.0') for req in textgen_requirements]
|
||||
textgen_requirements = [req.replace('+cu121', '+cu117').replace('torch2.1', 'torch2.0') for req in textgen_requirements]
|
||||
elif is_cuda118:
|
||||
textgen_requirements = [req.replace('+cu121', '+cu118') for req in textgen_requirements]
|
||||
with open('temp_requirements.txt', 'w') as file:
|
||||
file.write('\n'.join(textgen_requirements))
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue