From ee63c98be33a977f08d038f714124eba64360189 Mon Sep 17 00:00:00 2001 From: Protonu Date: Tue, 24 Dec 2024 21:13:09 -0500 Subject: [PATCH] Add support for smem_epilogue when mma output is not cast to half (#3620) Support non-stmatrix stores from regs to shared memory and then TMA when the output of mma op is not cast back to half precision - stmatrix works with half precision only. --- csrc/scheduler/hopper_multi_matmul.cpp | 66 +++++++++++++++++--------- csrc/scheduler/mma_utils.cpp | 12 ----- tests/cpp/test_matmul_scheduler.cpp | 26 ++++++---- tests/cpp/test_memory.cpp | 12 +++-- tests/cpp/test_mma.cpp | 24 +++++++--- 5 files changed, 87 insertions(+), 53 deletions(-) diff --git a/csrc/scheduler/hopper_multi_matmul.cpp b/csrc/scheduler/hopper_multi_matmul.cpp index 2fa0a40ab75..b0e4b751c8a 100644 --- a/csrc/scheduler/hopper_multi_matmul.cpp +++ b/csrc/scheduler/hopper_multi_matmul.cpp @@ -508,8 +508,12 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() { TensorView* d_smem = cacheAfter(dc, LoadStoreOpType::Set); std::vector tvs_to_schedule{d, d_smem}; - if (std::find(mma_results_.begin(), mma_results_.end(), dc) == - mma_results_.end()) { + + bool dc_in_mma_results = + std::find(mma_results_.begin(), mma_results_.end(), dc) != + mma_results_.end(); + + if (!dc_in_mma_results) { // Skip scheduling dc if it is an mma_result. This can happen if we are // not casting back to half-precision in the output tvs_to_schedule.push_back(dc); @@ -519,14 +523,13 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() { dc->setMemoryType(MemoryType::Local); d_smem->setMemoryType(MemoryType::Shared); - // Set LoadStoreOp - // TODO: extend support when mma is not cast to half - NVF_CHECK( - dataTypeSize(dc->dtype()) == 2, - "We support use_smem_epilogue on Hopper only when the output is 16-bit"); + auto store_with_stmatrix = dataTypeSize(dc->dtype()) == 2; - d_smem->definition()->as()->setOpType( - LoadStoreOpType::StMatrix); + if (store_with_stmatrix) { + // Set LoadStoreOp + d_smem->definition()->as()->setOpType( + LoadStoreOpType::StMatrix); + } d->definition()->as()->setOpType( LoadStoreOpType::CpAsyncBulkTensorTile); @@ -539,23 +542,40 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() { transformLikeMmaOutput(tv, /*is_mma_result=*/false); } - auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( - dc->getLoopDomain()); - dc->setLoopDomain(s.as()); - dc->setAllocationDomain(s.as(), true); - - scheduler_utils::BoundedDirectionalTransformPropagator::backward( - dc, - -1, - propagate_to, - scheduler_utils::BoundedDirectionalTransformPropagator::Options() - .propagateParallelType()); + // Should not propagate if the dc is a mma output as the mma output has + // already been scheduled. + if (!dc_in_mma_results) { + auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( + dc->getLoopDomain()); + dc->setLoopDomain(s.as()); + dc->setAllocationDomain(s.as(), true); + + scheduler_utils::BoundedDirectionalTransformPropagator::backward( + dc, + -1, + propagate_to, + scheduler_utils::BoundedDirectionalTransformPropagator::Options() + .propagateParallelType()); + } MmaInputSmemSwizzle swizzle = mma_utils::tmaSwizzleSharedMemory(d_smem); - // Schedule shared memory cache; Output from StMatrix - mma_utils::scheduleStMatrixForMmaOutput( - d_smem, swizzle, stmatrix_tile_m, stmatrix_tile_n); + // [M, N] -> [128(TIDx), N/8 , m(2) , n(2)] + auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( + d_smem->getLoopDomain()); + if (swizzle != MmaInputSmemSwizzle::None) { + // Create tma store allocation domain with swizzle + mma_utils::scheduleTMAStoreForMmaOutput(d_smem, swizzle); + } + d_smem->setLoopDomain(s.as()); + + if (store_with_stmatrix) { + // Schedule shared memory cache; Output from StMatrix + mma_utils::scheduleStMatrixForMmaOutput( + d_smem, swizzle, stmatrix_tile_m, stmatrix_tile_n); + } + + d_smem->axis(-1)->parallelize(ParallelType::Vectorize); // Schedule global memory output; Output from TMA Store mma_utils::scheduleTMAStoreForMmaOutput(d, swizzle); diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index e52a044b509..3afc8d43a97 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1315,17 +1315,6 @@ void scheduleStMatrixForMmaOutput( dataTypeSize(tv->dtype()) == 2, "we only support 16-bit types in stmatrix"); - // [M, N] -> [128(TIDx), N/8 , 2 , 2] - auto s = - mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(tv->getLoopDomain()); - - if (swizzle != MmaInputSmemSwizzle::None) { - // Create tma store allocation domain with swizzle - mma_utils::scheduleTMAStoreForMmaOutput(tv, swizzle); - } - - tv->setLoopDomain(s.as()); - if (tile_m == 16 && tile_n == 16) { // Let [M, N] be [64, 32] // After scheduleMmaOutputAllocation: [128(TIDx), 4, 2, 2] @@ -1344,7 +1333,6 @@ void scheduleStMatrixForMmaOutput( // [2, 128(TIDx), 2, 2] -> [2, 128(TIDx), 4(vectorize)] tv->merge(-2); } - tv->axis(-1)->parallelize(ParallelType::Vectorize); } MatmulOperandInnerDimsOpt getOperandInnerDims(Fusion* fusion) { diff --git a/tests/cpp/test_matmul_scheduler.cpp b/tests/cpp/test_matmul_scheduler.cpp index c860b908a92..8182464cf40 100644 --- a/tests/cpp/test_matmul_scheduler.cpp +++ b/tests/cpp/test_matmul_scheduler.cpp @@ -2660,7 +2660,7 @@ TEST_F(MatmulSchedulerTest, SegmentMatmulOpUnsupportedDtype) { testValidate(executor_cache.fusion(), outputs, {t0, t1}, __LINE__, __FILE__); } -TEST_F(MatmulSchedulerTest, PreBroadcastGEMM) { +TEST_F(MatmulSchedulerTest, PreBroadcastMmaBiasNeg) { // TODO: fix up params or switch to FusionExecutorCache when ready, then // enable Ampere NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); @@ -2671,12 +2671,20 @@ TEST_F(MatmulSchedulerTest, PreBroadcastGEMM) { // A - tv0, B - tv1 auto tv0 = makeContigConcreteTensor({-1, 1, -1}, DataType::Half); auto tv1 = makeContigConcreteTensor({1, -1, -1}, DataType::Half); + TensorView* tv2 = makeContigConcreteTensor({-1}, DataType::Half); fusion->addInput(tv0); fusion->addInput(tv1); + fusion->addInput(tv2); - auto tv2 = fusedMultiplySum(tv0, tv1, {-1}); + auto tv3 = fusedMultiplySum(tv0, tv1, {-1}); + // We add these computations to test + // scheduling (with epilogue) when the ouptut of mma is not + // cast to half. + auto tv4 = maybeCastOp(DataType::Float, tv2); + auto tv5 = biasEpilogue(tv3, tv4); + auto tv6 = neg(tv5); - fusion->addOutput(tv2); + fusion->addOutput(tv6); NVF_CHECK( 1 == ir_utils::getOpsOfType(fusion.get()).size(), @@ -2689,10 +2697,14 @@ TEST_F(MatmulSchedulerTest, PreBroadcastGEMM) { auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA); auto a = at::randn({M, K}, options); auto b = at::randn({N, K}, options); + auto c = at::randn({M}, options); auto t0 = a.unsqueeze(1); auto t1 = b.unsqueeze(0); - auto tref = at::matmul(a.to(at::kFloat), b.to(at::kFloat).t()); - std::vector inputs{t0, t1}; + auto tref = + atBiasEpilogue( + at::matmul(a.to(at::kFloat), b.to(at::kFloat).t()), c.to(at::kFloat)) + .neg_(); + std::vector inputs{t0, t1, c}; MatmulParams mparams; mparams.supported_vec_size = {8, 8, 4}; @@ -2705,9 +2717,7 @@ TEST_F(MatmulSchedulerTest, PreBroadcastGEMM) { mparams.circular_buffer_options.circular_buffer_smem_write = true; mparams.circular_buffer_options.circular_buffer_smem_read = true; mparams.circular_buffer_options.smem_circular_buffer_stage = 2; - // TODO: Currently we use stmatrix whenever this is true. We cannot do that - // when the dtype is not 16 bits. - mparams.use_smem_epilogue = false; + mparams.use_smem_epilogue = true; mparams.promote_prologue_smem_reuse = false; SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul) diff --git a/tests/cpp/test_memory.cpp b/tests/cpp/test_memory.cpp index 92d85fdc947..991fe732b72 100644 --- a/tests/cpp/test_memory.cpp +++ b/tests/cpp/test_memory.cpp @@ -2860,14 +2860,18 @@ TEST_P(StMatrixTest, Regular) { tv0->split(0, 32); tv0->axis(1)->parallelize(ParallelType::TIDx); - auto s = - mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(tv1->getLoopDomain()); - tv1->setLoopDomain(s.as()); - tv1->setAllocationDomain(s.as(), true); + for (auto tv : {tv1, tv2}) { + auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( + tv->getLoopDomain()); + tv->setLoopDomain(s.as()); + } + tv1->setAllocationDomain(tv1->getLoopDomain(), true); mma_utils::scheduleStMatrixForMmaOutput( tv2, /*swizzle=*/MmaInputSmemSwizzle::None, tile_m, tile_n); + tv2->axis(-1)->parallelize(ParallelType::Vectorize); + tv3->merge(0); tv3->split(0, 32); tv3->axis(1)->parallelize(ParallelType::TIDx); diff --git a/tests/cpp/test_mma.cpp b/tests/cpp/test_mma.cpp index 9835d36d6c3..7aafcafb8ab 100644 --- a/tests/cpp/test_mma.cpp +++ b/tests/cpp/test_mma.cpp @@ -515,12 +515,6 @@ TEST_P(HopperRSStmatrix, SingleTileWithTMALoadStoreStMatrix) { EXPECT_TRUE(tv3->getMemoryType() == MemoryType::Shared); EXPECT_TRUE(tv4->getMemoryType() == MemoryType::Global); - { - auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( - tv3c->getLoopDomain()); - tv3c->setLoopDomain(s.as()); - tv3c->setAllocationDomain(s.as(), true); - } { auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( tv2->getLoopDomain()); @@ -531,8 +525,26 @@ TEST_P(HopperRSStmatrix, SingleTileWithTMALoadStoreStMatrix) { tv2->axis(-3)->parallelize(ParallelType::Mma); } + { + auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( + tv3c->getLoopDomain()); + tv3c->setLoopDomain(s.as()); + tv3c->setAllocationDomain(s.as(), true); + } + MmaInputSmemSwizzle swizzle = mma_utils::tmaSwizzleSharedMemory(tv3); + { + auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( + tv3->getLoopDomain()); + + if (swizzle != MmaInputSmemSwizzle::None) { + mma_utils::scheduleTMAStoreForMmaOutput(tv3, swizzle); + } + + tv3->setLoopDomain(s.as()); + } mma_utils::scheduleStMatrixForMmaOutput(tv3, swizzle, tile_m, tile_n); + tv3->axis(-1)->parallelize(ParallelType::Vectorize); mma_utils::scheduleTMAStoreForMmaOutput(tv4, swizzle);