File size: 3,396 Bytes
14ebd2e
 
 
1d5ab84
14ebd2e
2734e3f
14ebd2e
 
 
 
2734e3f
 
 
14ebd2e
 
 
3f11b47
14ebd2e
 
 
 
 
 
 
 
 
 
 
 
2734e3f
ece46b2
 
2734e3f
 
 
 
 
 
 
 
 
 
1d5ab84
 
 
8d6a289
1d5ab84
8d6a289
1d5ab84
8d6a289
2734e3f
 
 
 
 
 
 
 
 
 
 
 
21b7439
 
 
 
 
 
2734e3f
 
 
1d5ab84
1fc9b44
 
1d5ab84
2734e3f
c3d2562
21b7439
66fa751
0c46806
 
990bec6
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
ARG CUDA_VERSION="11.8.0"
ARG CUDNN_VERSION="8"
ARG UBUNTU_VERSION="22.04"
ARG MAX_JOBS=4

FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION as base-builder

ENV PATH="/root/miniconda3/bin:${PATH}"

ARG PYTHON_VERSION="3.9"
ARG PYTORCH="2.0.0"
ARG CUDA="cu118"

ENV PYTHON_VERSION=$PYTHON_VERSION

RUN apt-get update
RUN apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/*

RUN wget \
    https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
    && mkdir /root/.conda \
    && bash Miniconda3-latest-Linux-x86_64.sh -b \
    && rm -f Miniconda3-latest-Linux-x86_64.sh

RUN conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"

ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"

WORKDIR /workspace

RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
    python3 -m pip install --no-cache-dir -U torch==${PYTORCH} torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/$CUDA


FROM base-builder AS flash-attn-builder

WORKDIR /workspace

ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"

RUN git clone https://github.com/HazyResearch/flash-attention.git && \
    cd flash-attention && \
    python3 setup.py bdist_wheel && \
    cd csrc/fused_dense_lib && \
    python3 setup.py bdist_wheel && \
    cd ../xentropy && \
    python3 setup.py bdist_wheel && \
    cd ../rotary && \
    python3 setup.py bdist_wheel && \
    cd ../layer_norm && \
    python3 setup.py bdist_wheel

FROM base-builder AS deepspeed-builder

WORKDIR /workspace

RUN git clone https://github.com/microsoft/DeepSpeed.git && \
    cd DeepSpeed && \
    MAX_CONCURRENCY=8 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_OPS=1 python3 setup.py bdist_wheel

FROM base-builder

# recompile apex
RUN python3 -m pip uninstall -y apex
RUN git clone https://github.com/NVIDIA/apex
#  `MAX_JOBS=1` disables parallel building to avoid cpu memory OOM when building image on GitHub Action (standard) runners
RUN cd apex && MAX_JOBS=1 python3 -m pip install --global-option="--cpp_ext" --global-option="--cuda_ext" --no-cache -v --disable-pip-version-check .

RUN mkdir /workspace/wheels
COPY --from=deepspeed-builder /workspace/DeepSpeed/dist/deepspeed-*.whl wheels
COPY --from=flash-attn-builder /workspace/flash-attention/dist/flash_attn-*.whl wheels
COPY --from=flash-attn-builder /workspace/flash-attention/csrc/fused_dense_lib/dist/fused_dense_lib-*.whl wheels
COPY --from=flash-attn-builder /workspace/flash-attention/csrc/xentropy/dist/xentropy_cuda_lib-*.whl wheels
COPY --from=flash-attn-builder /workspace/flash-attention/csrc/rotary/dist/rotary_emb-*.whl wheels
COPY --from=flash-attn-builder /workspace/flash-attention/csrc/layer_norm/dist/dropout_layer_norm-*.whl wheels

RUN pip3 install wheels/deepspeed-*.whl wheels/flash_attn-*.whl wheels/fused_dense_lib-*.whl wheels/xentropy_cuda_lib-*.whl wheels/rotary_emb-*.whl wheels/dropout_layer_norm-*.whl
RUN git lfs install --skip-repo
RUN pip3 install "peft @ git+https://github.com/huggingface/peft.git@main" \
        "accelerate @ git+https://github.com/huggingface/accelerate.git@main" \
        "transformers @ git+https://github.com/huggingface/transformers.git@main" && \
    pip3 install awscli && \
    # The base image ships with `pydantic==1.8.2` which is not working
    pip3 install -U --no-cache-dir pydantic