Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimization to put LHS operand in registers for WGMMA before elementwise ops #17

Closed
wants to merge 28 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
30787ef
preliminary changes for hoisting DotOpEnc for MMAV3
ggengnv Aug 26, 2024
896e128
improve DotOpEnc hoisting to use LocalLoad
ggengnv Aug 28, 2024
07552f4
Don't pipeline LHS operand for now
ggengnv Aug 28, 2024
60b08ca
Add placeholders for SharedToDotOpV3 and fix LLVM lowering issues
ggengnv Aug 29, 2024
ec5a726
Hack to make SharedEncodingAttr work for different width types
ggengnv Aug 30, 2024
d4ae449
dot lower to shared working for Hopper; small refactors
ggengnv Aug 30, 2024
fefdabc
Fix dot op ordering for WGMMA and allow ConstantOp
ggengnv Sep 4, 2024
5767cf3
Revert DecomposeUnsupportedConversions since we use v2 SharedEnc for …
ggengnv Sep 4, 2024
a72548f
revert ordering changes
ggengnv Sep 5, 2024
0c3b0a8
i8 -> f16 working
ggengnv Sep 6, 2024
a8875b0
fix lit test regressions
ggengnv Sep 6, 2024
38d6f45
Add comments for isHopperWidthChange
ggengnv Sep 6, 2024
9c36d4e
fix regression for Hopper MMA > DotOp
ggengnv Sep 6, 2024
0568fd8
Disable hoisting thru downcasts
ggengnv Sep 6, 2024
a3d65d1
fix test regression with WGMMA.cpp
ggengnv Sep 9, 2024
bbb0668
Rewrite OptimizeDotOperands logic to fix for general case
ggengnv Sep 9, 2024
00418d6
fix another regression and add minor comments
ggengnv Sep 9, 2024
93ef960
Remove redundant LocalAlloc+LocalLoad ops added in OptimizeDotOperands
ggengnv Sep 10, 2024
3c65087
fix bad rebase
ggengnv Sep 19, 2024
e7904ee
Pipelining
ggengnv Sep 19, 2024
d67b207
Add pipeline test
ggengnv Sep 19, 2024
bbea8d6
Add chained test for dot operand hoisting
ggengnv Sep 19, 2024
d1b64ee
Fix hoisting bug and add comments
ggengnv Sep 19, 2024
1115cd2
delete draft code
ggengnv Sep 19, 2024
15e9cf6
Improve comments
ggengnv Sep 20, 2024
47f1e84
Fix coalescing
ggengnv Sep 20, 2024
a2dea0a
fix typo
ggengnv Sep 20, 2024
c3460a3
More minor comments
ggengnv Sep 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,8 @@ compared to 1*64 when the hasLeadingOffset is false.
return get(context, vec, perPhase, maxPhase, order, CTALayout);
}

// ---- begin Ampere ----
if (mmaEnc.isAmpere()) {
// ---- begin Ampere & Hopper ----
if (mmaEnc.isAmpere() || mmaEnc.isHopper()) {
int perPhase = 128 / (shapePerCTA[order[0]] * 4 / dotOpEnc.getKWidth());
perPhase = std::max<int>(perPhase, 1);
std::vector<size_t> matShape = {8, 8, 4 * dotOpEnc.getKWidth()};
Expand Down Expand Up @@ -397,13 +397,6 @@ compared to 1*64 when the hasLeadingOffset is false.
llvm_unreachable("invalid operand index");
}

// ---- begin version 3 ----
if (mmaEnc.isHopper()) {
llvm_unreachable("SharedEncodingAttr builder when the MMAEncodingAttr"
" is Hopper has not been implemented yet");
return $_get(context, 1, 1, 1, order, CTALayout, true);
}

// ---- not implemented ----
llvm_unreachable("unsupported swizzling for provided MMA version");
}]>,
Expand Down Expand Up @@ -1317,6 +1310,10 @@ The parent field is the layout of d.
kWidth defines number of consecutive elements stored by one thread along k dimension.
Some layouts do not use this parameter, either because they have a fixed number of
elements along the K dim, or they use all elements of the tensor along the K dim.

We require kWidth to be provided for Hopper because the dtype at loading might be
different from the dtype at WGMMA, due to casting. The kWidth is determined by the
dtype at WGMMA.
}];

