Don't install flash-attention on windows + cuda 11
This commit is contained in:
parent
0ced78fdfa
commit
2d97897a25
1 changed files with 3 additions and 0 deletions
|
@ -289,6 +289,9 @@ def update_requirements(initial_installation=False):
|
||||||
textgen_requirements = [req.replace('+cu121', '+cu117').replace('+cu122', '+cu117').replace('torch2.1', 'torch2.0') for req in textgen_requirements]
|
textgen_requirements = [req.replace('+cu121', '+cu117').replace('+cu122', '+cu117').replace('torch2.1', 'torch2.0') for req in textgen_requirements]
|
||||||
elif is_cuda118:
|
elif is_cuda118:
|
||||||
textgen_requirements = [req.replace('+cu121', '+cu118').replace('+cu122', '+cu118') for req in textgen_requirements]
|
textgen_requirements = [req.replace('+cu121', '+cu118').replace('+cu122', '+cu118') for req in textgen_requirements]
|
||||||
|
if is_windows() and (is_cuda117 or is_cuda118): # No flash-attention on Windows for CUDA 11
|
||||||
|
textgen_requirements = [req for req in textgen_requirements if 'bdashore3/flash-attention' not in req]
|
||||||
|
|
||||||
with open('temp_requirements.txt', 'w') as file:
|
with open('temp_requirements.txt', 'w') as file:
|
||||||
file.write('\n'.join(textgen_requirements))
|
file.write('\n'.join(textgen_requirements))
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue