You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Main motivation for this is to enable online softmax. In online softmax, for input of dimension MxV, we use V as both reduction (on the first loop to get max and sum) and parallel dimension (to apply broadcast-multiply and div of the computed max/sum on original source).
A simpler kernel that does the same could look like:
One way we can solve this is by having a different dim_scaling and dim_tile_size for Reduction. for example we can set dim_scaling for V and a new generated V_tile differently, and inside the reduction graph, we will use V_tile instead of V. A partial solution is done in here. Although that implementation is still missing a piece which is to get context of [(reduction, expanded on v_tile)] and expansion on V dim.
As seen in this error message:
KeyError: (Reduction(graph=<torch.fx.graph.Graph object at 0x723e571a2c90>, fx_node=reduction, tkw_op_name='reduction', _tracing_function=<bound method define_op.<locals>.decorator.<locals>.new_function of ...>, axis=N, init_args=[register_0_0_0], subgraph_name='region_0', implicit_captures=[a]), ((M, 0),))
The text was updated successfully, but these errors were encountered:
Main motivation for this is to enable online softmax. In online softmax, for input of dimension
MxV
, we useV
as both reduction (on the first loop to get max and sum) and parallel dimension (to apply broadcast-multiply and div of the computed max/sum on original source).A simpler kernel that does the same could look like:
Runnable form can be found here
One way we can solve this is by having a different dim_scaling and dim_tile_size for Reduction. for example we can set dim_scaling for
V
and a new generatedV_tile
differently, and inside the reduction graph, we will useV_tile
instead ofV
. A partial solution is done in here. Although that implementation is still missing a piece which is to get context of [(reduction, expanded on v_tile)] and expansion onV
dim.As seen in this error message:
The text was updated successfully, but these errors were encountered: