Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: NaN Values in fwd_prepare_wy_repr Output in GatedDeltaNet #99

Open
xffxff opened this issue Dec 26, 2024 · 7 comments
Open

[Bug]: NaN Values in fwd_prepare_wy_repr Output in GatedDeltaNet #99

xffxff opened this issue Dec 26, 2024 · 7 comments
Labels
bug Something isn't working

Comments

@xffxff
Copy link

xffxff commented Dec 26, 2024

Describe the bug

When using fla's GatedDeltaNet, I encountered an issue where the output of fwd_prepare_wy_repr contains NaN values. I have located the problem and can reproduce it with the following code:

Steps to reproduce the bug

from fla.ops.gated_delta_rule.wy_fast import fwd_prepare_wy_repr
from fla.ops.utils import chunk_local_cumsum
import torch


q = torch.load("q.pt", weights_only=True)
g = torch.load("g.pt", weights_only=True)
k = torch.load("k.pt", weights_only=True)
v = torch.load("v.pt", weights_only=True)
beta = torch.load("beta.pt", weights_only=True)

g = chunk_local_cumsum(g, 64, offsets=None, head_first=False)
w, u, Aw, Au = fwd_prepare_wy_repr(
    k=k,
    v=v,
    beta=beta,
    g=g,
    offsets=None,
    indices=None,
    head_first=False,
    chunk_size=64
)

assert not torch.isnan(u).any(), "u contains NaN values" 

The output u contains NaN values. The relevant input .pt files are included in the attached debug.tar.gz archive.

Expected behavior

The output u should not contain any NaN values.

Environment info

  1. torch: 2.5.1
  2. triton: 3.1.0
@xffxff xffxff added the bug Something isn't working label Dec 26, 2024
@prolearner
Copy link

prolearner commented Dec 26, 2024

I'm encountering an error which might be related. I get NaNs after hundreds training steps using the following config and pytorch 2.4.1 and triton 3.0.0

{ "attn_mode": "chunk", "bos_token_id": 1, "clamp_min": null, "eos_token_id": 2, "expand_k": 1, "expand_v": 1, "fuse_cross_entropy": true, "fuse_norm": true, "hidden_act": "swish", "hidden_ratio": 4, "hidden_size": 1024, "initializer_range": 0.02, "intermediate_size": null, "max_position_embeddings": 2048, "model_type": "gated_deltanet", "num_heads": 8, "head_dim": 128, "num_hidden_layers": 24, "norm_first": false, "norm_eps": 1e-06, "tie_word_embeddings": true, "use_cache": true, "vocab_size": 32000 }

@yzhangcs
Copy link
Collaborator

Thank u for the report, I'll take a look at this issue right now.
@prolearner Did you use varlen inputs as well?

@yzhangcs
Copy link
Collaborator

@xffxff @prolearner Hi, please check out cb36e67

@prolearner
Copy link

@xffxff @prolearner Hi, please check out cb36e67

Thanks for the quick fix!
It seems to work now: training went past 16k steps as I'm writing this.

@xffxff
Copy link
Author

xffxff commented Dec 27, 2024

@yzhangcs Thank you for the quick fix! I've tested the updated code locally, and the NaN issue has been resolved with the reproduction code I provided above. I'm now going to test it in my training job.

@xffxff
Copy link
Author

xffxff commented Dec 27, 2024

@yzhangcs During training, we encountered NaN values again. I've dumped the new data, and the issue can still be reproduced using the code I shared earlier.

@yzhangcs
Copy link
Collaborator

Thank you! we're checking the numerical stability of exp gate with high priority.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants