Skip to content

Commit

Permalink
add flash attention
Browse files Browse the repository at this point in the history
  • Loading branch information
gayanechilingar committed Jan 26, 2024
1 parent d72434a commit 7fe60e1
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 82 deletions.
187 changes: 111 additions & 76 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -1,140 +1,175 @@
name: cl11.7
name: cl11.8_t4.37
channels:
- pytorch
- nvidia
- nvidia/label/cuda-11.8.0
- defaults
- conda-forge
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- blas=1.0=mkl
- bzip2=1.0.8=h7b6447c_0
- ca-certificates=2023.08.22=h06a4308_0
- cuda-cudart=11.7.99=0
- cuda-cupti=11.7.101=0
- cuda-libraries=11.7.1=0
- cuda-nvrtc=11.7.99=0
- cuda-nvtx=11.7.91=0
- cuda-runtime=11.7.1=0
- ca-certificates=2023.12.12=h06a4308_0
- cuda=11.8.0=0
- cuda-cccl=11.8.89=0
- cuda-command-line-tools=11.8.0=0
- cuda-compiler=11.8.0=0
- cuda-cudart=11.8.89=0
- cuda-cudart-dev=11.8.89=0
- cuda-cuobjdump=11.8.86=0
- cuda-cupti=11.8.87=0
- cuda-cuxxfilt=11.8.86=0
- cuda-demo-suite=11.8.86=0
- cuda-documentation=11.8.86=0
- cuda-driver-dev=11.8.89=0
- cuda-gdb=11.8.86=0
- cuda-libraries=11.8.0=0
- cuda-libraries-dev=11.8.0=0
- cuda-memcheck=11.8.86=0
- cuda-nsight=11.8.86=0
- cuda-nsight-compute=11.8.0=0
- cuda-nvcc=11.8.89=0
- cuda-nvdisasm=11.8.86=0
- cuda-nvml-dev=11.8.86=0
- cuda-nvprof=11.8.87=0
- cuda-nvprune=11.8.86=0
- cuda-nvrtc=11.8.89=0
- cuda-nvrtc-dev=11.8.89=0
- cuda-nvtx=11.8.86=0
- cuda-nvvp=11.8.87=0
- cuda-profiler-api=11.8.86=0
- cuda-runtime=11.8.0=0
- cuda-sanitizer-api=11.8.86=0
- cuda-toolkit=11.8.0=0
- cuda-tools=11.8.0=0
- cuda-visual-tools=11.8.0=0
- filelock=3.13.1=py310h06a4308_0
- gds-tools=1.4.0.31=0
- gmp=6.2.1=h295c915_3
- gmpy2=2.1.2=py310heeb90bb_0
- intel-openmp=2023.1.0=hdb19cb5_46305
- intel-openmp=2023.1.0=hdb19cb5_46306
- jinja2=3.1.2=py310h06a4308_0
- ld_impl_linux-64=2.38=h1181459_1
- libcublas=11.10.3.66=0
- libcufft=10.7.2.124=h4fbf590_0
- libcufile=1.7.2.10=0
- libcurand=10.3.3.141=0
- libcusolver=11.4.0.1=0
- libcusparse=11.7.4.91=0
- libcublas=11.11.3.6=0
- libcublas-dev=11.11.3.6=0
- libcufft=10.9.0.58=0
- libcufft-dev=10.9.0.58=0
- libcufile=1.8.1.2=0
- libcufile-dev=1.4.0.31=0
- libcurand=10.3.4.107=0
- libcurand-dev=10.3.0.86=0
- libcusolver=11.4.1.48=0
- libcusolver-dev=11.4.1.48=0
- libcusparse=11.7.5.86=0
- libcusparse-dev=11.7.5.86=0
- libffi=3.4.4=h6a678d5_0
- libgcc-ng=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libnpp=11.7.4.75=0
- libnvjpeg=11.8.0.2=0
- libnpp=11.8.0.86=0
- libnpp-dev=11.8.0.86=0
- libnvjpeg=11.9.0.86=0
- libnvjpeg-dev=11.9.0.86=0
- libstdcxx-ng=11.2.0=h1234567_1
- libuuid=1.41.5=h5eee18b_0
- markupsafe=2.1.1=py310h7f8727e_0
- mkl=2023.1.0=h213fc3f_46343
- llvm-openmp=14.0.6=h9e868ea_0
- markupsafe=2.1.3=py310h5eee18b_0
- mkl=2023.1.0=h213fc3f_46344
- mpc=1.1.0=h10f8cd9_1
- mpfr=4.0.2=hb69a4c5_1
- mpmath=1.3.0=py310h06a4308_0
- ncurses=6.4=h6a678d5_0
- networkx=3.1=py310h06a4308_0
- openssl=3.0.11=h7f8727e_2
- pip=23.2.1=py310h06a4308_0
- nsight-compute=2022.3.0.22=0
- openssl=3.0.12=h7f8727e_0
- pip=23.3.1=py310h06a4308_0
- python=3.10.13=h955ad1f_0
- pytorch=2.0.1=py3.10_cuda11.7_cudnn8.5.0_0
- pytorch-cuda=11.7=h778d358_5
- pytorch=2.1.0=py3.10_cuda11.8_cudnn8.7.0_0
- pytorch-cuda=11.8=h7e8668a_5
- pytorch-mutex=1.0=cuda
- pyyaml=6.0.1=py310h5eee18b_0
- readline=8.2=h5eee18b_0
- setuptools=68.0.0=py310h06a4308_0
- setuptools=68.2.2=py310h06a4308_0
- sqlite=3.41.2=h5eee18b_0
- sympy=1.11.1=py310h06a4308_0
- sympy=1.12=py310h06a4308_0
- tbb=2021.8.0=hdb19cb5_0
- tk=8.6.12=h1ccaba5_0
- torchtriton=2.0.0=py310
- typing_extensions=4.7.1=py310h06a4308_0
- torchtriton=2.1.0=py310
- typing_extensions=4.9.0=py310h06a4308_0
- wheel=0.41.2=py310h06a4308_0
- xz=5.4.2=h5eee18b_0
- xz=5.4.5=h5eee18b_0
- yaml=0.2.5=h7b6447c_0
- zlib=1.2.13=h5eee18b_0
- pip:
- accelerate==0.21.0
- accelerate==0.26.1
- aim==3.17.5
- aim-ui==3.17.5
- aimrecords==0.0.7
- aimrocks==0.4.0
- aiofiles==23.2.1
- aiohttp==3.8.6
- aiohttp==3.9.1
- aiosignal==1.3.1
- alembic==1.12.0
- alembic==1.13.1
- annotated-types==0.6.0
- anyio==3.7.1
- anyio==4.2.0
- async-timeout==4.0.3
- attrs==23.1.0
- attrs==23.2.0
- backoff==2.2.1
- base58==2.0.1
- cachetools==5.3.1
- certifi==2023.7.22
- cachetools==5.3.2
- certifi==2023.11.17
- cffi==1.16.0
- cfgv==3.4.0
- charset-normalizer==3.3.0
- charset-normalizer==3.3.2
- click==8.1.7
- cryptography==41.0.4
- datasets==2.14.4
- cryptography==41.0.7
- datasets==2.14.7
- dill==0.3.7
- distlib==0.3.7
- einops==0.7.0
- exceptiongroup==1.1.3
- fastapi==0.103.2
- filelock==3.13.1
- flash-attn==2.3.2
- frozenlist==1.4.0
- fsspec==2023.9.2
- greenlet==3.0.0
- grpcio==1.59.0
- exceptiongroup==1.2.0
- fastapi==0.109.0
- flash-attn==2.4.2
- frozenlist==1.4.1
- fsspec==2023.10.0
- greenlet==3.0.3
- grpcio==1.60.0
- h11==0.14.0
- huggingface-hub==0.17.3
- identify==2.5.31
- idna==3.4
- mako==1.2.4
- idna==3.6
- mako==1.3.0
- monotonic==1.6
- multidict==6.0.4
- multiprocess==0.70.15
- ninja==1.11.1.1
- nodeenv==1.8.0
- numpy==1.26.0
- numpy==1.26.3
- packaging==23.2
- pandas==2.1.1
- pillow==10.0.1
- platformdirs==3.11.0
- pre-commit==3.5.0
- protobuf==4.24.4
- psutil==5.9.5
- pandas==2.1.4
- pillow==10.2.0
- protobuf==4.25.2
- psutil==5.9.7
- py3nvml==0.2.7
- pyarrow==13.0.0
- pyarrow==14.0.2
- pyarrow-hotfix==0.6
- pycparser==2.21
- pydantic==2.4.2
- pydantic-core==2.10.1
- pydantic==2.5.3
- pydantic-core==2.14.6
- python-dateutil==2.8.2
- pytz==2023.3.post1
- pyyaml==6.0.1
- regex==2023.10.3
- regex==2023.12.25
- requests==2.31.0
- restrictedpython==6.2
- safetensors==0.4.0
- restrictedpython==7.0
- safetensors==0.4.1
- segment-analytics-python==2.2.3
- six==1.16.0
- sniffio==1.3.0
- sqlalchemy==1.4.49
- starlette==0.27.0
- tokenizers==0.13.3
- sqlalchemy==1.4.51
- starlette==0.35.1
- tokenizers==0.14.1
- tqdm==4.66.1
- transformers==4.31.0
- tzdata==2023.3
- urllib3==2.0.6
- uvicorn==0.23.2
- virtualenv==20.24.6
- transformers==4.35.0
- tzdata==2023.4
- urllib3==2.1.0
- uvicorn==0.25.0
- xmltodict==0.13.0
- xxhash==3.4.1
- yarl==1.9.2
prefix: /home/philipp/miniconda3/envs/cl11.7
- yarl==1.9.4
prefix: /auto/home/gayane/miniforge3/envs/cl11.8_t4.37
3 changes: 1 addition & 2 deletions src/config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@ distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch_policy: BACKWARD_PRE
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_offload_params: false
fsdp_sharding_strategy: 1
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: CustomOPTForCausalLM
machine_rank: 0
main_training_function: main
mixed_precision: bf16
Expand Down
7 changes: 3 additions & 4 deletions src/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import bitsandbytes as bnb
from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training
import torch
from custom_modeling_opt import CustomOPTForCausalLM
from transformers import BitsAndBytesConfig
from custom_modeling_opt import OPTForCausalLM


quant_config = BitsAndBytesConfig(
Expand Down Expand Up @@ -58,7 +58,6 @@ def load_model(
)
)
if "galactica" in from_pretrained.lower():
model = CustomOPTForCausalLM.from_pretrained(
from_pretrained, use_flash_attn=use_flash_attn, torch_dtype=dtype
)
model = OPTForCausalLM.from_pretrained(
from_pretrained, use_flash_attn=use_flash_attn, torch_dtype=dtype, attn_implementation="flash_attention_2")
return model

0 comments on commit 7fe60e1

Please sign in to comment.