winglian commited on
Commit
8f2b591
1 Parent(s): 5787e1a

set torch version to what is installed during axolotl install (#1234)

Browse files
Files changed (1) hide show
  1. setup.py +2 -1
setup.py CHANGED
@@ -27,6 +27,7 @@ def parse_requirements():
27
 
28
  try:
29
  torch_version = version("torch")
 
30
  if torch_version.startswith("2.1."):
31
  _install_requires.pop(_install_requires.index("xformers==0.0.22"))
32
  _install_requires.append("xformers>=0.0.23")
@@ -50,7 +51,7 @@ setup(
50
  dependency_links=dependency_links,
51
  extras_require={
52
  "flash-attn": [
53
- "flash-attn==2.3.3",
54
  ],
55
  "fused-dense-lib": [
56
  "fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib",
 
27
 
28
  try:
29
  torch_version = version("torch")
30
+ _install_requires.append(f"torch=={torch_version}")
31
  if torch_version.startswith("2.1."):
32
  _install_requires.pop(_install_requires.index("xformers==0.0.22"))
33
  _install_requires.append("xformers>=0.0.23")
 
51
  dependency_links=dependency_links,
52
  extras_require={
53
  "flash-attn": [
54
+ "flash-attn==2.5.0",
55
  ],
56
  "fused-dense-lib": [
57
  "fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib",