You are attempting to use Flash Attention 2.0 without specifying a torch dtype.

#23
by chrislevy - opened

I am specifying the torch_dtype.

from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
import torch

# We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.

model = Qwen2VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2-VL-7B-Instruct",
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    device_map="auto",
)
# default processor
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")

But it says during inference

You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour

Inference works. Just not really sure under the hood if everything is running properly.

I have installed flash attention in my container

cuda_version = "12.4.0"  # should be no greater than host CUDA version
flavor = "devel"  #  includes full CUDA toolkit
operating_sys = "ubuntu22.04"
tag = f"{cuda_version}-{flavor}-{operating_sys}"
image = (
    modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.11")
    .apt_install("git")
    .pip_install(
        "ninja",  # required to build flash-attn
        "packaging",  # required to build flash-attn
        "wheel",  # required to build flash-attn
    )
    .run_commands(
        "pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124",
        "pip install git+https://github.com/huggingface/transformers",
        "pip install accelerate",
        "pip install qwen-vl-utils",
        "pip install python-dotenv",
        f'huggingface-cli login --token {os.environ["HUGGING_FACE_ACCESS_TOKEN"]}',
    )
    .run_commands("pip install flash-attn --no-build-isolation")
)

It's probably because you didn't set torch_dtype in the vision config.

See this line here

If you run model.config.vision_config, what is the value for torch_dtype?

You can probably fix this by doing

model = Qwen2VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2-VL-7B-Instruct",
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    device_map="auto",
    vision_config={"torch_dtype": torch.bfloat16}
)

Thanks for the suggestion! It is still logging that to the screen. Not sure if this warning can be ignored or if it really means it's not utilizing the flash attention.

Screenshot 2024-09-07 at 4.24.15 PM.png

The warning can be ignored. Flash attention 2 can only work on bfloat16 or float16, and it will result in an error if you are not using those dtypes. All of the weights are in bfloat16 in the repo, so they should stay that way when you call from_pretrained. Still, it's annoying that it comes up. There is some work being done to improve the attn_implementation for composite models like Qwen2VL, so hopefully this will be addressed.

If you want to be absolutely sure, you can do attn_implementation="eager" and check if the memory consumption is higher and the speed is slower

Okay thanks for the quick replies!

chrislevy changed discussion status to closed

Sign up or log in to comment