Skip to content

Commit

Permalink
Add support for smem_epilogue when mma output is not cast to half (#3620
Browse files Browse the repository at this point in the history
)

Support non-stmatrix stores from regs to shared memory and then TMA when
the output of mma op is not cast back to half precision - stmatrix works
with half precision only.
  • Loading branch information
protonu authored Dec 25, 2024
1 parent db3576c commit ee63c98
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 53 deletions.
66 changes: 43 additions & 23 deletions csrc/scheduler/hopper_multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -508,8 +508,12 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
TensorView* d_smem = cacheAfter(dc, LoadStoreOpType::Set);

std::vector<TensorView*> tvs_to_schedule{d, d_smem};
if (std::find(mma_results_.begin(), mma_results_.end(), dc) ==
mma_results_.end()) {

bool dc_in_mma_results =
std::find(mma_results_.begin(), mma_results_.end(), dc) !=
mma_results_.end();

if (!dc_in_mma_results) {
// Skip scheduling dc if it is an mma_result. This can happen if we are
// not casting back to half-precision in the output
tvs_to_schedule.push_back(dc);
Expand All @@ -519,14 +523,13 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
dc->setMemoryType(MemoryType::Local);
d_smem->setMemoryType(MemoryType::Shared);

// Set LoadStoreOp
// TODO: extend support when mma is not cast to half
NVF_CHECK(
dataTypeSize(dc->dtype()) == 2,
"We support use_smem_epilogue on Hopper only when the output is 16-bit");
auto store_with_stmatrix = dataTypeSize(dc->dtype()) == 2;

d_smem->definition()->as<LoadStoreOp>()->setOpType(
LoadStoreOpType::StMatrix);
if (store_with_stmatrix) {
// Set LoadStoreOp
d_smem->definition()->as<LoadStoreOp>()->setOpType(
LoadStoreOpType::StMatrix);
}
d->definition()->as<LoadStoreOp>()->setOpType(
LoadStoreOpType::CpAsyncBulkTensorTile);

Expand All @@ -539,23 +542,40 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
transformLikeMmaOutput(tv, /*is_mma_result=*/false);
}

auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
dc->getLoopDomain());
dc->setLoopDomain(s.as<IterDomain*>());
dc->setAllocationDomain(s.as<IterDomain*>(), true);

scheduler_utils::BoundedDirectionalTransformPropagator::backward(
dc,
-1,
propagate_to,
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType());
// Should not propagate if the dc is a mma output as the mma output has
// already been scheduled.
if (!dc_in_mma_results) {
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
dc->getLoopDomain());
dc->setLoopDomain(s.as<IterDomain*>());
dc->setAllocationDomain(s.as<IterDomain*>(), true);

scheduler_utils::BoundedDirectionalTransformPropagator::backward(
dc,
-1,
propagate_to,
scheduler_utils::BoundedDirectionalTransformPropagator::Options()
.propagateParallelType());
}

MmaInputSmemSwizzle swizzle = mma_utils::tmaSwizzleSharedMemory(d_smem);

// Schedule shared memory cache; Output from StMatrix
mma_utils::scheduleStMatrixForMmaOutput(
d_smem, swizzle, stmatrix_tile_m, stmatrix_tile_n);
// [M, N] -> [128(TIDx), N/8 , m(2) , n(2)]
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
d_smem->getLoopDomain());
if (swizzle != MmaInputSmemSwizzle::None) {
// Create tma store allocation domain with swizzle
mma_utils::scheduleTMAStoreForMmaOutput(d_smem, swizzle);
}
d_smem->setLoopDomain(s.as<IterDomain*>());

if (store_with_stmatrix) {
// Schedule shared memory cache; Output from StMatrix
mma_utils::scheduleStMatrixForMmaOutput(
d_smem, swizzle, stmatrix_tile_m, stmatrix_tile_n);
}

d_smem->axis(-1)->parallelize(ParallelType::Vectorize);

// Schedule global memory output; Output from TMA Store
mma_utils::scheduleTMAStoreForMmaOutput(d, swizzle);
Expand Down
12 changes: 0 additions & 12 deletions csrc/scheduler/mma_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1315,17 +1315,6 @@ void scheduleStMatrixForMmaOutput(
dataTypeSize(tv->dtype()) == 2,
"we only support 16-bit types in stmatrix");

// [M, N] -> [128(TIDx), N/8 , 2 , 2]
auto s =
mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(tv->getLoopDomain());

if (swizzle != MmaInputSmemSwizzle::None) {
// Create tma store allocation domain with swizzle
mma_utils::scheduleTMAStoreForMmaOutput(tv, swizzle);
}

tv->setLoopDomain(s.as<IterDomain*>());

if (tile_m == 16 && tile_n == 16) {
// Let [M, N] be [64, 32]
// After scheduleMmaOutputAllocation: [128(TIDx), 4, 2, 2]
Expand All @@ -1344,7 +1333,6 @@ void scheduleStMatrixForMmaOutput(
// [2, 128(TIDx), 2, 2] -> [2, 128(TIDx), 4(vectorize)]
tv->merge(-2);
}
tv->axis(-1)->parallelize(ParallelType::Vectorize);
}

