-
Notifications
You must be signed in to change notification settings - Fork 242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Keep rope at float32 precision #1497
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we also need to cast back to the compute_dtype
before returning. Only the computation part needs to happen in float32
.
@@ -95,7 +95,7 @@ def _apply_rope(self, x, positions): | |||
max_wavelength = 10000 | |||
x_shape = ops.shape(x) | |||
freq_exponents = (2.0 / x_shape[-1]) * ops.cast( | |||
ops.arange(x_shape[-1] // 2, dtype="float32"), self.compute_dtype | |||
ops.arange(x_shape[-1] // 2, dtype="float32"), "float32" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can remove the ops.cast
call here; everything should be float32
.
@mattdangerw Gemma still downcasts the tensors to |
Yeah looks like this is causing test failures, probably due to this issue? |
Hi :) I'm assuming this came about from my Twitter thread https://twitter.com/danielhanchen/status/1765446273661075609 :) I added a fix into transformers 4.38.2 here: huggingface/transformers#29285. So using Also another problematic line is https://github.com/keras-team/keras-nlp/blob/v0.8.2/keras_nlp/models/gemma/gemma_attention.py#L159 seq_len = ops.shape(x)[1]
start_index = cache_update_index
positions = ops.cast(
ops.arange(seq_len, dtype="float32"), >>>>> self.compute_dtype <<<<
)
positions = positions + ops.cast(start_index, self.compute_dtype) Which is wrong - Assume if someone did RoPE Scaling with float16 - this will cause 65504 to be the maximum, which in turn causes overflow ie infinities to occur. bfloat16 loses precision, but can represent larger numbers. |
Hi @danielhanchen, enjoyed reading the blogpost it was a great in depth dive! Switched all of RoPE to happen in "float32" and added downcasting before returning. This likely works until we replace the call with normal Keras RoPE? |
Yeah, this should be good enough for now. We can merge this and I can rebase my PR on top of your changes. |
@@ -92,10 +92,13 @@ def build(self, inputs_shape): | |||
def _apply_rope(self, x, positions): | |||
"""Rope rotate q or k.""" | |||
# TODO: refactor to use RotaryEmbedding layer? | |||
x = ops.cast( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we really want x
in float32? that feels a little awkward, given that we are going to downcast immediately. seems reasonable to keep the sin
and cos
line at full precision (though i have no idea if that's important or not).
but this line seems like it should probably be done in compute_dtype
=> ops.stack([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1)
so
sin = ops.cast(ops.sin(radians), self.compute_dtype)
cos = ops.cast(ops.cos(radians), self.compute_dtype)
x1, x2 = ops.split(x, 2, axis=-1)
return ops.stack([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1)
incidentally, that would mean that if we ever did cache the rotary vectors, that would mean we are caching them at float32
and applying them at the correct dtype for the model
wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, casting x is a bit awkward. I guess as long as radians/positions/timescale are in float32 initially we loose the numerical inacuracy.
Dug more into the other submited fix for sin/cos in float32 huggingface/transformers#29285 (comment) - and it seems like this is the right call here (downcasting before we stack/multiply).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sweet! This LGTM! Will pull it in.
* Keep rope at float32 precision * Carry out all of RoPE in float32 * Formatting * Cleanup * Do not cast x
No description provided.