let parameters = (
Expand All @@ -1332,7 +1329,7 @@ elements along the K dim, or they use all elements of the tensor along the K dim
"Attribute":$parent,
"Type":$eltTy), [{
NvidiaMmaEncodingAttr parentAttr = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent);
if (!parentAttr || !parentAttr.isAmpere())
if (!parentAttr || (!parentAttr.isAmpere() && !parentAttr.isHopper()))
return $_get(context, opIdx, parent, 0);
unsigned bitwidth = eltTy.getIntOrFloatBitWidth();
unsigned MMAv2kWidth = 32 / bitwidth;
ggengnv marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
14 changes: 11 additions & 3 deletions lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ SmallVector<Value> reorderValues(const SmallVector<Value> &values, Type inType,
// If the parent of the dot operand is in block encoding, we don't need to
// reorder elements
auto parentEncoding = dyn_cast<NvidiaMmaEncodingAttr>(ouEncoding.getParent());
if (!parentEncoding)
if (!parentEncoding || parentEncoding.isHopper())
return values;
size_t inBitWidth = inTensorTy.getElementType().getIntOrFloatBitWidth();
size_t ouBitWidth = ouTensorTy.getElementType().getIntOrFloatBitWidth();
Expand Down Expand Up @@ -87,8 +87,12 @@ SmallVector<Value> unpackI32(const SmallVector<Value> &inValues, Type srcTy,
if (!tensorTy)
return inValues;
auto encoding = dyn_cast<DotOperandEncodingAttr>(tensorTy.getEncoding());
if (!(encoding && isa<NvidiaMmaEncodingAttr>(encoding.getParent())))
if (!encoding)
return inValues;
auto parentEnc = dyn_cast<NvidiaMmaEncodingAttr>(encoding.getParent());
if (!parentEnc || parentEnc.isHopper())
return inValues;

SmallVector<Value> outValues;
for (auto v : inValues) {
// cast i32 to appropriate eltType vector and extract elements
Expand All @@ -109,8 +113,12 @@ SmallVector<Value> packI32(const SmallVector<Value> &inValues, Type srcTy,
if (!tensorTy)
return inValues;
auto encoding = dyn_cast<DotOperandEncodingAttr>(tensorTy.getEncoding());
if (!(encoding && isa<NvidiaMmaEncodingAttr>(encoding.getParent())))
if (!encoding)
return inValues;
auto parentEnc = dyn_cast<NvidiaMmaEncodingAttr>(encoding.getParent());
if (!parentEnc || parentEnc.isHopper())
return inValues;

SmallVector<Value> outValues;
auto eltType = typeConverter->convertType(tensorTy.getElementType());
int vecWidth = 32 / eltType.getIntOrFloatBitWidth();
Expand Down
22 changes: 16 additions & 6 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1022,13 +1022,17 @@ LogicalResult DotOperandEncodingAttr::verify(
return emitError() << "triton_gpu.dot_op parent paramenter cannot be null";
}
if (auto parentAttr = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
if (kWidth != 0 && !parentAttr.isAmpere())
if (kWidth != 0 && !(parentAttr.isAmpere() || parentAttr.isHopper()))
return emitError() << "triton_gpu.dot_op kWidth parameter can only be "
"non-zero for Ampere MMA parent";
if (kWidth == 0 && parentAttr.isAmpere())
"non-zero for Ampere or Hopper MMA parent";
if (kWidth == 0 && (parentAttr.isAmpere() || parentAttr.isHopper()))
return emitError()
<< "triton_gpu.dot_op kWidth parameter is mandatory for "
"Ampere MMA parent";
"Ampere or Hopper MMA parent";
if (opIdx != 0 && parentAttr.isHopper())
return emitError()
<< "triton_gpu.dot_op opIdx parameter must be 0 for "
"Hopper MMA parent";
return success();
}

Expand Down Expand Up @@ -1957,14 +1961,14 @@ int NvidiaMmaEncodingAttr::getMMAv1Vec(int opIdx) const {
SmallVector<int64_t> NvidiaMmaEncodingAttr::getMMAv2Rep(ArrayRef<int64_t> shape,
ggengnv marked this conversation as resolved.
Show resolved Hide resolved
int bitwidth,
int opIdx) const {
assert(isAmpere() || isHopper());
auto rank = shape.size();
auto warpsPerCTA = getWarpsPerCTA();
SmallVector<int> shapePerWarp = {1, 16, 8, 4 * 64 / bitwidth};
int numRepBatch =
rank == 3
? std::max<int64_t>(1, shape[0] / (shapePerWarp[0] * warpsPerCTA[0]))
: 1;
assert(isAmpere());

if (opIdx == 0)
return {numRepBatch,
Expand All @@ -1979,14 +1983,20 @@ SmallVector<int64_t> NvidiaMmaEncodingAttr::getMMAv2Rep(ArrayRef<int64_t> shape,
warpsPerCTA[rank - 1]))};
}
}

unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperands(
ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const {
auto shapePerCTA = getShapePerCTA(*this, shape);
int warpsPerCTAM = getWarpsPerCTA()[0];
int warpsPerCTAN = getWarpsPerCTA()[1];
// H100
if (isHopper()) {
return getTotalElemsPerThread(shape, eltTy);
assert(opIdx == 0);
auto instrMNK = getInstrShape();
auto wpt = getWarpsPerCTA();
ggengnv marked this conversation as resolved.
Show resolved Hide resolved
int repM = ceil<unsigned>(shapePerCTA[0], instrMNK[0] * wpt[0]);
int repK = ceil<unsigned>(shapePerCTA[1], instrMNK[2]);
return 4 * kWidth * repM * repK;
ggengnv marked this conversation as resolved.
Show resolved Hide resolved
}
// A100
if (isAmpere()) {
Expand Down
Loading