Add flash-attention 2 for windows (#4235)

This commit is contained in:
Brian Dashore 2023-10-21 02:46:23 -04:00 committed by GitHub
parent 258d046218
commit 3345da2ea4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 139 additions and 110 deletions

View file

@ -171,8 +171,19 @@ def install_webui():
install_git = "conda install -y -k ninja git"
install_pytorch = "python -m pip install torch torchvision torchaudio"
use_cuda118 = "N"
if any((is_windows(), is_linux())) and choice == "A":
install_pytorch = "python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118"
if "USE_CUDA118" in os.environ:
use_cuda118 = os.environ.get("USE_CUDA118", "").lower() in ("yes", "y", "trye", "1", "t", "on")
else:
# Ask for CUDA version if using NVIDIA
print("\nWould you like to use CUDA 11.8 instead of 12.1? This is only necessary for older GPUs like Kepler.\nIf unsure, say \"N\".\n")
use_cuda118 = input("Input (Y/N)> ").upper().strip('"\'').strip()
while use_cuda118 not in 'YN':
print("Invalid choice. Please try again.")
use_cuda118 = input("Input> ").upper().strip('"\'').strip()
install_pytorch = f"python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/{'cu121' if use_cuda118 == 'N' else 'cu118'}"
elif not is_macos() and choice == "B":
if is_linux():
install_pytorch = "python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.6"
@ -189,7 +200,7 @@ def install_webui():
# Install CUDA libraries (this wasn't necessary for Pytorch before...)
if choice == "A":
run_cmd("conda install -y -c \"nvidia/label/cuda-11.8.0\" cuda-runtime", assert_success=True, environment=True)
run_cmd(f"conda install -y -c \"nvidia/label/{'cuda-12.1.0' if use_cuda118 == 'N' else 'cuda-11.8.0'}\" cuda-runtime", assert_success=True, environment=True)
# Install the webui requirements
update_requirements(initial_installation=True)
@ -236,9 +247,11 @@ def update_requirements(initial_installation=False):
elif initial_installation:
print_big_message("Will not install extensions due to INSTALL_EXTENSIONS environment variable.")
# Detect the PyTorch version
# Detect the Python and PyTorch versions
torver = torch_version()
is_cuda = '+cu' in torver # 2.0.1+cu118
print(f"TORCH: {torver}")
is_cuda = '+cu' in torver
is_cuda118 = '+cu118' in torver # 2.1.0+cu118
is_cuda117 = '+cu117' in torver # 2.0.1+cu117
is_rocm = '+rocm' in torver # 2.0.1+rocm5.4.2
is_intel = '+cxx11' in torver # 2.0.1a0+cxx11.abi
@ -269,7 +282,9 @@ def update_requirements(initial_installation=False):
print_big_message(f"Installing webui requirements from file: {requirements_file}")
textgen_requirements = open(requirements_file).read().splitlines()
if is_cuda117:
textgen_requirements = [req.replace('+cu118', '+cu117').replace('torch2.1', 'torch2.0') for req in textgen_requirements]
textgen_requirements = [req.replace('+cu121', '+cu117').replace('torch2.1', 'torch2.0') for req in textgen_requirements]
elif is_cuda118:
textgen_requirements = [req.replace('+cu121', '+cu118') for req in textgen_requirements]
with open('temp_requirements.txt', 'w') as file:
file.write('\n'.join(textgen_requirements))