Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TKW] Allow separate distribution/expansion of tiled in reduction dimension and parallel dimension #169

Open
raikonenfnu opened this issue Sep 26, 2024 · 0 comments

Comments

@raikonenfnu
Copy link
Contributor

raikonenfnu commented Sep 26, 2024

Main motivation for this is to enable online softmax. In online softmax, for input of dimension MxV, we use V as both reduction (on the first loop to get max and sum) and parallel dimension (to apply broadcast-multiply and div of the computed max/sum on original source).
Screenshot 2024-09-25 at 10 06 53 PM

A simpler kernel that does the same could look like:


    constraints: list[tkw.Constraint] = [
        tkw.HardwareConstraint(
            threads_per_wave=64,
            waves_per_block=(1, 1, 1),
            vector_shapes={M: 1, N: BLOCK_N},
        )
    ]
    constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)]
    constraints += [tkw.WorkgroupConstraint(N, N, 0)]
    constraints += [tkw.TilingConstraint(N, BLOCK_N)]
    constraints += [tkw.WaveConstraint(M, BLOCK_M)]
    constraints += [tkw.WaveConstraint(N, BLOCK_N)]

    @tkw.wave(constraints)
    def test(
        a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32],
        c: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32],
    ):
        init_max = tkl.Register[M, tkl.f32](-1e6)
        @tkw.reduction(N, init_args=[init_max])
        def repeat(
            partial_max: tkl.Register[M, tkl.f32],
        ) -> tkl.Register[M, tkl.f32]:
            lhs = tkw.read(a, elements_per_thread=ELEMS_PER_THREAD)
            partial_max = tkw.max(lhs, partial_max, dim=N)
            return partial_max
        lhs = tkw.read(a, elements_per_thread=ELEMS_PER_THREAD)
        result = lhs + repeat
        tkw.write(result, c, elements_per_thread=1)

Runnable form can be found here

One way we can solve this is by having a different dim_scaling and dim_tile_size for Reduction. for example we can set dim_scaling for V and a new generated V_tile differently, and inside the reduction graph, we will use V_tile instead of V. A partial solution is done in here. Although that implementation is still missing a piece which is to get context of [(reduction, expanded on v_tile)] and expansion on V dim.

As seen in this error message:

KeyError: (Reduction(graph=<torch.fx.graph.Graph object at 0x723e571a2c90>, fx_node=reduction, tkw_op_name='reduction', _tracing_function=<bound method define_op.<locals>.decorator.<locals>.new_function of ...>, axis=N, init_args=[register_0_0_0], subgraph_name='region_0', implicit_captures=[a]), ((M, 0),))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant