Daniel Liden

Blog / About Me / Photos / LLM Fine Tuning / Notes /

Troubleshooting Flash Attention Installation

I have repeatedly run into issues getting flash-attention working correctly with whatever version of PyTorch and CUDA I happen to be working with. I found a working pattern, at least for the platform I tend to be working on (Databricks). This note is a quick summary.

The Problem

I kept getting an "undefined symbol" error like this when trying to load a model with flash attention (or even just when importing the flash attention library).

Solution

The following approach worked.

  1. Verify CUDA version; install the right version of Torch.
  2. Clone the flash-attention library and install (don't just pip install)

So in the case of my most recent project:

%pip install --upgrade torch

was fine because it's compiled for cuda 12.

To install flash-attention:

%sh
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention
pip install . --no-build-isolation

We can then make sure everything works (without needing to take extra time to load a model, for example) like this:

import torch
print(torch.__version__)
print(torch.version.cuda)

import flash_attn
print(flash_attn.__version__)
2.2.2+cu121
12.1
2.5.7

What didn't work

I wasn't able to get any variety of pip install flash-attn working. This was regardless of the no build isolation flag; specific versions; etc.

Date: 2024-04-16 Tue 00:00

Emacs 29.3 (Org mode 9.6.15)