MatmulOperandInnerDimsOpt getOperandInnerDims(Fusion* fusion) {
Expand Down
26 changes: 18 additions & 8 deletions tests/cpp/test_matmul_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2660,7 +2660,7 @@ TEST_F(MatmulSchedulerTest, SegmentMatmulOpUnsupportedDtype) {
testValidate(executor_cache.fusion(), outputs, {t0, t1}, __LINE__, __FILE__);
}

TEST_F(MatmulSchedulerTest, PreBroadcastGEMM) {
TEST_F(MatmulSchedulerTest, PreBroadcastMmaBiasNeg) {
// TODO: fix up params or switch to FusionExecutorCache when ready, then
// enable Ampere
NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0);
Expand All @@ -2671,12 +2671,20 @@ TEST_F(MatmulSchedulerTest, PreBroadcastGEMM) {
// A - tv0, B - tv1
auto tv0 = makeContigConcreteTensor({-1, 1, -1}, DataType::Half);
auto tv1 = makeContigConcreteTensor({1, -1, -1}, DataType::Half);
TensorView* tv2 = makeContigConcreteTensor({-1}, DataType::Half);
fusion->addInput(tv0);
fusion->addInput(tv1);
fusion->addInput(tv2);

auto tv2 = fusedMultiplySum(tv0, tv1, {-1});
auto tv3 = fusedMultiplySum(tv0, tv1, {-1});
// We add these computations to test
// scheduling (with epilogue) when the ouptut of mma is not
// cast to half.
auto tv4 = maybeCastOp(DataType::Float, tv2);
auto tv5 = biasEpilogue(tv3, tv4);
auto tv6 = neg(tv5);

fusion->addOutput(tv2);
fusion->addOutput(tv6);

NVF_CHECK(
1 == ir_utils::getOpsOfType<MmaOp>(fusion.get()).size(),
Expand All @@ -2689,10 +2697,14 @@ TEST_F(MatmulSchedulerTest, PreBroadcastGEMM) {
auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
auto a = at::randn({M, K}, options);
auto b = at::randn({N, K}, options);
auto c = at::randn({M}, options);
auto t0 = a.unsqueeze(1);
auto t1 = b.unsqueeze(0);
auto tref = at::matmul(a.to(at::kFloat), b.to(at::kFloat).t());
std::vector<c10::IValue> inputs{t0, t1};
auto tref =
atBiasEpilogue(
at::matmul(a.to(at::kFloat), b.to(at::kFloat).t()), c.to(at::kFloat))
.neg_();
std::vector<c10::IValue> inputs{t0, t1, c};

MatmulParams mparams;
mparams.supported_vec_size = {8, 8, 4};
Expand All @@ -2705,9 +2717,7 @@ TEST_F(MatmulSchedulerTest, PreBroadcastGEMM) {
mparams.circular_buffer_options.circular_buffer_smem_write = true;
mparams.circular_buffer_options.circular_buffer_smem_read = true;
mparams.circular_buffer_options.smem_circular_buffer_stage = 2;
// TODO: Currently we use stmatrix whenever this is true. We cannot do that
// when the dtype is not 16 bits.
mparams.use_smem_epilogue = false;
mparams.use_smem_epilogue = true;
mparams.promote_prologue_smem_reuse = false;

SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
Expand Down
12 changes: 8 additions & 4 deletions tests/cpp/test_memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2860,14 +2860,18 @@ TEST_P(StMatrixTest, Regular) {
tv0->split(0, 32);
tv0->axis(1)->parallelize(ParallelType::TIDx);

auto s =
mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(tv1->getLoopDomain());
tv1->setLoopDomain(s.as<IterDomain*>());
tv1->setAllocationDomain(s.as<IterDomain*>(), true);
for (auto tv : {tv1, tv2}) {
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
tv->getLoopDomain());
tv->setLoopDomain(s.as<IterDomain*>());
}
tv1->setAllocationDomain(tv1->getLoopDomain(), true);

mma_utils::scheduleStMatrixForMmaOutput(
tv2, /*swizzle=*/MmaInputSmemSwizzle::None, tile_m, tile_n);

tv2->axis(-1)->parallelize(ParallelType::Vectorize);

tv3->merge(0);
tv3->split(0, 32);
tv3->axis(1)->parallelize(ParallelType::TIDx);
Expand Down
24 changes: 18 additions & 6 deletions tests/cpp/test_mma.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -515,12 +515,6 @@ TEST_P(HopperRSStmatrix, SingleTileWithTMALoadStoreStMatrix) {
EXPECT_TRUE(tv3->getMemoryType() == MemoryType::Shared);
EXPECT_TRUE(tv4->getMemoryType() == MemoryType::Global);

{
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
tv3c->getLoopDomain());
tv3c->setLoopDomain(s.as<IterDomain*>());
tv3c->setAllocationDomain(s.as<IterDomain*>(), true);
}
{
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
tv2->getLoopDomain());
Expand All @@ -531,8 +525,26 @@ TEST_P(HopperRSStmatrix, SingleTileWithTMALoadStoreStMatrix) {
tv2->axis(-3)->parallelize(ParallelType::Mma);
}

{
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
tv3c->getLoopDomain());
tv3c->setLoopDomain(s.as<IterDomain*>());
tv3c->setAllocationDomain(s.as<IterDomain*>(), true);
}

MmaInputSmemSwizzle swizzle = mma_utils::tmaSwizzleSharedMemory(tv3);
{
auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(
tv3->getLoopDomain());

if (swizzle != MmaInputSmemSwizzle::None) {
mma_utils::scheduleTMAStoreForMmaOutput(tv3, swizzle);
}

tv3->setLoopDomain(s.as<IterDomain*>());
}
mma_utils::scheduleStMatrixForMmaOutput(tv3, swizzle, tile_m, tile_n);
tv3->axis(-1)->parallelize(ParallelType::Vectorize);

mma_utils::scheduleTMAStoreForMmaOutput(tv4, swizzle);

Expand Down

0 comments on commit ee63c98

Please sign in to comment.