Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -710,16 +710,17 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b
const te::Tensor& end, const te::Tensor& strides,
std::string name = "T_strided_slice_dynamic",
std::string tag = topi::kInjective) {
DataType index_dtype = begin->shape[0]->dtype;
const int64_t num_dynamic_axes = begin->shape[0].as<IntImmNode>()->value;
ICHECK_EQ(end->shape[0].as<IntImmNode>()->value, num_dynamic_axes);
ICHECK_EQ(strides->shape[0].as<IntImmNode>()->value, num_dynamic_axes);

Array<PrimExpr> begin_expr, end_expr, strides_expr;
for (int64_t i = 0; i < num_dynamic_axes; ++i) {
auto i64_ind = IntImm(DataType::Int(64), i);
begin_expr.push_back(begin(i64_ind));
end_expr.push_back(end(i64_ind));
strides_expr.push_back(strides(i64_ind));
auto ind = make_const(index_dtype, i);
begin_expr.push_back(begin(ind));
end_expr.push_back(end(ind));
strides_expr.push_back(strides(ind));
}
return dynamic_strided_slice(x, begin_expr, end_expr, strides_expr, name, tag);
}
Expand Down Expand Up @@ -822,9 +823,10 @@ inline Tensor strided_slice(const Tensor& x, const Array<Integer>& begin, const
Array<Integer> end_full(end);
Array<Integer> strides_full(strides);

const IntImm one = IntImm(DataType::Int(64), 1);
const IntImm zero = IntImm(DataType::Int(64), 0);
const IntImm max_range = IntImm(DataType::Int(64), std::numeric_limits<int64_t>::max());
DataType index_dtype = begin.size() > 0 ? begin[0]->dtype : DataType::Int(64);
const IntImm one = IntImm(index_dtype, 1);
const IntImm zero = IntImm(index_dtype, 0);
const IntImm max_range = Downcast<IntImm>(max_value(index_dtype));

for (size_t i = strides.size(); i < src_tensor_dim; ++i) {
strides_full.push_back(one);
Expand Down