diff --git a/chemlactica/utils/model_utils.py b/chemlactica/utils/model_utils.py index 72891e7..5937666 100644 --- a/chemlactica/utils/model_utils.py +++ b/chemlactica/utils/model_utils.py @@ -8,6 +8,22 @@ from transformers import BitsAndBytesConfig +def float_casting_decorator(layer_class): + class FloatCastingLayer(layer_class): + def __init__(self, *args, **kwargs): + super(FloatCastingLayer, self).__init__(*args, **kwargs) + + def forward( + self, + x, + *args, + **kwargs, + ): + return super().forward(x, *args, **kwargs).to(torch.float32) + + return FloatCastingLayer + + quant_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=False, @@ -74,6 +90,13 @@ def load_model( model = OPTForCausalLM.from_pretrained( from_pretrained, torch_dtype=dtype, attn_implementation=attn_implementation ) + print(type(model.lm_head)) + model.lm_head = float_casting_decorator(model.lm_head.__class__)( + in_features=model.lm_head.in_features, + out_features=model.lm_head.out_features, + ) + # model.lm_head.forward = cast_to_fp32(OPTForCausalLM.lm_head.forward) + if "mistral" in from_pretrained.lower(): model = MistralForCausalLM.from_pretrained( from_pretrained, diff --git a/test_status.yaml b/test_status.yaml index 02df6a0..c9925f9 100644 --- a/test_status.yaml +++ b/test_status.yaml @@ -1 +1 @@ -3fad7942cb631ed45c48c1be3c9b3992aaf312a7: PASS +fe40dade26e27de4bd050161752291203ce9d39a: PASS