Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions aiter/ops/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@

# user interface

from typing import Tuple
from typing import Optional, Tuple

import torch
from ..jit.core import (
compile_ops,
)
from ..utility import dtypes

from ..jit.core import compile_ops
from ..jit.utils.chip_info import get_cu_num
from ..utility import dtypes


@compile_ops("module_moe_asm", fc_name="biased_grouped_topk")
Expand Down Expand Up @@ -202,6 +202,7 @@ def top_k_per_row_prefill(
rowStarts: torch.Tensor,
rowEnds: torch.Tensor,
indices: torch.Tensor,
values: Optional[torch.Tensor],
numRows: int,
stride0: int,
stride1: int,
Expand Down
1 change: 1 addition & 0 deletions csrc/include/rocm_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1337,6 +1337,7 @@ namespace py = pybind11;
py::arg("rowStarts"), \
py::arg("rowEnds"), \
py::arg("indices"), \
py::arg("values"), \
py::arg("numRows"), \
py::arg("stride0"), \
py::arg("stride1")); \
Expand Down
1 change: 1 addition & 0 deletions csrc/include/topk_per_row.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ void top_k_per_row_prefill(const torch::Tensor& logits,
const torch::Tensor& rowStarts,
const torch::Tensor& rowEnds,
torch::Tensor& indices,
std::optional<torch::Tensor> values,
int64_t numRows,
int64_t stride0,
int64_t stride1);
Expand Down
Loading