Skip to content

Commit

Permalink
set flashattention as optional dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
imoneoi committed Aug 1, 2023
1 parent bdbdab1 commit 4f32ca9
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 10 deletions.
15 changes: 8 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,16 +114,20 @@ We will release the evaluation results as soon as they become available, so stay

## <a id="installation"></a> Installation

To use OpenChat, you need to install CUDA and PyTorch, then install FlashAttention 1. After that you can install OpenChat via pip:
To use OpenChat, you need to install CUDA and PyTorch, then you can install OpenChat via pip:

```bash
pip3 install ochat
```

If you want to train models, please also install FlashAttention 1.

```bash
pip3 install packaging ninja
pip3 install --no-build-isolation "flash-attn<2"

pip3 install ochat
```

FlashAttention may have compatibility issues. If you encounter these problems, you can try to create a new `conda` environment following the instructions below.
FlashAttention and vLLM may have compatibility issues. If you encounter these problems, you can try to create a new `conda` environment following the instructions below.

```bash
conda create -y --name openchat
Expand All @@ -146,9 +150,6 @@ pip3 install ochat
git clone https://github.com/imoneoi/openchat
cd openchat

pip3 install packaging ninja
pip3 install --no-build-isolation "flash-attn<2"

pip3 install --upgrade pip # enable PEP 660 support
pip3 install -e .
```
Expand Down
7 changes: 5 additions & 2 deletions ochat/models/unpadded_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@
from transformers.utils import logging
from transformers.models.llama.configuration_llama import LlamaConfig

from flash_attn.flash_attn_interface import flash_attn_unpadded_func
from flash_attn.bert_padding import pad_input
try:
from flash_attn.flash_attn_interface import flash_attn_unpadded_func
from flash_attn.bert_padding import pad_input
except ImportError:
print ("FlashAttention not found. Install it if you need to train models.")


logger = logging.get_logger(__name__)
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ dependencies = [
"sentencepiece",
"transformers",
"accelerate",
"flash-attn<2",
"protobuf<3.21",
"fastapi",
"pydantic",
Expand Down

0 comments on commit 4f32ca9

Please sign in to comment.