third_party/triton/xla_extensions/sparse_dot_passes.patch
diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
index 4aa2712ec..16a6253d7 100644
--- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
+++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
@@ -279,6 +279,89 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
}
};
+struct TritonSparseDotPattern
+ : public OpConversionPattern<triton::gpu::SparseDotOp> {
+ using OpConversionPattern<triton::gpu::SparseDotOp>::OpConversionPattern;
+
+ LogicalResult matchAndRewrite(
+ triton::gpu::SparseDotOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ RankedTensorType origType = cast<RankedTensorType>(op.getType());
+ auto origShape = origType.getShape();
+ auto typeConverter = getTypeConverter<TritonGPUTypeConverter>();
+ int numWarps = typeConverter->getNumWarps();
+ int threadsPerWarp = typeConverter->getThreadsPerWarp();
+ int numCTAs = typeConverter->getNumCTAs();
+
+ auto rank = origShape.size();
+ auto numElements = product<int64_t>(origShape);
+ SmallVector<unsigned> retSizePerThread(rank, 1);
+ if (numElements / (numWarps * threadsPerWarp) >= 4) {
+ retSizePerThread[rank - 1] = 2;
+ retSizePerThread[rank - 2] = 2;
+ }
+ if (numElements / (numWarps * threadsPerWarp) >= 16) {
+ retSizePerThread[rank - 1] = 4;
+ retSizePerThread[rank - 2] = 4;
+ }
+ SmallVector<unsigned> retOrder(rank);
+ for (unsigned i = 0; i < rank; ++i)
+ retOrder[i] = rank - 1 - i;
+ Attribute dEncoding = triton::gpu::BlockedEncodingAttr::get(
+ getContext(), origShape, retSizePerThread, retOrder, numWarps,
+ threadsPerWarp, numCTAs);
+ RankedTensorType retType =
+ RankedTensorType::get(origShape, origType.getElementType(), dEncoding);
+
+ // a & b must be of smem layout
+ auto aType = cast<RankedTensorType>(adaptor.getA().getType());
+ auto bType = cast<RankedTensorType>(adaptor.getB().getType());
+ Type aEltType = aType.getElementType();
+ Type bEltType = bType.getElementType();
+ Attribute aEncoding = aType.getEncoding();
+ Attribute bEncoding = bType.getEncoding();
+ if (!aEncoding || !bEncoding)
+ return failure();
+ Value a = adaptor.getA();
+ Value b = adaptor.getB();
+ Value c = adaptor.getC();
+ if (!isa<triton::gpu::DotOperandEncodingAttr>(aEncoding)) {
+ Attribute encoding = triton::gpu::DotOperandEncodingAttr::get(
+ getContext(), 0, dEncoding, aEltType);
+ auto dstType =
+ RankedTensorType::get(aType.getShape(), aEltType, encoding);
+ a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), dstType, a);
+ }
+ if (!isa<triton::gpu::DotOperandEncodingAttr>(bEncoding)) {
+ Attribute encoding = triton::gpu::DotOperandEncodingAttr::get(
+ getContext(), 1, dEncoding, bEltType);
+ auto dstType =
+ RankedTensorType::get(bType.getShape(), bEltType, encoding);
+ b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
+ }
+ c = rewriter.create<triton::gpu::ConvertLayoutOp>(c.getLoc(), retType, c);
+
+ // aMeta must be of smem layout
+ auto aMetaType = cast<RankedTensorType>(adaptor.getAMeta().getType());
+ Attribute aMetaEncoding = aMetaType.getEncoding();
+ if (!aMetaEncoding) return failure();
+ Value aMeta = adaptor.getAMeta();
+ if (!isa<triton::gpu::SparseDotMetaEncodingAttr>(aMetaEncoding)) {
+ Attribute encoding =
+ triton::gpu::SparseDotMetaEncodingAttr::get(getContext(), dEncoding);
+ auto dstType = RankedTensorType::get(
+ aMetaType.getShape(), aMetaType.getElementType(), encoding);
+ aMeta = rewriter.create<triton::gpu::ConvertLayoutOp>(aMeta.getLoc(),
+ dstType, aMeta);
+ }
+
+ addNamedAttrs(rewriter.replaceOpWithNewOp<triton::gpu::SparseDotOp>(
+ op, retType, a, b, c, aMeta),
+ adaptor.getAttributes());
+ return success();
+ }
+};
+
struct TritonCatPattern : public OpConversionPattern<triton::CatOp> {
using OpConversionPattern::OpConversionPattern;
@@ -553,6 +636,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
GenericOpPattern<triton::ExperimentalDescriptorStoreOp>,
GenericOpPattern<triton::CallOp>, TritonFuncOpPattern>(typeConverter,
context);
+ patterns.insert<TritonSparseDotPattern>(typeConverter, context);
}
//
@@ -794,6 +878,12 @@ public:
mod->setAttr(AttrTargetName,
StringAttr::get(context, this->target.getValue()));
+ // Only transform sparse dot op with undefined layout.
+ target.addDynamicallyLegalOp<triton::gpu::SparseDotOp>(
+ [](triton::gpu::SparseDotOp op) {
+ return op.getAMeta().getType().getEncoding() != nullptr;
+ });
+
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
return signalPassFailure();
diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
index 098ee85e4..0516fc56f 100644
--- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
+++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
@@ -44,8 +44,9 @@ static int getMMAVersionSafe(int computeCapability, tt::DotOp op) {
return 0;
}
+template <typename DotType>
SmallVector<unsigned>
-warpsPerTileV2(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps) {
+warpsPerTileV2(DotType dotOp, const ArrayRef<int64_t> shape, int numWarps) {
auto rank = shape.size();
// Early exit for batched matmul
if (rank == 3)
@@ -58,8 +59,8 @@ warpsPerTileV2(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps) {
auto slices = multiRootGetSlice(dotOp, {filter}, {filter});
bool hasChainedDot = false;
for (Operation *op : slices) {
- if (isa<tt::DotOp>(op) && (op != dotOp)) {
- auto chainedDot = cast<tt::DotOp>(op);
+ if (isa<DotType>(op) && (op != dotOp)) {
+ auto chainedDot = cast<DotType>(op);
auto resTy = chainedDot.getResult().getType();
if (resTy.getRank() != rank) {
continue;
@@ -103,12 +104,13 @@ warpsPerTileV2(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps) {
return ret;
}
-SmallVector<unsigned, 2>
-warpsPerTileV3(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps,
- const SmallVector<unsigned, 3> &instrShape) {
+template <typename DotType>
+SmallVector<unsigned, 2> warpsPerTileV3(
+ DotType dotOp, const ArrayRef<int64_t> shape, int numWarps,
+ const SmallVector<unsigned, 3> &instrShape) {
SetVector<Operation *> slices;
mlir::getForwardSlice(dotOp.getResult(), &slices);
- if (llvm::find_if(slices, [](Operation *op) { return isa<tt::DotOp>(op); }) !=
+ if (llvm::find_if(slices, [](Operation *op) { return isa<DotType>(op); }) !=
slices.end())
return {(unsigned)numWarps, 1};
@@ -178,9 +180,10 @@ public:
: mlir::RewritePattern(tt::DotOp::getOperationName(), 2, context),
computeCapability(computeCapability) {}
- static SmallVector<unsigned, 3>
- getWarpsPerTile(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int version,
- int numWarps, const SmallVector<unsigned, 3> &instrShape) {
+ template <typename DotType>
+ static SmallVector<unsigned, 3> getWarpsPerTile(
+ DotType dotOp, const ArrayRef<int64_t> shape, int version, int numWarps,
+ const SmallVector<unsigned, 3> &instrShape) {
switch (version) {
case 2:
return warpsPerTileV2(dotOp, shape, numWarps);
@@ -335,6 +338,98 @@ public:
return success();
}
};
+
+class SparseBlockedToMMA : public mlir::RewritePattern {
+ public:
+ using SparseDotOp = mlir::triton::gpu::SparseDotOp;
+ using SparseDotMetaEncodingAttr =
+ mlir::triton::gpu::SparseDotMetaEncodingAttr;
+
+ SparseBlockedToMMA(mlir::MLIRContext *context, int computeCapability)
+ : mlir::RewritePattern(SparseDotOp::getOperationName(), 2, context),
+ computeCapability(computeCapability) {}
+
+ mlir::LogicalResult matchAndRewrite(
+ mlir::Operation *op, mlir::PatternRewriter &rewriter) const override {
+ auto dotOp = cast<SparseDotOp>(op);
+ auto ctx = op->getContext();
+ Value a = dotOp.getA();
+ Value b = dotOp.getB();
+
+ // Check data-types and SM compatibility
+ RankedTensorType oldRetType = dotOp.getType();
+ if (!oldRetType.getEncoding() ||
+ isa<ttg::NvidiaMmaEncodingAttr>(oldRetType.getEncoding()))
+ return failure();
+
+ assert(computeCapability >= 80 &&
+ "SparseDot is supported on Ampere and higher");
+ int versionMajor = computeCapability < 90 ? 2 : 3;
+
+ // get MMA encoding for the given number of warps
+ auto retShapePerCTA = ttg::getShapePerCTA(oldRetType);
+ auto mod = op->getParentOfType<mlir::ModuleOp>();
+ int numWarps = ttg::TritonGPUDialect::getNumWarps(mod);
+ auto CTALayout = ttg::getCTALayout(oldRetType.getEncoding());
+
+ auto instrShape =
+ mmaVersionToInstrShape(versionMajor, retShapePerCTA,
+ cast<TensorOrMemDesc>(a.getType()), numWarps);
+ auto warpsPerTile = BlockedToMMA::getWarpsPerTile(
+ dotOp, retShapePerCTA, versionMajor, numWarps, instrShape);
+ ttg::NvidiaMmaEncodingAttr mmaEnc =
+ ttg::NvidiaMmaEncodingAttr::get(ctx, versionMajor, /*versionMinor=*/0,
+ warpsPerTile, CTALayout, instrShape);
+ auto newRetType = RankedTensorType::get(
+ oldRetType.getShape(), oldRetType.getElementType(), mmaEnc);
+
+ // convert accumulator
+ auto oldAcc = dotOp.getOperand(2);
+ auto newAcc = rewriter.create<ttg::ConvertLayoutOp>(oldAcc.getLoc(),
+ newRetType, oldAcc);
+
+ if (versionMajor == 2) {
+ // convert A operand
+ auto oldAType = cast<RankedTensorType>(a.getType());
+ auto newAEncoding = ttg::DotOperandEncodingAttr::get(
+ ctx, 0, mmaEnc, oldAType.getElementType());
+ auto newAType = RankedTensorType::get(
+ oldAType.getShape(), oldAType.getElementType(), newAEncoding);
+ a = rewriter.create<ttg::ConvertLayoutOp>(a.getLoc(), newAType, a);
+
+ // convert B operand
+ auto oldBType = cast<RankedTensorType>(b.getType());
+ auto newBEncoding = ttg::DotOperandEncodingAttr::get(
+ ctx, 1, mmaEnc, oldBType.getElementType());
+ auto newBType = RankedTensorType::get(
+ oldBType.getShape(), oldBType.getElementType(), newBEncoding);
+ b = rewriter.create<ttg::ConvertLayoutOp>(b.getLoc(), newBType, b);
+ } else {
+ a = BlockedToMMA::getMMAv3Operand(a, rewriter, 0);
+ b = BlockedToMMA::getMMAv3Operand(b, rewriter, 1);
+ }
+
+ // convert metadata
+ Value meta = dotOp.getAMeta();
+ auto oldMetaType = cast<RankedTensorType>(meta.getType());
+ auto newMetaType = RankedTensorType::get(
+ oldMetaType.getShape(), oldMetaType.getElementType(),
+ SparseDotMetaEncodingAttr::get(ctx, mmaEnc));
+ meta =
+ rewriter.create<ttg::ConvertLayoutOp>(meta.getLoc(), newMetaType, meta);
+
+ // convert dot instruction
+ auto newDot = rewriter.create<SparseDotOp>(dotOp.getLoc(), newRetType, a, b,
+ newAcc, meta);
+
+ rewriter.replaceOpWithNewOp<ttg::ConvertLayoutOp>(op, oldRetType,
+ newDot.getResult());
+ return success();
+ }
+
+ private:
+ int computeCapability;
+};
} // namespace
static Value promoteOperand(OpBuilder &builder, Location loc, Value operand,
@@ -394,6 +489,7 @@ public:
mlir::RewritePatternSet patterns(context);
patterns.add<::BlockedToMMA>(context, computeCapability);
+ patterns.add<::SparseBlockedToMMA>(context, computeCapability);
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
signalPassFailure();
}
diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp
index 97ca6a840..f0ef124ff 100644
--- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp
+++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp
@@ -188,6 +188,10 @@ public:
}
};
+static bool isDotOp(Operation* op) {
+ return isa<tt::DotOp, ttg::SparseDotOp>(op);
+}
+
static bool isMMAv3Dot(Operation *op) {
auto dot = dyn_cast<tt::DotOp>(op);
if (!dot)
@@ -399,19 +403,28 @@ getSharedEncIfAllUsersAreDotEnc(Value val) {
} else {
if (!isa<ttg::LocalLoadOp, ttg::ConvertLayoutOp>(user))
return std::nullopt;
- auto dotOpEnc = dyn_cast<ttg::DotOperandEncodingAttr>(
- cast<TensorOrMemDesc>(user->getResult(0).getType()).getEncoding());
- if (!dotOpEnc)
+ auto enc =
+ cast<TensorOrMemDesc>(user->getResult(0).getType()).getEncoding();
+ if (isa<ttg::DotOperandEncodingAttr>(enc)) {
+ auto srcTy = cast<TensorOrMemDesc>(val.getType());
+ auto CTALayout = ttg::getCTALayout(srcTy.getEncoding());
+ auto order = ttg::getOrder(srcTy.getEncoding());
+ unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth();
+ tempAttr = ttg::SharedEncodingAttr::get(
+ val.getContext(), cast<ttg::DotOperandEncodingAttr>(enc),
+ srcTy.getShape(), ttg::getOrder(srcTy.getEncoding()),
+ ttg::getCTALayout(srcTy.getEncoding()),
+ srcTy.getElementType().getIntOrFloatBitWidth(),
+ /*needTrans=*/false);
+ } else if (isa<ttg::SparseDotMetaEncodingAttr>(enc)) {
+ auto srcTy = cast<TensorOrMemDesc>(val.getType());
+ tempAttr = ttg::SharedEncodingAttr::get(
+ val.getContext(), /*vec=*/1, /*perPhase=*/1, /*maxPhase=*/1,
+ ttg::getOrder(srcTy.getEncoding()),
+ ttg::getCTALayout(srcTy.getEncoding()));
+ } else {
return std::nullopt;
- auto srcTy = cast<TensorOrMemDesc>(val.getType());
- auto CTALayout = ttg::getCTALayout(srcTy.getEncoding());
- auto order = ttg::getOrder(srcTy.getEncoding());
- unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth();
- tempAttr = ttg::SharedEncodingAttr::get(
- val.getContext(), dotOpEnc, srcTy.getShape(),
- ttg::getOrder(srcTy.getEncoding()),
- ttg::getCTALayout(srcTy.getEncoding()),
- srcTy.getElementType().getIntOrFloatBitWidth(), /*needTrans=*/false);
+ }
}
// Check that the shared encodings needed by the users are compatible.
if (!tempAttr || (attr != nullptr && attr != tempAttr))
@@ -518,7 +531,7 @@ loadOpsToIndirectionLevelAndUse(scf::ForOp forOp) {
};
for (Operation &op : forOp.getBody()->without_terminator()) {
- if (!isa<tt::DotOp>(op))
+ if (!isDotOp(&op))
continue;
seen.clear();
dfs(&op, 0, &op);
@@ -595,7 +608,8 @@ assignMemoryLayouts(llvm::SmallVector<std::tuple<Operation *, int, Operation *>>
continue;
}
- if (auto dot = dyn_cast<tt::DotOp>(use)) {
+ if (isDotOp(use)) {
+ auto dot = dyn_cast<tt::DotOp>(use);
loadInfo.usedByDot = true;
if (loadIsMMAv3(op)) {
loadInfo.loadIsMMAV3 = true;
@@ -614,7 +628,7 @@ assignMemoryLayouts(llvm::SmallVector<std::tuple<Operation *, int, Operation *>>
// The codegen bug is caught by an assertion, so if you think you've
// fixed it, feel free to delete this code and see if the assert still
// fails. :)
- if (!loadInfo.sharedEncoding) {
+ if (dot && !loadInfo.sharedEncoding) {
if (auto dotEnc = dyn_cast<ttg::NvidiaMmaEncodingAttr>(
dot.getResult().getType().getEncoding())) {
auto loadTy = cast<RankedTensorType>(op->getResultTypes()[0]);
diff --git a/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp b/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp
index 2211df31b..ee5ff44d8 100644
--- a/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp
+++ b/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp
@@ -37,6 +37,10 @@ public:
auto srcEncoding = srcType.getEncoding();
if (isa<triton::gpu::SharedEncodingAttr>(srcEncoding))
return;
+ if (isa<triton::gpu::SparseDotMetaEncodingAttr>(dstType.getEncoding())) {
+ replaceSparseMetaEncoding(cvtOp);
+ return;
+ }
auto dstDotOp =
dyn_cast<triton::gpu::DotOperandEncodingAttr>(dstType.getEncoding());
if (!dstDotOp)
@@ -83,6 +87,27 @@ public:
cvtOp.erase();
});
}
+
+ private:
+ void replaceSparseMetaEncoding(triton::gpu::ConvertLayoutOp cvtOp) {
+ auto srcType = cast<RankedTensorType>(cvtOp.getOperand().getType());
+ auto srcEncoding = srcType.getEncoding();
+ auto sharedLayout = triton::gpu::SharedEncodingAttr::get(
+ cvtOp.getContext(), 8, 1, 1, triton::gpu::getOrder(srcEncoding),
+ triton::gpu::getCTALayout(srcEncoding));
+
+ auto dstType = cast<RankedTensorType>(cvtOp.getType());
+ auto tmpType = triton::MemDescType::get(
+ dstType.getShape(), dstType.getElementType(), sharedLayout);
+
+ OpBuilder builder(cvtOp);
+ auto tmp = builder.create<triton::gpu::LocalAllocOp>(
+ cvtOp.getLoc(), tmpType, cvtOp.getSrc());
+ auto newConvert = builder.create<triton::gpu::LocalLoadOp>(
+ cvtOp.getLoc(), dstType, tmp);
+ cvtOp.replaceAllUsesWith(newConvert.getResult());
+ cvtOp.erase();
+ }
};
std::unique_ptr<Pass> mlir::triton::gpu::createReduceDataDuplicationPass() {
diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp
index f456d36a6..a1dac2b72 100644
--- a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp
+++ b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp
@@ -45,7 +45,7 @@ public:
return;
ModuleOp mod = getOperation();
mod.walk([&](Operation *op) {
- if (!isa<tt::DotOp, ttng::DotAsyncOp>(op))
+ if (!isa<tt::DotOp, ttng::DotAsyncOp, ttg::SparseDotOp>(op))
return WalkResult::advance();
OpBuilder builder(op);
auto a = op->getOperand(0);
@@ -80,7 +80,7 @@ private:
static DenseSet<std::pair<Operation *, unsigned>> trace;
auto op = operand.getDefiningOp();
// avoid redundant insertion
- if (op && isa<tt::DotOp, ttng::DotAsyncOp>(op))
+ if (op && isa<tt::DotOp, ttng::DotAsyncOp, ttg::SparseDotOp>(op))
return false;
// reach convertlayout
if (op && isa<ttg::LocalAllocOp>(op) &&