From fe40dade26e27de4bd050161752291203ce9d39a Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Wed, 21 Feb 2024 18:58:21 +0400 Subject: [PATCH 1/3] cast decoder forward method logits to fp32 --- chemlactica/utils/model_utils.py | 4 +++- chemlactica/utils/utils.py | 8 ++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/chemlactica/utils/model_utils.py b/chemlactica/utils/model_utils.py index 72891e7..3c55048 100644 --- a/chemlactica/utils/model_utils.py +++ b/chemlactica/utils/model_utils.py @@ -1,5 +1,5 @@ from transformers import OPTForCausalLM, OPTConfig, MistralForCausalLM -from .utils import get_tokenizer_special_tokens +from .utils import get_tokenizer_special_tokens, cast_to_fp32 import bitsandbytes as bnb @@ -74,6 +74,8 @@ def load_model( model = OPTForCausalLM.from_pretrained( from_pretrained, torch_dtype=dtype, attn_implementation=attn_implementation ) + model.model.decoder.forward = cast_to_fp32(model.model.decoder.forward) + if "mistral" in from_pretrained.lower(): model = MistralForCausalLM.from_pretrained( from_pretrained, diff --git a/chemlactica/utils/utils.py b/chemlactica/utils/utils.py index 67b2ddc..fb4c0a6 100644 --- a/chemlactica/utils/utils.py +++ b/chemlactica/utils/utils.py @@ -1,4 +1,5 @@ import os +import torch import json from transformers import AutoTokenizer from functools import cache @@ -86,6 +87,13 @@ def remove_extraneous_args(args): delattr(args, "accelerate_eval_config_file") +def cast_to_fp32(func): + def wrapper(*args, **kwargs): + return func(*args, **kwargs).to(torch.float32) + + return wrapper + + if __name__ == "__main__": # import sys import glob From bcbd55d41adb2cfcf068b3c62208e6d35cd1b031 Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Wed, 21 Feb 2024 21:41:40 +0400 Subject: [PATCH 2/3] cast lm head to fp32 --- chemlactica/utils/model_utils.py | 25 +++++++++++++++++++++++-- chemlactica/utils/utils.py | 8 -------- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/chemlactica/utils/model_utils.py b/chemlactica/utils/model_utils.py index 3c55048..5937666 100644 --- a/chemlactica/utils/model_utils.py +++ b/chemlactica/utils/model_utils.py @@ -1,5 +1,5 @@ from transformers import OPTForCausalLM, OPTConfig, MistralForCausalLM -from .utils import get_tokenizer_special_tokens, cast_to_fp32 +from .utils import get_tokenizer_special_tokens import bitsandbytes as bnb @@ -8,6 +8,22 @@ from transformers import BitsAndBytesConfig +def float_casting_decorator(layer_class): + class FloatCastingLayer(layer_class): + def __init__(self, *args, **kwargs): + super(FloatCastingLayer, self).__init__(*args, **kwargs) + + def forward( + self, + x, + *args, + **kwargs, + ): + return super().forward(x, *args, **kwargs).to(torch.float32) + + return FloatCastingLayer + + quant_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=False, @@ -74,7 +90,12 @@ def load_model( model = OPTForCausalLM.from_pretrained( from_pretrained, torch_dtype=dtype, attn_implementation=attn_implementation ) - model.model.decoder.forward = cast_to_fp32(model.model.decoder.forward) + print(type(model.lm_head)) + model.lm_head = float_casting_decorator(model.lm_head.__class__)( + in_features=model.lm_head.in_features, + out_features=model.lm_head.out_features, + ) + # model.lm_head.forward = cast_to_fp32(OPTForCausalLM.lm_head.forward) if "mistral" in from_pretrained.lower(): model = MistralForCausalLM.from_pretrained( diff --git a/chemlactica/utils/utils.py b/chemlactica/utils/utils.py index fb4c0a6..67b2ddc 100644 --- a/chemlactica/utils/utils.py +++ b/chemlactica/utils/utils.py @@ -1,5 +1,4 @@ import os -import torch import json from transformers import AutoTokenizer from functools import cache @@ -87,13 +86,6 @@ def remove_extraneous_args(args): delattr(args, "accelerate_eval_config_file") -def cast_to_fp32(func): - def wrapper(*args, **kwargs): - return func(*args, **kwargs).to(torch.float32) - - return wrapper - - if __name__ == "__main__": # import sys import glob From cc5674fb1c9017f4f6b797fae30d8673f3c8b754 Mon Sep 17 00:00:00 2001 From: Philipp Guevorguian Date: Wed, 21 Feb 2024 21:45:36 +0400 Subject: [PATCH 3/3] add test_status --- test_status.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test_status.yaml b/test_status.yaml index 02df6a0..c9925f9 100644 --- a/test_status.yaml +++ b/test_status.yaml @@ -1 +1 @@ -3fad7942cb631ed45c48c1be3c9b3992aaf312a7: PASS +fe40dade26e27de4bd050161752291203ce9d39a: PASS