Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keep rope at float32 precision #1497

Merged
merged 5 commits into from
Mar 13, 2024
Merged

Conversation

grasskin
Copy link
Member

@grasskin grasskin commented Mar 7, 2024

No description provided.

Copy link
Contributor

@tirthasheshpatel tirthasheshpatel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we also need to cast back to the compute_dtype before returning. Only the computation part needs to happen in float32.

@@ -95,7 +95,7 @@ def _apply_rope(self, x, positions):
max_wavelength = 10000
x_shape = ops.shape(x)
freq_exponents = (2.0 / x_shape[-1]) * ops.cast(
ops.arange(x_shape[-1] // 2, dtype="float32"), self.compute_dtype
ops.arange(x_shape[-1] // 2, dtype="float32"), "float32"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove the ops.cast call here; everything should be float32.

@tirthasheshpatel
Copy link
Contributor

tirthasheshpatel commented Mar 7, 2024

@mattdangerw Gemma still downcasts the tensors to compute_dtype since it uses it's own implementation of RoPE. I can submit a follow-up PR to use this layer instead.

@mattdangerw
Copy link
Member

mattdangerw commented Mar 8, 2024

I think we also need to cast back to the compute_dtype before returning. Only the computation part needs to happen in float32.

Yeah looks like this is causing test failures, probably due to this issue?

@danielhanchen
Copy link

danielhanchen commented Mar 9, 2024

Hi :) I'm assuming this came about from my Twitter thread https://twitter.com/danielhanchen/status/1765446273661075609 :)

I added a fix into transformers 4.38.2 here: huggingface/transformers#29285. So using mixed_bfloat16 causes torch.autocast to cast all ops to bfloat16. I don't normally use Keras, so unsure if torch.autocast affects operations, since I know even explicitly forcing float32 causes autocast to override it. However, unsure on Keras.

Also another problematic line is https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/gemma_attention.py#L159

        seq_len = ops.shape(x)[1]
        start_index = cache_update_index
        positions = ops.cast(
            ops.arange(seq_len, dtype="float32"), >>>>> self.compute_dtype <<<<
        )
        positions = positions + ops.cast(start_index, self.compute_dtype)

Which is wrong - Assume if someone did RoPE Scaling with float16 - this will cause 65504 to be the maximum, which in turn causes overflow ie infinities to occur. bfloat16 loses precision, but can represent larger numbers.

@grasskin
Copy link
Member Author

Hi @danielhanchen, enjoyed reading the blogpost it was a great in depth dive!

Switched all of RoPE to happen in "float32" and added downcasting before returning. This likely works until we replace the call with normal Keras RoPE?

@tirthasheshpatel
Copy link
Contributor

Switched all of RoPE to happen in "float32" and added downcasting before returning. This likely works until we replace the call with normal Keras RoPE?

Yeah, this should be good enough for now. We can merge this and I can rebase my PR on top of your changes.

@@ -92,10 +92,13 @@ def build(self, inputs_shape):
def _apply_rope(self, x, positions):
"""Rope rotate q or k."""
# TODO: refactor to use RotaryEmbedding layer?
x = ops.cast(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we really want x in float32? that feels a little awkward, given that we are going to downcast immediately. seems reasonable to keep the sin and cos line at full precision (though i have no idea if that's important or not).

but this line seems like it should probably be done in compute_dtype => ops.stack([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1)

so

sin = ops.cast(ops.sin(radians), self.compute_dtype)
cos = ops.cast(ops.cos(radians), self.compute_dtype)
x1, x2 = ops.split(x, 2, axis=-1)
return ops.stack([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1)

incidentally, that would mean that if we ever did cache the rotary vectors, that would mean we are caching them at float32 and applying them at the correct dtype for the model

wdyt?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, casting x is a bit awkward. I guess as long as radians/positions/timescale are in float32 initially we loose the numerical inacuracy.

Dug more into the other submited fix for sin/cos in float32 huggingface/transformers#29285 (comment) - and it seems like this is the right call here (downcasting before we stack/multiply).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sweet! This LGTM! Will pull it in.

@mattdangerw mattdangerw merged commit 09d2fdd into keras-team:master Mar 13, 2024
10 checks passed
@grasskin grasskin mentioned this pull request Mar 21, 2024
abuelnasr0 pushed a commit to abuelnasr0/keras-nlp that referenced this pull request Apr 2, 2024
* Keep rope at float32 precision

* Carry out all of RoPE in float32

* Formatting

* Cleanup

* Do not cast x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants