Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
328 changes: 328 additions & 0 deletions docs/PTO_IR_manual.md
Original file line number Diff line number Diff line change
Expand Up @@ -2451,6 +2451,7 @@ pto.tprelu ins(%a, %slopes, %c : !pto.tile_buf<loc=vec, dtype=f16, rows=16, cols
| `pto.tadds` | `dst[i,j] = src[i,j] + scalar` |
| `pto.tsubs` | `dst[i,j] = src[i,j] - scalar` |
| `pto.tmuls` | `dst[i,j] = src[i,j] * scalar` |
| `pto.taxpy` | `dst[i,j] = dst[i,j] + src[i,j] * scalar` |
| `pto.tdivs` | `dst[i,j] = src[i,j] / scalar` (or `scalar / src[i,j]`) |
| `pto.tmaxs` | `dst[i,j] = max(src[i,j], scalar)` |
| `pto.tmins` | `dst[i,j] = min(src[i,j], scalar)` |
Expand Down Expand Up @@ -2636,6 +2637,70 @@ pto.tmuls ins(%a, %s : !pto.tile_buf<loc=vec, dtype=f16, rows=16, cols=16,

---

##### `pto.taxpy` - Multiply-Add Into Destination Tile

**Summary:** Updates the destination tile in place with `dst += src * scalar`.

**Semantics:**

```
For each element (i, j):
dst[i, j] = dst[i, j] + src[i, j] * scalar
```

**Arguments:**

| Name | Type | Description |
|------|------|-------------|
| `src` | `pto.tile_buf` | Source tile buffer |
| `scalar` | `ScalarType` (signless integer / float) | Scalar multiplier |
| `dst` | `pto.tile_buf` | Destination tile buffer, updated in place |

**Results:** None. Writes into `dst` via DPS pattern.

**Assembly Format:**

```
pto.taxpy ins(<src>, <scalar> : <src_type>, <scalar_type>)
outs(<dst> : <dst_type>)
```

**Constraints & Verification:**

- **Implementation checks (A2A3)**
- `src` and `dst` must use `loc=vec`.
- `src` and `dst` must have the same shape and valid shape.
- `scalar` type must exactly match the element type of `src`.
- `src` element type must be `f16` or `f32`.
- `dst` element type must be `f16` or `f32`.
- Element types must either match, or use the widening form `src=f16`, `dst=f32`.
- **Implementation checks (A5)**
- `src` and `dst` must use `loc=vec`.
- `src` and `dst` must have the same shape and valid shape.
- `scalar` type must exactly match the element type of `src`.
- `src` element type must be `f16`, `bf16`, or `f32`.
- `dst` element type must be `f16`, `bf16`, or `f32`.
- Element types must either match, or use the widening form `src=f16`, `dst=f32`.

**Hardware Mapping:**

- Executes on the **Vector pipeline** (`PIPE_V`)
- Operates on data in the **VEC (UB)** memory space

**Basic Example:**

```mlir
%alpha = arith.constant 2.0 : f16
pto.taxpy ins(%src, %alpha : !pto.tile_buf<loc=vec, dtype=f16, rows=16, cols=16,
v_row=16, v_col=16, blayout=row_major, slayout=none_box,
fractal=512, pad=0>, f16)
outs(%dst : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=16,
v_row=16, v_col=16, blayout=row_major, slayout=none_box,
fractal=512, pad=0>)
```

---

##### `pto.tdivs` - Elementwise Division with Scalar

**Summary:** Divides every element of a tile buffer by a scalar, or divides a scalar by every element.
Expand Down Expand Up @@ -3726,10 +3791,13 @@ Reduce along rows or columns of a tile. All execute on the **Vector pipeline** (
|----|----------|
| `pto.trowsum` | `dst[i,0] = sum_j src[i,j]` |
| `pto.trowmax` | `dst[i,0] = max_j src[i,j]` |
| `pto.trowargmax` | `dst[i,0] = argmax_j src[i,j]` (requires tmp) |
| `pto.trowmin` | `dst[i,0] = min_j src[i,j]` (requires tmp) |
| `pto.trowargmin` | `dst[i,0] = argmin_j src[i,j]` (requires tmp) |
| `pto.tcolsum` | `dst[0,j] = sum_i src[i,j]` (requires tmp, optional isBinary) |
| `pto.tcolmax` | `dst[0,j] = max_i src[i,j]` |
| `pto.tcolmin` | `dst[0,j] = min_i src[i,j]` |
| `pto.thistogram` | `dst[i, idx[i,0]] = histogram_update(dst[i, idx[i,0]], src[i,:])` (A5 only) |

---

Expand Down Expand Up @@ -3874,6 +3942,73 @@ pto.trowmax ins(%src : !pto.tile_buf<loc=vec, dtype=f16, rows=16, cols=16,

---

##### `pto.trowargmax` - Row-wise ArgMax Reduction

**Summary:** Reduces each row to the column index of its maximum element. Requires a temporary buffer.

**Semantics:**

```
For each row i:
dst[i, 0] = argmax over j of src[i, j]
```

**Arguments:**

| Name | Type | Description |
|------|------|-------------|
| `src` | `pto.tile_buf` | Source tile buffer |
| `tmp` | `pto.tile_buf` | Temporary buffer with the same shape/type as `src` |
| `dst` | `pto.tile_buf` | Destination tile buffer containing row-wise indices |

**Results:** None. Writes into `dst` via DPS pattern.

**Assembly Format:**

```
pto.trowargmax ins(<src>, <tmp> : <src_type>, <tmp_type>)
outs(<dst> : <dst_type>)
```

**Constraints & Verification:**

- **Implementation checks (A2A3)**
- `src`, `tmp`, and `dst` must use `loc=vec`.
- `src` must use ND-style tile layout (`blayout=row_major`, `slayout=none_box`).
- `tmp` must have the same shape, valid shape, and element type as `src`.
- `dst` must use `slayout=none_box` and either:
- a DN-style column vector tile (`blayout=col_major`, `cols=1`), or
- a legacy ND-style tile with `valid column == 1`.
- `src` element type must be `f16` or `f32`.
- `dst` element type must be `i32` or `ui32`.
- Runtime valid checks:
- `src valid row != 0` and `src valid column != 0`
- `src valid row == dst valid row`
- `dst valid column == 1`
- **Implementation checks (A5)**
- Same constraints as A2/A3.

**Hardware Mapping:**

- Executes on the **Vector pipeline** (`PIPE_V`)
- Operates on data in the **VEC (UB)** memory space

**Basic Example:**

```mlir
pto.trowargmax ins(%src, %tmp : !pto.tile_buf<loc=vec, dtype=f16, rows=16, cols=32,
v_row=16, v_col=32, blayout=row_major, slayout=none_box,
fractal=512, pad=0>,
!pto.tile_buf<loc=vec, dtype=f16, rows=16, cols=32,
v_row=16, v_col=32, blayout=row_major, slayout=none_box,
fractal=512, pad=0>)
outs(%dst : !pto.tile_buf<loc=vec, dtype=ui32, rows=16, cols=1,
v_row=16, v_col=1, blayout=col_major, slayout=none_box,
fractal=512, pad=0>)
```

---

##### `pto.trowmin` - Row-wise Min Reduction

**Summary:** Reduces each row by taking the minimum across columns. Requires a temporary buffer.
Expand Down Expand Up @@ -3950,6 +4085,144 @@ pto.trowmin ins(%src, %tmp : !pto.tile_buf<loc=vec, dtype=f16, rows=16, cols=16,

---

##### `pto.trowargmin` - Row-wise ArgMin Reduction

**Summary:** Reduces each row to the column index of its minimum element. Requires a temporary buffer.

**Semantics:**

```
For each row i:
dst[i, 0] = argmin over j of src[i, j]
```

**Arguments:**

| Name | Type | Description |
|------|------|-------------|
| `src` | `pto.tile_buf` | Source tile buffer |
| `tmp` | `pto.tile_buf` | Temporary buffer with the same shape/type as `src` |
| `dst` | `pto.tile_buf` | Destination tile buffer containing row-wise indices |

**Results:** None. Writes into `dst` via DPS pattern.

**Assembly Format:**

```
pto.trowargmin ins(<src>, <tmp> : <src_type>, <tmp_type>)
outs(<dst> : <dst_type>)
```

**Constraints & Verification:**

- **Implementation checks (A2A3)**
- `src`, `tmp`, and `dst` must use `loc=vec`.
- `src` must use ND-style tile layout (`blayout=row_major`, `slayout=none_box`).
- `tmp` must have the same shape, valid shape, and element type as `src`.
- `dst` must use `slayout=none_box` and either:
- a DN-style column vector tile (`blayout=col_major`, `cols=1`), or
- a legacy ND-style tile with `valid column == 1`.
- `src` element type must be `f16` or `f32`.
- `dst` element type must be `i32` or `ui32`.
- Runtime valid checks:
- `src valid row != 0` and `src valid column != 0`
- `src valid row == dst valid row`
- `dst valid column == 1`
- **Implementation checks (A5)**
- Same constraints as A2/A3.

**Hardware Mapping:**

- Executes on the **Vector pipeline** (`PIPE_V`)
- Operates on data in the **VEC (UB)** memory space

**Basic Example:**

```mlir
pto.trowargmin ins(%src, %tmp : !pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=32,
v_row=16, v_col=32, blayout=row_major, slayout=none_box,
fractal=512, pad=0>,
!pto.tile_buf<loc=vec, dtype=f32, rows=16, cols=32,
v_row=16, v_col=32, blayout=row_major, slayout=none_box,
fractal=512, pad=0>)
outs(%dst : !pto.tile_buf<loc=vec, dtype=i32, rows=16, cols=1,
v_row=16, v_col=1, blayout=col_major, slayout=none_box,
fractal=512, pad=0>)
```

---

##### `pto.thistogram` - Row-wise Histogram Accumulation

**Summary:** Updates a 256-bin histogram row in `dst` using each source row and a per-row index selector. This op is only supported on A5.

**Semantics:**

```
For each row i:
bin = idx[i, 0]
dst[i, bin] = histogram_update(dst[i, bin], src[i, :])
```

The exact accumulation performed inside the selected bin is target-defined by the hardware `THISTOGRAM` intrinsic. The `isMSB` attribute selects the intrinsic mode.

**Arguments:**

| Name | Type | Description |
|------|------|-------------|
| `src` | `pto.tile_buf` | Source tile buffer, one logical row per histogram update |
| `idx` | `pto.tile_buf` | Per-row bin selector tile |
| `dst` | `pto.tile_buf` | Destination histogram tile |
| `isMSB` | `BoolAttr` (default: `true`) | Selects the `THISTOGRAM<...>` intrinsic mode |

**Results:** None. Writes into `dst` via DPS pattern.

**Assembly Format:**

```
pto.thistogram ins(<src>, <idx> : <src_type>, <idx_type>)
outs(<dst> : <dst_type>)
{isMSB = true}
```

**Constraints & Verification:**

- **Implementation checks (A2A3)**
- Not supported.
- **Implementation checks (A5)**
- `src`, `idx`, and `dst` must all be `tile_buf` values in `loc=vec`.
- `src` and `dst` must use `row_major + none_box` layout.
- `idx` must use DN-style layout (`col_major + none_box`).
- `src` element type must be `ui16`.
- `idx` element type must be `ui8`.
- `dst` element type must be `ui32`.
- `src`, `idx`, and `dst` must all be rank-2 tiles.
- `idx` rows and valid rows must match `src`.
- `dst` rows and valid rows must match `src`.
- `idx` must have exactly one column.
- `dst` shape[1] and valid_shape[1] must be at least `256`.

**Hardware Mapping:**

- Executes on the **Vector pipeline** (`PIPE_V`)
- Operates on data in the **VEC (UB)** memory space

**Basic Example:**

```mlir
pto.thistogram ins(%src, %idx : !pto.tile_buf<loc=vec, dtype=ui16, rows=8, cols=32,
v_row=8, v_col=32, blayout=row_major, slayout=none_box,
fractal=512, pad=0>,
!pto.tile_buf<loc=vec, dtype=ui8, rows=8, cols=1,
v_row=8, v_col=1, blayout=col_major, slayout=none_box,
fractal=512, pad=0>)
outs(%dst : !pto.tile_buf<loc=vec, dtype=ui32, rows=8, cols=256,
v_row=8, v_col=256, blayout=row_major, slayout=none_box,
fractal=512, pad=0>) {isMSB = false}
```

---

##### `pto.tcolsum` - Column-wise Sum Reduction

**Summary:** Reduces each column by summing across rows. Requires a temporary buffer.
Expand Down Expand Up @@ -6677,6 +6950,61 @@ pto.tsetval ins(%off, %val : index, f16) outs(%dst : !pto.tile_buf<loc=vec, dtyp

### 4.15 MX Quantized Operations

##### `pto.tget_scale_addr` - Bind Scaling Tile View

**Summary:** Binds a scaling tile to the scale-address associated with a source tile. No elementwise computation or data movement is performed.

**Semantics:**

```
dst = scale_view_of(src)
```

`dst` becomes a scaling-tile view compatible with `src`, so it can be consumed by later MX / scaling-aware ops.

**Arguments:**

| Name | Type | Description |
|------|------|-------------|
| `src` | `pto.tile_buf` | Source tile that owns or references scale storage |
| `dst` | `pto.tile_buf` | Destination scaling tile view |

**Results:** None. Initializes or rebinds `dst` via DPS pattern.

**Assembly Format:**

```
pto.tget_scale_addr ins(<src> : <src_type>)
outs(<dst> : <dst_type>)
```

**Constraints & Verification:**

- **Implementation checks (A2A3)**
- Not supported.
- **Implementation checks (A5)**
- `src` must be a valid `pto.tile_buf`.
- `dst` must be a valid `pto.tile_buf` in `loc=scaling`.
- `dst` must have the same rank, shape, and valid shape as `src`.
- `dst` must satisfy the target-specific scaling-tile compatibility rules for `src`.

**Hardware Mapping:**

- Executes on the **Scalar pipeline** (`PIPE_S`)

**Basic Example:**

```mlir
pto.tget_scale_addr ins(%src : !pto.tile_buf<loc=left, dtype=f8E4M3, rows=1, cols=128,
v_row=1, v_col=128, blayout=col_major, slayout=row_major,
fractal=512, pad=0>)
outs(%scale : !pto.tile_buf<loc=scaling, dtype=f16, rows=1, cols=128,
v_row=1, v_col=128, blayout=row_major, slayout=row_major,
fractal=512, pad=0>)
```

---

##### `pto.tmov.fp` - Move/Convert with Scaling Tile

**Summary:** Legacy dedicated fp-TMOV op. New code should prefer `pto.tmov` with an `fp` operand, which lowers to the same `TMOV_FP` / fp-parameterized `TMOV` APIs.
Expand Down
Loading
Loading