diff --git a/docs/PTO_IR_manual.md b/docs/PTO_IR_manual.md index e19bb6229..d66d83a14 100644 --- a/docs/PTO_IR_manual.md +++ b/docs/PTO_IR_manual.md @@ -2451,6 +2451,7 @@ pto.tprelu ins(%a, %slopes, %c : !pto.tile_buf, : , ) + outs( : ) +``` + +**Constraints & Verification:** + +- **Implementation checks (A2A3)** + - `src` and `dst` must use `loc=vec`. + - `src` and `dst` must have the same shape and valid shape. + - `scalar` type must exactly match the element type of `src`. + - `src` element type must be `f16` or `f32`. + - `dst` element type must be `f16` or `f32`. + - Element types must either match, or use the widening form `src=f16`, `dst=f32`. +- **Implementation checks (A5)** + - `src` and `dst` must use `loc=vec`. + - `src` and `dst` must have the same shape and valid shape. + - `scalar` type must exactly match the element type of `src`. + - `src` element type must be `f16`, `bf16`, or `f32`. + - `dst` element type must be `f16`, `bf16`, or `f32`. + - Element types must either match, or use the widening form `src=f16`, `dst=f32`. + +**Hardware Mapping:** + +- Executes on the **Vector pipeline** (`PIPE_V`) +- Operates on data in the **VEC (UB)** memory space + +**Basic Example:** + +```mlir +%alpha = arith.constant 2.0 : f16 +pto.taxpy ins(%src, %alpha : !pto.tile_buf, f16) + outs(%dst : !pto.tile_buf) +``` + +--- + ##### `pto.tdivs` - Elementwise Division with Scalar **Summary:** Divides every element of a tile buffer by a scalar, or divides a scalar by every element. @@ -3726,10 +3791,13 @@ Reduce along rows or columns of a tile. All execute on the **Vector pipeline** ( |----|----------| | `pto.trowsum` | `dst[i,0] = sum_j src[i,j]` | | `pto.trowmax` | `dst[i,0] = max_j src[i,j]` | +| `pto.trowargmax` | `dst[i,0] = argmax_j src[i,j]` (requires tmp) | | `pto.trowmin` | `dst[i,0] = min_j src[i,j]` (requires tmp) | +| `pto.trowargmin` | `dst[i,0] = argmin_j src[i,j]` (requires tmp) | | `pto.tcolsum` | `dst[0,j] = sum_i src[i,j]` (requires tmp, optional isBinary) | | `pto.tcolmax` | `dst[0,j] = max_i src[i,j]` | | `pto.tcolmin` | `dst[0,j] = min_i src[i,j]` | +| `pto.thistogram` | `dst[i, idx[i,0]] = histogram_update(dst[i, idx[i,0]], src[i,:])` (A5 only) | --- @@ -3874,6 +3942,73 @@ pto.trowmax ins(%src : !pto.tile_buf, : , ) + outs( : ) +``` + +**Constraints & Verification:** + +- **Implementation checks (A2A3)** + - `src`, `tmp`, and `dst` must use `loc=vec`. + - `src` must use ND-style tile layout (`blayout=row_major`, `slayout=none_box`). + - `tmp` must have the same shape, valid shape, and element type as `src`. + - `dst` must use `slayout=none_box` and either: + - a DN-style column vector tile (`blayout=col_major`, `cols=1`), or + - a legacy ND-style tile with `valid column == 1`. + - `src` element type must be `f16` or `f32`. + - `dst` element type must be `i32` or `ui32`. + - Runtime valid checks: + - `src valid row != 0` and `src valid column != 0` + - `src valid row == dst valid row` + - `dst valid column == 1` +- **Implementation checks (A5)** + - Same constraints as A2/A3. + +**Hardware Mapping:** + +- Executes on the **Vector pipeline** (`PIPE_V`) +- Operates on data in the **VEC (UB)** memory space + +**Basic Example:** + +```mlir +pto.trowargmax ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) +``` + +--- + ##### `pto.trowmin` - Row-wise Min Reduction **Summary:** Reduces each row by taking the minimum across columns. Requires a temporary buffer. @@ -3950,6 +4085,144 @@ pto.trowmin ins(%src, %tmp : !pto.tile_buf, : , ) + outs( : ) +``` + +**Constraints & Verification:** + +- **Implementation checks (A2A3)** + - `src`, `tmp`, and `dst` must use `loc=vec`. + - `src` must use ND-style tile layout (`blayout=row_major`, `slayout=none_box`). + - `tmp` must have the same shape, valid shape, and element type as `src`. + - `dst` must use `slayout=none_box` and either: + - a DN-style column vector tile (`blayout=col_major`, `cols=1`), or + - a legacy ND-style tile with `valid column == 1`. + - `src` element type must be `f16` or `f32`. + - `dst` element type must be `i32` or `ui32`. + - Runtime valid checks: + - `src valid row != 0` and `src valid column != 0` + - `src valid row == dst valid row` + - `dst valid column == 1` +- **Implementation checks (A5)** + - Same constraints as A2/A3. + +**Hardware Mapping:** + +- Executes on the **Vector pipeline** (`PIPE_V`) +- Operates on data in the **VEC (UB)** memory space + +**Basic Example:** + +```mlir +pto.trowargmin ins(%src, %tmp : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) +``` + +--- + +##### `pto.thistogram` - Row-wise Histogram Accumulation + +**Summary:** Updates a 256-bin histogram row in `dst` using each source row and a per-row index selector. This op is only supported on A5. + +**Semantics:** + +``` +For each row i: + bin = idx[i, 0] + dst[i, bin] = histogram_update(dst[i, bin], src[i, :]) +``` + +The exact accumulation performed inside the selected bin is target-defined by the hardware `THISTOGRAM` intrinsic. The `isMSB` attribute selects the intrinsic mode. + +**Arguments:** + +| Name | Type | Description | +|------|------|-------------| +| `src` | `pto.tile_buf` | Source tile buffer, one logical row per histogram update | +| `idx` | `pto.tile_buf` | Per-row bin selector tile | +| `dst` | `pto.tile_buf` | Destination histogram tile | +| `isMSB` | `BoolAttr` (default: `true`) | Selects the `THISTOGRAM<...>` intrinsic mode | + +**Results:** None. Writes into `dst` via DPS pattern. + +**Assembly Format:** + +``` +pto.thistogram ins(, : , ) + outs( : ) + {isMSB = true} +``` + +**Constraints & Verification:** + +- **Implementation checks (A2A3)** + - Not supported. +- **Implementation checks (A5)** + - `src`, `idx`, and `dst` must all be `tile_buf` values in `loc=vec`. + - `src` and `dst` must use `row_major + none_box` layout. + - `idx` must use DN-style layout (`col_major + none_box`). + - `src` element type must be `ui16`. + - `idx` element type must be `ui8`. + - `dst` element type must be `ui32`. + - `src`, `idx`, and `dst` must all be rank-2 tiles. + - `idx` rows and valid rows must match `src`. + - `dst` rows and valid rows must match `src`. + - `idx` must have exactly one column. + - `dst` shape[1] and valid_shape[1] must be at least `256`. + +**Hardware Mapping:** + +- Executes on the **Vector pipeline** (`PIPE_V`) +- Operates on data in the **VEC (UB)** memory space + +**Basic Example:** + +```mlir +pto.thistogram ins(%src, %idx : !pto.tile_buf, + !pto.tile_buf) + outs(%dst : !pto.tile_buf) {isMSB = false} +``` + +--- + ##### `pto.tcolsum` - Column-wise Sum Reduction **Summary:** Reduces each column by summing across rows. Requires a temporary buffer. @@ -6677,6 +6950,61 @@ pto.tsetval ins(%off, %val : index, f16) outs(%dst : !pto.tile_buf : ) + outs( : ) +``` + +**Constraints & Verification:** + +- **Implementation checks (A2A3)** + - Not supported. +- **Implementation checks (A5)** + - `src` must be a valid `pto.tile_buf`. + - `dst` must be a valid `pto.tile_buf` in `loc=scaling`. + - `dst` must have the same rank, shape, and valid shape as `src`. + - `dst` must satisfy the target-specific scaling-tile compatibility rules for `src`. + +**Hardware Mapping:** + +- Executes on the **Scalar pipeline** (`PIPE_S`) + +**Basic Example:** + +```mlir +pto.tget_scale_addr ins(%src : !pto.tile_buf) + outs(%scale : !pto.tile_buf) +``` + +--- + ##### `pto.tmov.fp` - Move/Convert with Scaling Tile **Summary:** Legacy dedicated fp-TMOV op. New code should prefer `pto.tmov` with an `fp` operand, which lowers to the same `TMOV_FP` / fp-parameterized `TMOV` APIs. diff --git a/include/PTO/IR/PTOOps.td b/include/PTO/IR/PTOOps.td index 193fe44c5..1e94c4f55 100644 --- a/include/PTO/IR/PTOOps.td +++ b/include/PTO/IR/PTOOps.td @@ -2248,6 +2248,64 @@ def TGetValOp : PTO_TOp<"tgetval", [ }]; } +def THistogramOp : PTO_TOp<"thistogram", [ + PTO_DpsInitOpInterface, + OpPipeInterface, + DeclareOpInterfaceMethods +]> { + let summary = "THISTOGRAM: accumulate per-row 256-bin histograms."; + + let arguments = (ins + PTODpsType:$src, + PTODpsType:$idx, + PTODpsType:$dst, + DefaultValuedOptionalAttr:$isMSB + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + `ins` `(` $src `,` $idx `:` qualified(type($src)) `,` qualified(type($idx)) `)` + `outs` `(` $dst `:` qualified(type($dst) ) `)` + attr-dict + }]; + + let extraClassDeclaration = [{ + ::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_V; } + ::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); } + }]; +} + +def TGetScaleAddrOp : PTO_TOp<"tget_scale_addr", [ + PTO_DpsInitOpInterface, + OpPipeInterface, + DeclareOpInterfaceMethods +]> { + let summary = "Bind a scaling tile to the scaled address of a source tile."; + + let arguments = (ins + PTODpsType:$src, + PTODpsType:$dst + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + `ins` `(` $src `:` qualified(type($src)) `)` + `outs` `(` $dst `:` qualified(type($dst) ) `)` + attr-dict + }]; + + let extraClassDeclaration = [{ + ::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_S; } + ::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); } + }]; +} + // ---- tile-world TOp version (with 't') ---- // pto.mscatter ins(%src, %idx) outs(%mem) [ ...] @@ -2421,6 +2479,35 @@ def TAddSOp : PTO_TOp<"tadds", [ }]; } +def TAxpyOp : PTO_TOp<"taxpy", [ + PTO_DpsInitOpInterface, + OpPipeInterface, + DeclareOpInterfaceMethods +]> { + let summary = "TAXPY: dst += src * scalar."; + + let arguments = (ins + PTODpsType:$src, + ScalarType:$scalar, + PTODpsType:$dst + ); + + let results = (outs); + + let assemblyFormat = [{ + `ins` `(` $src `,` $scalar `:` qualified(type($src)) `,` type($scalar) `)` + `outs` `(` $dst `:` qualified(type($dst) ) `)` + attr-dict + }]; + + let hasVerifier = 1; + + let extraClassDeclaration = [{ + ::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_V; } + ::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); } + }]; +} + def TAddSCOp : PTO_TOp<"taddsc", [ PTO_DpsInitOpInterface, OpPipeInterface, @@ -4666,6 +4753,35 @@ def TRowMaxOp: PTO_TOp<"trowmax", [ ::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); } }]; } + +def TRowArgMaxOp: PTO_TOp<"trowargmax", [ + PTO_DpsInitOpInterface, + OpPipeInterface, + DeclareOpInterfaceMethods +]> { + let summary = "TROWARGMAX: Reduce each row to the column index of the maximum element."; + + let arguments = (ins + PTODpsType:$src, + PTODpsType:$tmp, + PTODpsType:$dst + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + `ins` `(` $src `,` $tmp `:` qualified(type($src)) `,` qualified(type($tmp)) `)` + `outs` `(` $dst `:` qualified(type($dst) ) `)` + attr-dict + }]; + + let extraClassDeclaration = [{ + ::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_V; } + ::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); } + }]; +} //===----------------------------------------------------------------------===// // PTOOps.td (add TROWMIN TBDPS/tile buffer op) //===----------------------------------------------------------------------===// @@ -4698,6 +4814,35 @@ def TRowMinOp: PTO_TOp<"trowmin", [ ::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); } }]; } + +def TRowArgMinOp: PTO_TOp<"trowargmin", [ + PTO_DpsInitOpInterface, + OpPipeInterface, + DeclareOpInterfaceMethods +]> { + let summary = "TROWARGMIN: Reduce each row to the column index of the minimum element."; + + let arguments = (ins + PTODpsType:$src, + PTODpsType:$tmp, + PTODpsType:$dst + ); + + let results = (outs); + + let hasVerifier = 1; + + let assemblyFormat = [{ + `ins` `(` $src `,` $tmp `:` qualified(type($src)) `,` qualified(type($tmp)) `)` + `outs` `(` $dst `:` qualified(type($dst) ) `)` + attr-dict + }]; + + let extraClassDeclaration = [{ + ::mlir::pto::PIPE getPipe() { return ::mlir::pto::PIPE::PIPE_V; } + ::mlir::MutableOperandRange getDpsInitsMutable() { return getDstMutable(); } + }]; +} //===----------------------------------------------------------------------===// // PTOOps.td (add TROWSUM TBDPS/tile buffer op) //===----------------------------------------------------------------------===// diff --git a/lib/PTO/IR/PTO.cpp b/lib/PTO/IR/PTO.cpp index a419713ce..1bece6f84 100644 --- a/lib/PTO/IR/PTO.cpp +++ b/lib/PTO/IR/PTO.cpp @@ -2680,6 +2680,63 @@ LogicalResult pto::TAddSOp::verify() { }; return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); } + +LogicalResult pto::TAxpyOp::verify() { + auto verifyCommon = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyVecTileCommon(*this, srcTy, "src")) || + failed(verifyVecTileCommon(*this, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameValidShape(*this, srcTy, dstTy, "src", "dst"))) + return failure(); + + Type scalarTy = getScalar().getType(); + Type srcElem = getElemTy(srcTy); + if (scalarTy != srcElem) + return emitOpError("expects scalar type to match src element type"); + if (getShapeVec(srcTy) != getShapeVec(dstTy)) + return emitOpError("expects src and dst to have the same shape"); + return success(); + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + if (failed(verifyCommon())) + return failure(); + Type srcElem = getElemTy(getSrc().getType()); + Type dstElem = getElemTy(getDst().getType()); + bool sameType = srcElem == dstElem; + bool widenF16ToF32 = srcElem.isF16() && dstElem.isF32(); + if (!(sameType || widenF16ToF32)) + return emitOpError( + "expects dst/src element types to match, or dst=f32 and src=f16"); + if (!(dstElem.isF16() || dstElem.isF32())) + return emitOpError("expects A2/A3 taxpy dst element type to be f16/f32"); + if (!(srcElem.isF16() || srcElem.isF32())) + return emitOpError("expects A2/A3 taxpy src element type to be f16/f32"); + return success(); + }; + + auto verifyA5 = [&]() -> LogicalResult { + if (failed(verifyCommon())) + return failure(); + Type srcElem = getElemTy(getSrc().getType()); + Type dstElem = getElemTy(getDst().getType()); + bool sameType = srcElem == dstElem; + bool widenF16ToF32 = srcElem.isF16() && dstElem.isF32(); + if (!(sameType || widenF16ToF32)) + return emitOpError( + "expects dst/src element types to match, or dst=f32 and src=f16"); + if (!(dstElem.isF16() || dstElem.isF32() || dstElem.isBF16())) + return emitOpError("expects A5 taxpy dst element type to be f16/bf16/f32"); + if (!(srcElem.isF16() || srcElem.isF32() || srcElem.isBF16())) + return emitOpError("expects A5 taxpy src element type to be f16/bf16/f32"); + return success(); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + LogicalResult pto::TAddSCOp::verify() { if (shouldBypassDecodedMemrefVerifier(getOperation())) return success(); @@ -5289,6 +5346,104 @@ LogicalResult TGetValOp::verify() { return success(); } +LogicalResult THistogramOp::verify() { + auto isSignlessOrUnsignedInt = [](Type ty, unsigned width) { + auto it = dyn_cast(ty); + return it && it.getWidth() == width && (it.isSignless() || it.isUnsigned()); + }; + + auto verifyA2A3 = [&]() -> LogicalResult { + return emitOpError("thistogram is only supported on A5"); + }; + auto verifyA5 = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type idxTy = getIdx().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyTileBufCommon(*this, srcTy, "src")) || + failed(verifyTileBufCommon(*this, idxTy, "idx")) || + failed(verifyTileBufCommon(*this, dstTy, "dst"))) + return failure(); + + auto srcSpace = getPTOMemorySpaceEnum(srcTy); + auto idxSpace = getPTOMemorySpaceEnum(idxTy); + auto dstSpace = getPTOMemorySpaceEnum(dstTy); + if (!srcSpace || *srcSpace != pto::AddressSpace::VEC) + return emitOpError("expects src to be in the vec address space"); + if (!idxSpace || *idxSpace != pto::AddressSpace::VEC) + return emitOpError("expects idx to be in the vec address space"); + if (!dstSpace || *dstSpace != pto::AddressSpace::VEC) + return emitOpError("expects dst to be in the vec address space"); + + auto srcTB = dyn_cast(srcTy); + auto idxTB = dyn_cast(idxTy); + auto dstTB = dyn_cast(dstTy); + if (!srcTB || !idxTB || !dstTB) + return emitOpError("expects src, idx, and dst to be tile_buf types"); + + if (srcTB.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor) || + srcTB.getSLayoutValueI32() != static_cast(pto::SLayout::NoneBox)) + return emitOpError("expects src to use row_major + none_box layout"); + if (dstTB.getBLayoutValueI32() != static_cast(pto::BLayout::RowMajor) || + dstTB.getSLayoutValueI32() != static_cast(pto::SLayout::NoneBox)) + return emitOpError("expects dst to use row_major + none_box layout"); + if (idxTB.getBLayoutValueI32() != static_cast(pto::BLayout::ColMajor) || + idxTB.getSLayoutValueI32() != static_cast(pto::SLayout::NoneBox)) + return emitOpError( + "expects idx to use DN layout (col_major + none_box)"); + + if (!isSignlessOrUnsignedInt(getElemTy(srcTy), 16)) + return emitOpError("expects src element type to be ui16"); + if (!isSignlessOrUnsignedInt(getElemTy(idxTy), 8)) + return emitOpError("expects idx element type to be ui8"); + if (!isSignlessOrUnsignedInt(getElemTy(dstTy), 32)) + return emitOpError("expects dst element type to be ui32"); + + auto srcShape = getShapeVec(srcTy); + auto idxShape = getShapeVec(idxTy); + auto dstShape = getShapeVec(dstTy); + auto srcValid = getValidShapeVec(srcTy); + auto idxValid = getValidShapeVec(idxTy); + auto dstValid = getValidShapeVec(dstTy); + if (srcShape.size() != 2 || idxShape.size() != 2 || dstShape.size() != 2 || + srcValid.size() != 2 || idxValid.size() != 2 || dstValid.size() != 2) + return emitOpError( + "expects src, idx, and dst to have rank-2 shape and valid_shape"); + + if (!hasCompatibleKnownExtent(srcShape[0], idxShape[0]) || + !hasCompatibleKnownExtent(srcValid[0], idxValid[0])) + return emitOpError("expects idx rows and valid rows to match src"); + if (!hasCompatibleKnownExtent(srcShape[0], dstShape[0]) || + !hasCompatibleKnownExtent(srcValid[0], dstValid[0])) + return emitOpError("expects dst rows and valid rows to match src"); + + if (!isKnownUnitExtent(idxShape[1]) || !isKnownUnitExtent(idxValid[1])) + return emitOpError("expects idx to have exactly one column"); + if (dstShape[1] != ShapedType::kDynamic && dstShape[1] < 256) + return emitOpError("expects dst shape[1] to be at least 256"); + if (dstValid[1] != ShapedType::kDynamic && dstValid[1] < 256) + return emitOpError("expects dst valid_shape[1] to be at least 256"); + return success(); + }; + + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + +LogicalResult TGetScaleAddrOp::verify() { + auto verifyA2A3 = [&]() -> LogicalResult { + return emitOpError("tget_scale_addr is only supported on A5"); + }; + auto verifyA5 = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyTileBufCommon(*this, srcTy, "src"))) + return failure(); + if (failed(verifyScaleTileMatchesOperand(*this, dstTy, srcTy, "dst", "src"))) + return failure(); + return success(); + }; + return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); +} + // ---- MScatterOp ---- LogicalResult MScatterOp::verify() { if (shouldBypassDecodedMemrefVerifier(getOperation())) @@ -7379,6 +7534,35 @@ mlir::LogicalResult mlir::pto::TRowMaxOp::verify() { return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); } +mlir::LogicalResult mlir::pto::TRowArgMaxOp::verify() { + auto verifyByArch = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type tmpTy = getTmp().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyRowReductionSrcLayout(*this, srcTy, "src")) || + failed(verifyVecTileCommon(*this, tmpTy, "tmp")) || + failed(verifyRowReductionDstLayout(*this, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameShapeAndElem(*this, srcTy, tmpTy, "src", "tmp")) || + failed(verifyTileBufSameValidShape(*this, srcTy, tmpTy, "src", "tmp"))) + return failure(); + if (failed(verifyRowReductionValidRegion(*this, srcTy, dstTy))) + return failure(); + + auto srcElem = getElemTy(srcTy).dyn_cast(); + if (!srcElem || (!srcElem.isF16() && !srcElem.isF32())) + return emitOpError("expects src element type to be f16 or f32"); + + auto dstInt = dyn_cast(getElemTy(dstTy)); + if (!dstInt || dstInt.getWidth() != 32 || + (!dstInt.isSignless() && !dstInt.isUnsigned())) + return emitOpError("expects dst element type to be i32 or ui32"); + return success(); + }; + + return dispatchVerifierByArch(getOperation(), verifyByArch, verifyByArch); +} + mlir::LogicalResult mlir::pto::TRowMinOp::verify() { auto verifyA2A3 = [&]() -> LogicalResult { @@ -7424,6 +7608,35 @@ mlir::LogicalResult mlir::pto::TRowMinOp::verify() { return dispatchVerifierByArch(getOperation(), verifyA2A3, verifyA5); } +mlir::LogicalResult mlir::pto::TRowArgMinOp::verify() { + auto verifyByArch = [&]() -> LogicalResult { + Type srcTy = getSrc().getType(); + Type tmpTy = getTmp().getType(); + Type dstTy = getDst().getType(); + if (failed(verifyRowReductionSrcLayout(*this, srcTy, "src")) || + failed(verifyVecTileCommon(*this, tmpTy, "tmp")) || + failed(verifyRowReductionDstLayout(*this, dstTy, "dst"))) + return failure(); + if (failed(verifyTileBufSameShapeAndElem(*this, srcTy, tmpTy, "src", "tmp")) || + failed(verifyTileBufSameValidShape(*this, srcTy, tmpTy, "src", "tmp"))) + return failure(); + if (failed(verifyRowReductionValidRegion(*this, srcTy, dstTy))) + return failure(); + + auto srcElem = getElemTy(srcTy).dyn_cast(); + if (!srcElem || (!srcElem.isF16() && !srcElem.isF32())) + return emitOpError("expects src element type to be f16 or f32"); + + auto dstInt = dyn_cast(getElemTy(dstTy)); + if (!dstInt || dstInt.getWidth() != 32 || + (!dstInt.isSignless() && !dstInt.isUnsigned())) + return emitOpError("expects dst element type to be i32 or ui32"); + return success(); + }; + + return dispatchVerifierByArch(getOperation(), verifyByArch, verifyByArch); +} + mlir::LogicalResult mlir::pto::TRowSumOp::verify() { auto verifyA2A3 = [&]() -> LogicalResult { @@ -9177,6 +9390,19 @@ void TGetValOp::getEffects( PTO_ADD_READ(getSrcMutable()); } +void THistogramOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_READ(getIdxMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + +void TGetScaleAddrOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + // TSETVAL: Write(dst) (single element update) void TSetValOp::getEffects( SmallVectorImpl> &effects) { @@ -9194,6 +9420,12 @@ PTO_DEFINE_BINARY_EFFECTS(TAddOp, getSrc0Mutable(), getSrc1Mutable(), getDstMuta PTO_DEFINE_TERNARY_EFFECTS(TAddCOp, getSrc0Mutable(), getSrc1Mutable(), getSrc2Mutable(), getDstMutable()) PTO_DEFINE_UNARY_EFFECTS(TAddSOp, getSrcMutable(), getDstMutable()) PTO_DEFINE_BINARY_EFFECTS(TAddSCOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) +void TAxpyOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_READ(getScalarMutable()); + PTO_ADD_WRITE(getDstMutable()); +} PTO_DEFINE_BINARY_EFFECTS(TAndOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) PTO_DEFINE_BINARY_EFFECTS(TConcatOp, getSrc0Mutable(), getSrc1Mutable(), getDstMutable()) @@ -9446,6 +9678,13 @@ void TRowMaxOp::getEffects( PTO_ADD_WRITE(getDstMutable()); } +void TRowArgMaxOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_WRITE(getTmpMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + void TRowMinOp::getEffects( SmallVectorImpl> &effects) { PTO_ADD_READ(getSrcMutable()); @@ -9453,6 +9692,13 @@ void TRowMinOp::getEffects( PTO_ADD_WRITE(getDstMutable()); } +void TRowArgMinOp::getEffects( + SmallVectorImpl> &effects) { + PTO_ADD_READ(getSrcMutable()); + PTO_ADD_WRITE(getTmpMutable()); + PTO_ADD_WRITE(getDstMutable()); +} + void TRowSumOp::getEffects( SmallVectorImpl> &effects) { PTO_ADD_READ(getSrcMutable()); diff --git a/lib/PTO/Transforms/PTOToEmitC.cpp b/lib/PTO/Transforms/PTOToEmitC.cpp index f091e3a9b..3d9337a6c 100644 --- a/lib/PTO/Transforms/PTOToEmitC.cpp +++ b/lib/PTO/Transforms/PTOToEmitC.cpp @@ -2354,14 +2354,13 @@ struct PTOMGatherToMGATHER : public OpConversionPattern { LogicalResult matchAndRewrite(pto::MGatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value mem = peelUnrealized(adaptor.getMem()); + Value idx = peelUnrealized(adaptor.getIdx()); Value dst = peelUnrealized(adaptor.getDst()); - // pto-isa currently has no NPU implementation for MGATHER/MSCATTER. - // Fallback to a smoke-friendly lowering to keep compile/run coverage. rewriter.create( - op.getLoc(), TypeRange{}, "TLOAD", + op.getLoc(), TypeRange{}, "MGATHER", ArrayAttr{}, ArrayAttr{}, - ValueRange{dst, mem}); + ValueRange{dst, mem, idx}); if (op->getNumResults() == 0) { rewriter.eraseOp(op); @@ -4775,14 +4774,13 @@ struct PTOMScatterToMSCATTER : public OpConversionPattern { LogicalResult matchAndRewrite(pto::MScatterOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Value src = peelUnrealized(adaptor.getSrc()); + Value idx = peelUnrealized(adaptor.getIdx()); Value mem = peelUnrealized(adaptor.getMem()); - // pto-isa currently has no NPU implementation for MGATHER/MSCATTER. - // Fallback to a smoke-friendly lowering to keep compile/run coverage. rewriter.create( - op.getLoc(), TypeRange{}, "TSTORE", + op.getLoc(), TypeRange{}, "MSCATTER", ArrayAttr{}, ArrayAttr{}, - ValueRange{mem, src}); + ValueRange{mem, src, idx}); rewriter.eraseOp(op); return success(); @@ -4836,6 +4834,72 @@ struct PTOGetValToGETVAL : public OpConversionPattern { } }; +struct PTOTAxpyToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TAxpyOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + Value scalar = peelUnrealized(adaptor.getScalar()); + + rewriter.create( + loc, TypeRange{}, "TAXPY", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, scalar}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOHistogramToEmitC : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::THistogramOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto *ctx = rewriter.getContext(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value idx = peelUnrealized(adaptor.getIdx()); + Value dst = peelUnrealized(adaptor.getDst()); + + auto templateArgs = rewriter.getArrayAttr( + {emitc::OpaqueAttr::get(ctx, op.getIsMSB() ? "true" : "false")}); + rewriter.create( + loc, TypeRange{}, "THISTOGRAM", + /*args=*/ArrayAttr{}, /*templateArgs=*/templateArgs, + /*operands=*/ValueRange{dst, src, idx}); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PTOGetScaleAddrToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TGetScaleAddrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TGET_SCALE_ADDR", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src}); + + rewriter.eraseOp(op); + return success(); + } +}; + struct PTOSetValidShapeToEmitC : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -7985,6 +8049,28 @@ struct PTORowMaxToEmitC : public OpConversionPattern { return success(); } }; + +struct PTORowArgMaxToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowArgMaxOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TROWARGMAX", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, tmp}); + + rewriter.eraseOp(op); + return success(); + } +}; //===----------------------------------------------------------------------===// // PTOConvert.cpp (add lowering + patterns.add for TROWMIN DPS/memref op) //===----------------------------------------------------------------------===// @@ -8011,6 +8097,28 @@ struct PTORowMinToEmitC : public OpConversionPattern { } }; +struct PTORowArgMinToEmitC + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite(pto::TRowArgMinOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + + Value src = peelUnrealized(adaptor.getSrc()); + Value tmp = peelUnrealized(adaptor.getTmp()); + Value dst = peelUnrealized(adaptor.getDst()); + + rewriter.create( + loc, TypeRange{}, "TROWARGMIN", + /*args=*/ArrayAttr{}, /*templateArgs=*/ArrayAttr{}, + /*operands=*/ValueRange{dst, src, tmp}); + + rewriter.eraseOp(op); + return success(); + } +}; + //===----------------------------------------------------------------------===// // PTOConvert.cpp (add lowering + patterns.add for TROWSUM DPS/memref op) //===----------------------------------------------------------------------===// @@ -9576,11 +9684,13 @@ static void populatePTOToEmitCPatterns(RewritePatternSet &patterns, patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); + patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); @@ -9629,6 +9739,8 @@ static void populatePTOToEmitCPatterns(RewritePatternSet &patterns, patterns.add(typeConverter, ctx); + patterns.add( + typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); patterns.add(typeConverter, ctx); diff --git a/test/basic/mgather_emitc.pto b/test/basic/mgather_emitc.pto new file mode 100644 index 000000000..41fe201f9 --- /dev/null +++ b/test/basic/mgather_emitc.pto @@ -0,0 +1,22 @@ +// RUN: ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s --check-prefix=A5 + +module { + func.func @mgather_emitc(%arg0: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + + %tv = pto.make_tensor_view %arg0, shape = [%c32, %c32], strides = [%c32, %c1] : !pto.tensor_view + %mem = pto.partition_view %tv, offsets = [%c0, %c0], sizes = [%c32, %c32] : !pto.tensor_view -> !pto.partition_tensor_view<32x32xi32> + %idx = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.mgather ins(%mem, %idx : !pto.partition_tensor_view<32x32xi32>, !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} + +// A5-LABEL: __global__ AICORE void mgather_emitc( +// A5: MGATHER([[DST:v[0-9]+]], [[MEM:v[0-9]+]], [[IDX:v[0-9]+]]); +// A5-NOT: TLOAD( diff --git a/test/basic/mscatter_emitc.pto b/test/basic/mscatter_emitc.pto new file mode 100644 index 000000000..0218b15e4 --- /dev/null +++ b/test/basic/mscatter_emitc.pto @@ -0,0 +1,22 @@ +// RUN: ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s --check-prefix=A5 + +module { + func.func @mscatter_emitc(%arg0: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + + %tv = pto.make_tensor_view %arg0, shape = [%c32, %c32], strides = [%c32, %c1] : !pto.tensor_view + %mem = pto.partition_view %tv, offsets = [%c0, %c0], sizes = [%c32, %c32] : !pto.tensor_view -> !pto.partition_tensor_view<32x32xi32> + %src = pto.alloc_tile : !pto.tile_buf + %idx = pto.alloc_tile : !pto.tile_buf + + pto.mscatter ins(%src, %idx : !pto.tile_buf, !pto.tile_buf) + outs(%mem : !pto.partition_tensor_view<32x32xi32>) + return + } +} + +// A5-LABEL: __global__ AICORE void mscatter_emitc( +// A5: MSCATTER([[MEM:v[0-9]+]], [[SRC:v[0-9]+]], [[IDX:v[0-9]+]]); +// A5-NOT: TSTORE( diff --git a/test/basic/taxpy_emitc.pto b/test/basic/taxpy_emitc.pto new file mode 100644 index 000000000..22c71f96f --- /dev/null +++ b/test/basic/taxpy_emitc.pto @@ -0,0 +1,19 @@ +// RUN: ptoas --pto-arch=a3 %s 2>&1 | FileCheck %s --check-prefix=A3 +// RUN: ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s --check-prefix=A5 + +module { + func.func @taxpy_emitc() { + %cst = arith.constant 2.0 : f16 + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.taxpy ins(%src, %cst : !pto.tile_buf, f16) + outs(%dst : !pto.tile_buf) + return + } +} + +// A3-LABEL: __global__ AICORE void taxpy_emitc() +// A3: TAXPY( +// A5-LABEL: __global__ AICORE void taxpy_emitc() +// A5: TAXPY( diff --git a/test/basic/tget_scale_addr_emitc.pto b/test/basic/tget_scale_addr_emitc.pto new file mode 100644 index 000000000..b90bac618 --- /dev/null +++ b/test/basic/tget_scale_addr_emitc.pto @@ -0,0 +1,15 @@ +// RUN: ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s + +module { + func.func @tget_scale_addr_emitc() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tget_scale_addr ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} + +// CHECK-LABEL: __global__ AICORE void tget_scale_addr_emitc() +// CHECK: TGET_SCALE_ADDR( diff --git a/test/basic/tget_scale_addr_verify_invalid.pto b/test/basic/tget_scale_addr_verify_invalid.pto new file mode 100644 index 000000000..2db7cb822 --- /dev/null +++ b/test/basic/tget_scale_addr_verify_invalid.pto @@ -0,0 +1,12 @@ +// RUN: not ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s + +func.func @tget_scale_addr_verify_invalid() { + %src = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.tget_scale_addr ins(%src : !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return +} + +// CHECK: error: 'pto.tget_scale_addr' op expects dst to be in the scaling address space diff --git a/test/basic/thistogram_emitc.pto b/test/basic/thistogram_emitc.pto new file mode 100644 index 000000000..e86956bac --- /dev/null +++ b/test/basic/thistogram_emitc.pto @@ -0,0 +1,20 @@ +// RUN: ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s + +module { + func.func @thistogram_emitc() { + %src = pto.alloc_tile : !pto.tile_buf + %idx = pto.alloc_tile : !pto.tile_buf + %dst0 = pto.alloc_tile : !pto.tile_buf + %dst1 = pto.alloc_tile : !pto.tile_buf + + pto.thistogram ins(%src, %idx : !pto.tile_buf, !pto.tile_buf) + outs(%dst0 : !pto.tile_buf) + pto.thistogram ins(%src, %idx : !pto.tile_buf, !pto.tile_buf) + outs(%dst1 : !pto.tile_buf) {isMSB = false} + return + } +} + +// CHECK-LABEL: __global__ AICORE void thistogram_emitc() +// CHECK: THISTOGRAM( +// CHECK: THISTOGRAM( diff --git a/test/basic/thistogram_verify_invalid_a3.pto b/test/basic/thistogram_verify_invalid_a3.pto new file mode 100644 index 000000000..2091ec737 --- /dev/null +++ b/test/basic/thistogram_verify_invalid_a3.pto @@ -0,0 +1,13 @@ +// RUN: not ptoas --pto-arch=a3 %s 2>&1 | FileCheck %s + +func.func @thistogram_verify_invalid_a3() { + %src = pto.alloc_tile : !pto.tile_buf + %idx = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.thistogram ins(%src, %idx : !pto.tile_buf, !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return +} + +// CHECK: error: 'pto.thistogram' op thistogram is only supported on A5 diff --git a/test/basic/trowargmax_emitc.pto b/test/basic/trowargmax_emitc.pto new file mode 100644 index 000000000..b8c0c8720 --- /dev/null +++ b/test/basic/trowargmax_emitc.pto @@ -0,0 +1,16 @@ +// RUN: ptoas --pto-arch=a3 %s 2>&1 | FileCheck %s --check-prefix=A3 + +module { + func.func @trowargmax_emitc() { + %src = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.trowargmax ins(%src, %tmp : !pto.tile_buf, !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} + +// A3-LABEL: __global__ AICORE void trowargmax_emitc() +// A3: TROWARGMAX( diff --git a/test/basic/trowargmin_emitc.pto b/test/basic/trowargmin_emitc.pto new file mode 100644 index 000000000..cb9a9eb95 --- /dev/null +++ b/test/basic/trowargmin_emitc.pto @@ -0,0 +1,16 @@ +// RUN: ptoas --pto-arch=a5 %s 2>&1 | FileCheck %s --check-prefix=A5 + +module { + func.func @trowargmin_emitc() { + %src = pto.alloc_tile : !pto.tile_buf + %tmp = pto.alloc_tile : !pto.tile_buf + %dst = pto.alloc_tile : !pto.tile_buf + + pto.trowargmin ins(%src, %tmp : !pto.tile_buf, !pto.tile_buf) + outs(%dst : !pto.tile_buf) + return + } +} + +// A5-LABEL: __global__ AICORE void trowargmin_emitc() +// A5: TROWARGMIN( diff --git a/tools/ptobc/generated/ptobc_opcodes_v0.h b/tools/ptobc/generated/ptobc_opcodes_v0.h index 49a57474d..86f7fa1ec 100644 --- a/tools/ptobc/generated/ptobc_opcodes_v0.h +++ b/tools/ptobc/generated/ptobc_opcodes_v0.h @@ -155,6 +155,11 @@ inline constexpr OpInfo kOpTable[] = { {0x1073, "pto.trowexpanddiv", 0, 0x00, 0x02, 0, 0, 0, 0x00}, {0x1074, "pto.trowexpandmul", 0, 0x00, 0x02, 0, 0, 0, 0x00}, {0x1075, "pto.tpack", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1076, "pto.taxpy", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1077, "pto.thistogram", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x1078, "pto.tget_scale_addr", 0, 0x00, 0x00, 2, 0, 0, 0x00}, + {0x1079, "pto.trowargmax", 0, 0x00, 0x00, 3, 0, 0, 0x00}, + {0x107A, "pto.trowargmin", 0, 0x00, 0x00, 3, 0, 0, 0x00}, {0x2000, "arith.addi", 0, 0x01, 0x00, 2, 1, 0, 0x00}, {0x2001, "arith.ceildivsi", 0, 0x01, 0x00, 2, 1, 0, 0x00}, {0x2002, "arith.cmpi", 0, 0x01, 0x00, 2, 1, 0, 0x01}, @@ -320,6 +325,12 @@ inline std::optional lookupOpcodeByName(llvm::StringRef name) { .Case("pto.subset", 0x1072) .Case("pto.trowexpanddiv", 0x1073) .Case("pto.trowexpandmul", 0x1074) + .Case("pto.tpack", 0x1075) + .Case("pto.taxpy", 0x1076) + .Case("pto.thistogram", 0x1077) + .Case("pto.tget_scale_addr", 0x1078) + .Case("pto.trowargmax", 0x1079) + .Case("pto.trowargmin", 0x107A) .Case("scf.for", 0x4000) .Case("scf.if", 0x4001) .Case("scf.yield", 0x4002) @@ -471,6 +482,12 @@ inline std::optional lookupOpcodeAndVariantByFullName(llvm::St .Case("pto.subset", OpcodeAndVariant{0x1072, 0, 0}) .Case("pto.trowexpanddiv", OpcodeAndVariant{0x1073, 0, 0}) .Case("pto.trowexpandmul", OpcodeAndVariant{0x1074, 0, 0}) + .Case("pto.tpack", OpcodeAndVariant{0x1075, 0, 0}) + .Case("pto.taxpy", OpcodeAndVariant{0x1076, 0, 0}) + .Case("pto.thistogram", OpcodeAndVariant{0x1077, 0, 0}) + .Case("pto.tget_scale_addr", OpcodeAndVariant{0x1078, 0, 0}) + .Case("pto.trowargmax", OpcodeAndVariant{0x1079, 0, 0}) + .Case("pto.trowargmin", OpcodeAndVariant{0x107A, 0, 0}) .Case("scf.for", OpcodeAndVariant{0x4000, 0, 0}) .Case("scf.if", OpcodeAndVariant{0x4001, 0, 0}) .Case("scf.yield", OpcodeAndVariant{0x4002, 0, 0})