File size: 1,738 Bytes
8b617cc
 
37293dc
77fca25
3355706
 
 
 
 
34c0a86
3355706
 
 
 
 
c25ba79
 
 
 
 
 
3355706
 
8d288a2
 
 
 
 
7f2027d
8d288a2
 
3355706
 
 
 
 
77fca25
 
2bc1a5b
772cd87
6c5fbe6
 
2bc1a5b
77fca25
 
3355706
77fca25
cf66547
a045db0
cf66547
c25ba79
2bc1a5b
 
77fca25
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
"""setup.py for axolotl"""

from setuptools import find_packages, setup


def parse_requirements():
    _install_requires = []
    _dependency_links = []
    with open("./requirements.txt", encoding="utf-8") as requirements_file:
        lines = [r.strip() for r in requirements_file.readlines()]
        for line in lines:
            if line.startswith("--extra-index-url"):
                # Handle custom index URLs
                _, url = line.split()
                _dependency_links.append(url)
            elif (
                "flash-attn" not in line
                and "deepspeed" not in line
                and line
                and line[0] != "#"
            ):
                # Handle standard packages
                _install_requires.append(line)

    # TODO(wing) remove once xformers release supports torch 2.1.0
    if "torch==2.1.0" in _install_requires:
        _install_requires.pop(_install_requires.index("xformers>=0.0.22"))
        _install_requires.append(
            "xformers @ git+https://github.com/facebookresearch/xformers.git@main"
        )

    return _install_requires, _dependency_links


install_requires, dependency_links = parse_requirements()


setup(
    name="axolotl",
    version="0.3.0",
    description="LLM Trainer",
    long_description="Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.",
    package_dir={"": "src"},
    packages=find_packages(),
    install_requires=install_requires,
    dependency_links=dependency_links,
    extras_require={
        "flash-attn": [
            "flash-attn>=2.3.0",
        ],
        "deepspeed": [
            "deepspeed",
        ],
    },
)