Skip to content

Commit

Permalink
Merge pull request #6 from YerevaNN/fp32_cast
Browse files Browse the repository at this point in the history
cast lm_head output to fp32
  • Loading branch information
philippguevorguian authored Feb 21, 2024
2 parents 35253f6 + cc5674f commit 5067680
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
23 changes: 23 additions & 0 deletions chemlactica/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,22 @@
from transformers import BitsAndBytesConfig


def float_casting_decorator(layer_class):
class FloatCastingLayer(layer_class):
def __init__(self, *args, **kwargs):
super(FloatCastingLayer, self).__init__(*args, **kwargs)

def forward(
self,
x,
*args,
**kwargs,
):
return super().forward(x, *args, **kwargs).to(torch.float32)

return FloatCastingLayer


quant_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=False,
Expand Down Expand Up @@ -74,6 +90,13 @@ def load_model(
model = OPTForCausalLM.from_pretrained(
from_pretrained, torch_dtype=dtype, attn_implementation=attn_implementation
)
print(type(model.lm_head))
model.lm_head = float_casting_decorator(model.lm_head.__class__)(
in_features=model.lm_head.in_features,
out_features=model.lm_head.out_features,
)
# model.lm_head.forward = cast_to_fp32(OPTForCausalLM.lm_head.forward)

if "mistral" in from_pretrained.lower():
model = MistralForCausalLM.from_pretrained(
from_pretrained,
Expand Down
2 changes: 1 addition & 1 deletion test_status.yaml
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3fad7942cb631ed45c48c1be3c9b3992aaf312a7: PASS
fe40dade26e27de4bd050161752291203ce9d39a: PASS

0 comments on commit 5067680

Please sign in to comment.