tensorflow/tensorflow

View on GitHub
third_party/triton/xla_extensions/sparse_dot_passes.patch

Summary

Maintainability
Test Coverage
diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
index 4aa2712ec..16a6253d7 100644
--- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
+++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp
@@ -279,6 +279,89 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
   }
 };
 
+struct TritonSparseDotPattern
+    : public OpConversionPattern<triton::gpu::SparseDotOp> {
+  using OpConversionPattern<triton::gpu::SparseDotOp>::OpConversionPattern;
+
+  LogicalResult matchAndRewrite(
+      triton::gpu::SparseDotOp op, OpAdaptor adaptor,
+      ConversionPatternRewriter &rewriter) const override {
+    RankedTensorType origType = cast<RankedTensorType>(op.getType());
+    auto origShape = origType.getShape();
+    auto typeConverter = getTypeConverter<TritonGPUTypeConverter>();
+    int numWarps = typeConverter->getNumWarps();
+    int threadsPerWarp = typeConverter->getThreadsPerWarp();
+    int numCTAs = typeConverter->getNumCTAs();
+
+    auto rank = origShape.size();
+    auto numElements = product<int64_t>(origShape);
+    SmallVector<unsigned> retSizePerThread(rank, 1);
+    if (numElements / (numWarps * threadsPerWarp) >= 4) {
+      retSizePerThread[rank - 1] = 2;
+      retSizePerThread[rank - 2] = 2;
+    }
+    if (numElements / (numWarps * threadsPerWarp) >= 16) {
+      retSizePerThread[rank - 1] = 4;
+      retSizePerThread[rank - 2] = 4;
+    }
+    SmallVector<unsigned> retOrder(rank);
+    for (unsigned i = 0; i < rank; ++i)
+      retOrder[i] = rank - 1 - i;
+    Attribute dEncoding = triton::gpu::BlockedEncodingAttr::get(
+        getContext(), origShape, retSizePerThread, retOrder, numWarps,
+        threadsPerWarp, numCTAs);
+    RankedTensorType retType =
+        RankedTensorType::get(origShape, origType.getElementType(), dEncoding);
+
+    // a & b must be of smem layout
+    auto aType = cast<RankedTensorType>(adaptor.getA().getType());
+    auto bType = cast<RankedTensorType>(adaptor.getB().getType());
+    Type aEltType = aType.getElementType();
+    Type bEltType = bType.getElementType();
+    Attribute aEncoding = aType.getEncoding();
+    Attribute bEncoding = bType.getEncoding();
+    if (!aEncoding || !bEncoding)
+      return failure();
+    Value a = adaptor.getA();
+    Value b = adaptor.getB();
+    Value c = adaptor.getC();
+    if (!isa<triton::gpu::DotOperandEncodingAttr>(aEncoding)) {
+      Attribute encoding = triton::gpu::DotOperandEncodingAttr::get(
+          getContext(), 0, dEncoding, aEltType);
+      auto dstType =
+          RankedTensorType::get(aType.getShape(), aEltType, encoding);
+      a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), dstType, a);
+    }
+    if (!isa<triton::gpu::DotOperandEncodingAttr>(bEncoding)) {
+      Attribute encoding = triton::gpu::DotOperandEncodingAttr::get(
+          getContext(), 1, dEncoding, bEltType);
+      auto dstType =
+          RankedTensorType::get(bType.getShape(), bEltType, encoding);
+      b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
+    }
+    c = rewriter.create<triton::gpu::ConvertLayoutOp>(c.getLoc(), retType, c);
+
+    // aMeta must be of smem layout
+    auto aMetaType = cast<RankedTensorType>(adaptor.getAMeta().getType());
+    Attribute aMetaEncoding = aMetaType.getEncoding();
+    if (!aMetaEncoding) return failure();
+    Value aMeta = adaptor.getAMeta();
+    if (!isa<triton::gpu::SparseDotMetaEncodingAttr>(aMetaEncoding)) {
+      Attribute encoding =
+          triton::gpu::SparseDotMetaEncodingAttr::get(getContext(), dEncoding);
+      auto dstType = RankedTensorType::get(
+          aMetaType.getShape(), aMetaType.getElementType(), encoding);
+      aMeta = rewriter.create<triton::gpu::ConvertLayoutOp>(aMeta.getLoc(),
+                                                            dstType, aMeta);
+    }
+
+    addNamedAttrs(rewriter.replaceOpWithNewOp<triton::gpu::SparseDotOp>(
+                      op, retType, a, b, c, aMeta),
+                  adaptor.getAttributes());
+    return success();
+  }
+};
+
 struct TritonCatPattern : public OpConversionPattern<triton::CatOp> {
   using OpConversionPattern::OpConversionPattern;
 
@@ -553,6 +636,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
       GenericOpPattern<triton::ExperimentalDescriptorStoreOp>,
       GenericOpPattern<triton::CallOp>, TritonFuncOpPattern>(typeConverter,
                                                              context);
+  patterns.insert<TritonSparseDotPattern>(typeConverter, context);
 }
 
 //
