Skip to content

Commit

Permalink
Address comments #1
Browse files Browse the repository at this point in the history
Signed-off-by: Harsh Menon <[email protected]>
  • Loading branch information
harsh-nod committed Aug 29, 2024
1 parent f923e01 commit 9f4232a
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions shark_turbine/kernel/wave/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,13 @@ def compute_access_pattern_using_vector_shapes(

def apply(
self,
mma_index: int,
constraint_index: int,
dim: IndexSymbol,
elements_per_thread: int | IndexSymbol,
) -> IndexSequence:
if self.vector_shapes is not None:
return self.compute_access_pattern_using_vector_shapes(
dim, mma_index, elements_per_thread
dim, constraint_index, elements_per_thread
)
lane = self.linearized_thread_id
match self.mma_type:
Expand All @@ -146,7 +146,9 @@ def apply(
1, # K
]
return IndexSequence(
offset[mma_index], size[mma_index], stride[mma_index]
offset[constraint_index],
size[constraint_index],
stride[constraint_index],
)


Expand Down

0 comments on commit 9f4232a

Please sign in to comment.