From ddce0c1334536dd04c523ccce08928f3611d2627 Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Sun, 30 Jun 2024 22:06:46 -0700 Subject: [PATCH] Fix MambaChunkScanCombinedFn returning incorrect number of gradients --- mamba_ssm/__init__.py | 2 +- mamba_ssm/ops/triton/ssd_combined.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mamba_ssm/__init__.py b/mamba_ssm/__init__.py index 6e126d2d..b1b96a32 100644 --- a/mamba_ssm/__init__.py +++ b/mamba_ssm/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.2.0" +__version__ = "2.2.1" from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn from mamba_ssm.modules.mamba_simple import Mamba diff --git a/mamba_ssm/ops/triton/ssd_combined.py b/mamba_ssm/ops/triton/ssd_combined.py index 1305cfb4..77d20715 100644 --- a/mamba_ssm/ops/triton/ssd_combined.py +++ b/mamba_ssm/ops/triton/ssd_combined.py @@ -556,7 +556,7 @@ def backward(ctx, dout, *args): assert not ctx.return_varlen_states, "return_varlen_states is not supported in backward" dfinal_states = args[0] if ctx.return_final_states else None dx, ddt, dA, dB, dC, dD, dz, ddt_bias, dinitial_states = _mamba_chunk_scan_combined_bwd(dout, x, dt, A, B, C, out, ctx.chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, dfinal_states=dfinal_states, seq_idx=seq_idx, dt_softplus=ctx.dt_softplus, dt_limit=ctx.dt_limit) - return dx, ddt, dA, dB, dC, None, dD, dz, ddt_bias, dinitial_states, None, None, None, None + return dx, ddt, dA, dB, dC, None, dD, dz, ddt_bias, dinitial_states, None, None, None, None, None, None def mamba_chunk_scan_combined(x, dt, A, B, C, chunk_size, D=None, z=None, dt_bias=None, initial_states=None, seq_idx=None, cu_seqlens=None, dt_softplus=False, dt_limit=(0.0, float("inf")), return_final_states=False, return_varlen_states=False):