@@ -794,6 +878,12 @@ public:
     mod->setAttr(AttrTargetName,
                  StringAttr::get(context, this->target.getValue()));
 
+    // Only transform sparse dot op with undefined layout.
+    target.addDynamicallyLegalOp<triton::gpu::SparseDotOp>(
+        [](triton::gpu::SparseDotOp op) {
+          return op.getAMeta().getType().getEncoding() != nullptr;
+        });
+
     if (failed(applyPartialConversion(mod, target, std::move(patterns))))
       return signalPassFailure();
 
diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
index 098ee85e4..0516fc56f 100644
--- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
+++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
@@ -44,8 +44,9 @@ static int getMMAVersionSafe(int computeCapability, tt::DotOp op) {
   return 0;
 }
 
+template <typename DotType>
 SmallVector<unsigned>
-warpsPerTileV2(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps) {
+warpsPerTileV2(DotType dotOp, const ArrayRef<int64_t> shape, int numWarps) {
   auto rank = shape.size();
   // Early exit for batched matmul
   if (rank == 3)
@@ -58,8 +59,8 @@ warpsPerTileV2(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps) {
   auto slices = multiRootGetSlice(dotOp, {filter}, {filter});
   bool hasChainedDot = false;
   for (Operation *op : slices) {
-    if (isa<tt::DotOp>(op) && (op != dotOp)) {
-      auto chainedDot = cast<tt::DotOp>(op);
+    if (isa<DotType>(op) && (op != dotOp)) {
+      auto chainedDot = cast<DotType>(op);
       auto resTy = chainedDot.getResult().getType();
       if (resTy.getRank() != rank) {
         continue;
@@ -103,12 +104,13 @@ warpsPerTileV2(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps) {
   return ret;
 }
 
-SmallVector<unsigned, 2>
-warpsPerTileV3(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps,
-               const SmallVector<unsigned, 3> &instrShape) {
+template <typename DotType>
+SmallVector<unsigned, 2> warpsPerTileV3(
+    DotType dotOp, const ArrayRef<int64_t> shape, int numWarps,
+    const SmallVector<unsigned, 3> &instrShape) {
   SetVector<Operation *> slices;
   mlir::getForwardSlice(dotOp.getResult(), &slices);
-  if (llvm::find_if(slices, [](Operation *op) { return isa<tt::DotOp>(op); }) !=
+  if (llvm::find_if(slices, [](Operation *op) { return isa<DotType>(op); }) !=
       slices.end())
     return {(unsigned)numWarps, 1};
 
@@ -178,9 +180,10 @@ public:
       : mlir::RewritePattern(tt::DotOp::getOperationName(), 2, context),
         computeCapability(computeCapability) {}
 
-  static SmallVector<unsigned, 3>
-  getWarpsPerTile(tt::DotOp dotOp, const ArrayRef<int64_t> shape, int version,
-                  int numWarps, const SmallVector<unsigned, 3> &instrShape) {
+  template <typename DotType>
+  static SmallVector<unsigned, 3> getWarpsPerTile(
+      DotType dotOp, const ArrayRef<int64_t> shape, int version, int numWarps,
+      const SmallVector<unsigned, 3> &instrShape) {
     switch (version) {
     case 2:
       return warpsPerTileV2(dotOp, shape, numWarps);
@@ -335,6 +338,98 @@ public:
     return success();
   }
 };
+
+class SparseBlockedToMMA : public mlir::RewritePattern {
+ public:
+  using SparseDotOp = mlir::triton::gpu::SparseDotOp;
+  using SparseDotMetaEncodingAttr =
+      mlir::triton::gpu::SparseDotMetaEncodingAttr;
+
+  SparseBlockedToMMA(mlir::MLIRContext *context, int computeCapability)
+      : mlir::RewritePattern(SparseDotOp::getOperationName(), 2, context),
+        computeCapability(computeCapability) {}
+
+  mlir::LogicalResult matchAndRewrite(
+      mlir::Operation *op, mlir::PatternRewriter &rewriter) const override {
+    auto dotOp = cast<SparseDotOp>(op);
+    auto ctx = op->getContext();
+    Value a = dotOp.getA();
+    Value b = dotOp.getB();
+
+    // Check data-types and SM compatibility
+    RankedTensorType oldRetType = dotOp.getType();
+    if (!oldRetType.getEncoding() ||
+        isa<ttg::NvidiaMmaEncodingAttr>(oldRetType.getEncoding()))
+      return failure();
+
+    assert(computeCapability >= 80 &&
+           "SparseDot is supported on Ampere and higher");
+    int versionMajor = computeCapability < 90 ? 2 : 3;
+
+    // get MMA encoding for the given number of warps
+    auto retShapePerCTA = ttg::getShapePerCTA(oldRetType);
+    auto mod = op->getParentOfType<mlir::ModuleOp>();
+    int numWarps = ttg::TritonGPUDialect::getNumWarps(mod);
+    auto CTALayout = ttg::getCTALayout(oldRetType.getEncoding());
+
+    auto instrShape =
+        mmaVersionToInstrShape(versionMajor, retShapePerCTA,
+                               cast<TensorOrMemDesc>(a.getType()), numWarps);
+    auto warpsPerTile = BlockedToMMA::getWarpsPerTile(
+        dotOp, retShapePerCTA, versionMajor, numWarps, instrShape);
+    ttg::NvidiaMmaEncodingAttr mmaEnc =
+        ttg::NvidiaMmaEncodingAttr::get(ctx, versionMajor, /*versionMinor=*/0,
+                                        warpsPerTile, CTALayout, instrShape);
+    auto newRetType = RankedTensorType::get(
+        oldRetType.getShape(), oldRetType.getElementType(), mmaEnc);
+
+    // convert accumulator
+    auto oldAcc = dotOp.getOperand(2);
+    auto newAcc = rewriter.create<ttg::ConvertLayoutOp>(oldAcc.getLoc(),
+                                                        newRetType, oldAcc);
+
+    if (versionMajor == 2) {
+      // convert A operand
+      auto oldAType = cast<RankedTensorType>(a.getType());
+      auto newAEncoding = ttg::DotOperandEncodingAttr::get(
+          ctx, 0, mmaEnc, oldAType.getElementType());
+      auto newAType = RankedTensorType::get(
+          oldAType.getShape(), oldAType.getElementType(), newAEncoding);
+      a = rewriter.create<ttg::ConvertLayoutOp>(a.getLoc(), newAType, a);
+
+      // convert B operand
+      auto oldBType = cast<RankedTensorType>(b.getType());
+      auto newBEncoding = ttg::DotOperandEncodingAttr::get(
+          ctx, 1, mmaEnc, oldBType.getElementType());
+      auto newBType = RankedTensorType::get(
+          oldBType.getShape(), oldBType.getElementType(), newBEncoding);
+      b = rewriter.create<ttg::ConvertLayoutOp>(b.getLoc(), newBType, b);
+    } else {
+      a = BlockedToMMA::getMMAv3Operand(a, rewriter, 0);
+      b = BlockedToMMA::getMMAv3Operand(b, rewriter, 1);
+    }
+
+    // convert metadata
+    Value meta = dotOp.getAMeta();
+    auto oldMetaType = cast<RankedTensorType>(meta.getType());
+    auto newMetaType = RankedTensorType::get(
+        oldMetaType.getShape(), oldMetaType.getElementType(),
+        SparseDotMetaEncodingAttr::get(ctx, mmaEnc));
+    meta =
+        rewriter.create<ttg::ConvertLayoutOp>(meta.getLoc(), newMetaType, meta);
+
+    // convert dot instruction
+    auto newDot = rewriter.create<SparseDotOp>(dotOp.getLoc(), newRetType, a, b,
+                                               newAcc, meta);
+
+    rewriter.replaceOpWithNewOp<ttg::ConvertLayoutOp>(op, oldRetType,
+                                                      newDot.getResult());
+    return success();
+  }
+
+ private:
+  int computeCapability;
+};
 } // namespace
 
 static Value promoteOperand(OpBuilder &builder, Location loc, Value operand,
@@ -394,6 +489,7 @@ public:
 
     mlir::RewritePatternSet patterns(context);
     patterns.add<::BlockedToMMA>(context, computeCapability);
+    patterns.add<::SparseBlockedToMMA>(context, computeCapability);
     if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) {
       signalPassFailure();
     }
diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp
index 97ca6a840..f0ef124ff 100644
--- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp
+++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp
@@ -188,6 +188,10 @@ public:
   }
 };
 
+static bool isDotOp(Operation* op) {
+  return isa<tt::DotOp, ttg::SparseDotOp>(op);
+}
+
 static bool isMMAv3Dot(Operation *op) {
   auto dot = dyn_cast<tt::DotOp>(op);
   if (!dot)
@@ -399,19 +403,28 @@ getSharedEncIfAllUsersAreDotEnc(Value val) {
     } else {
       if (!isa<ttg::LocalLoadOp, ttg::ConvertLayoutOp>(user))
         return std::nullopt;
-      auto dotOpEnc = dyn_cast<ttg::DotOperandEncodingAttr>(
-          cast<TensorOrMemDesc>(user->getResult(0).getType()).getEncoding());
-      if (!dotOpEnc)
+      auto enc =
+          cast<TensorOrMemDesc>(user->getResult(0).getType()).getEncoding();
+      if (isa<ttg::DotOperandEncodingAttr>(enc)) {
+        auto srcTy = cast<TensorOrMemDesc>(val.getType());
+        auto CTALayout = ttg::getCTALayout(srcTy.getEncoding());
+        auto order = ttg::getOrder(srcTy.getEncoding());
+        unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth();
+        tempAttr = ttg::SharedEncodingAttr::get(
+            val.getContext(), cast<ttg::DotOperandEncodingAttr>(enc),
+            srcTy.getShape(), ttg::getOrder(srcTy.getEncoding()),
+            ttg::getCTALayout(srcTy.getEncoding()),
+            srcTy.getElementType().getIntOrFloatBitWidth(),
+            /*needTrans=*/false);
+      } else if (isa<ttg::SparseDotMetaEncodingAttr>(enc)) {
+        auto srcTy = cast<TensorOrMemDesc>(val.getType());
+        tempAttr = ttg::SharedEncodingAttr::get(
+            val.getContext(), /*vec=*/1, /*perPhase=*/1, /*maxPhase=*/1,
+            ttg::getOrder(srcTy.getEncoding()),
+            ttg::getCTALayout(srcTy.getEncoding()));
+      } else {
         return std::nullopt;
-      auto srcTy = cast<TensorOrMemDesc>(val.getType());
-      auto CTALayout = ttg::getCTALayout(srcTy.getEncoding());
-      auto order = ttg::getOrder(srcTy.getEncoding());
-      unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth();
-      tempAttr = ttg::SharedEncodingAttr::get(
-          val.getContext(), dotOpEnc, srcTy.getShape(),
-          ttg::getOrder(srcTy.getEncoding()),
-          ttg::getCTALayout(srcTy.getEncoding()),
-          srcTy.getElementType().getIntOrFloatBitWidth(), /*needTrans=*/false);
+      }
     }
     // Check that the shared encodings needed by the users are compatible.
     if (!tempAttr || (attr != nullptr && attr != tempAttr))
@@ -518,7 +531,7 @@ loadOpsToIndirectionLevelAndUse(scf::ForOp forOp) {
       };
 
   for (Operation &op : forOp.getBody()->without_terminator()) {
-    if (!isa<tt::DotOp>(op))
+    if (!isDotOp(&op))
       continue;
     seen.clear();
     dfs(&op, 0, &op);
@@ -595,7 +608,8 @@ assignMemoryLayouts(llvm::SmallVector<std::tuple<Operation *, int, Operation *>>
         continue;
     }
 
-    if (auto dot = dyn_cast<tt::DotOp>(use)) {
+    if (isDotOp(use)) {
+      auto dot = dyn_cast<tt::DotOp>(use);
       loadInfo.usedByDot = true;
       if (loadIsMMAv3(op)) {
         loadInfo.loadIsMMAV3 = true;
@@ -614,7 +628,7 @@ assignMemoryLayouts(llvm::SmallVector<std::tuple<Operation *, int, Operation *>>
         // The codegen bug is caught by an assertion, so if you think you've
         // fixed it, feel free to delete this code and see if the assert still
         // fails.  :)
-        if (!loadInfo.sharedEncoding) {
+        if (dot && !loadInfo.sharedEncoding) {
           if (auto dotEnc = dyn_cast<ttg::NvidiaMmaEncodingAttr>(
                   dot.getResult().getType().getEncoding())) {
             auto loadTy = cast<RankedTensorType>(op->getResultTypes()[0]);
diff --git a/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp b/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp
index 2211df31b..ee5ff44d8 100644
--- a/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp
+++ b/lib/Dialect/TritonGPU/Transforms/ReduceDataDuplication.cpp
@@ -37,6 +37,10 @@ public:
       auto srcEncoding = srcType.getEncoding();
       if (isa<triton::gpu::SharedEncodingAttr>(srcEncoding))
         return;
+      if (isa<triton::gpu::SparseDotMetaEncodingAttr>(dstType.getEncoding())) {
+        replaceSparseMetaEncoding(cvtOp);
+        return;
+      }
       auto dstDotOp =
           dyn_cast<triton::gpu::DotOperandEncodingAttr>(dstType.getEncoding());
       if (!dstDotOp)
@@ -83,6 +87,27 @@ public:
       cvtOp.erase();
     });
   }
+
+ private:
+  void replaceSparseMetaEncoding(triton::gpu::ConvertLayoutOp cvtOp) {
+    auto srcType = cast<RankedTensorType>(cvtOp.getOperand().getType());
+    auto srcEncoding = srcType.getEncoding();
+    auto sharedLayout = triton::gpu::SharedEncodingAttr::get(
+        cvtOp.getContext(), 8, 1, 1, triton::gpu::getOrder(srcEncoding),
+        triton::gpu::getCTALayout(srcEncoding));
+
+    auto dstType = cast<RankedTensorType>(cvtOp.getType());
+    auto tmpType = triton::MemDescType::get(
+        dstType.getShape(), dstType.getElementType(), sharedLayout);
+
+    OpBuilder builder(cvtOp);
+    auto tmp = builder.create<triton::gpu::LocalAllocOp>(
+        cvtOp.getLoc(), tmpType, cvtOp.getSrc());
+    auto newConvert = builder.create<triton::gpu::LocalLoadOp>(
+        cvtOp.getLoc(), dstType, tmp);
+    cvtOp.replaceAllUsesWith(newConvert.getResult());
+    cvtOp.erase();
+  }
 };
 
 std::unique_ptr<Pass> mlir::triton::gpu::createReduceDataDuplicationPass() {
diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp
index f456d36a6..a1dac2b72 100644
--- a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp
+++ b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp
@@ -45,7 +45,7 @@ public:
       return;
     ModuleOp mod = getOperation();
     mod.walk([&](Operation *op) {
-      if (!isa<tt::DotOp, ttng::DotAsyncOp>(op))
+      if (!isa<tt::DotOp, ttng::DotAsyncOp, ttg::SparseDotOp>(op))
         return WalkResult::advance();
       OpBuilder builder(op);
       auto a = op->getOperand(0);
@@ -80,7 +80,7 @@ private:
     static DenseSet<std::pair<Operation *, unsigned>> trace;
     auto op = operand.getDefiningOp();
     // avoid redundant insertion
-    if (op && isa<tt::DotOp, ttng::DotAsyncOp>(op))
+    if (op && isa<tt::DotOp, ttng::DotAsyncOp, ttg::SparseDotOp>(op))
       return false;
     // reach convertlayout
     if (op && isa<ttg::LocalAllocOp>(op) &&