Skip to content
This repository was archived by the owner on Aug 11, 2020. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions mshadow/extension/range.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ namespace expr {
template<typename DType>
struct RangeExp:
public Exp<RangeExp<DType>, DType, type::kMapper> {
const float start_;
const float stop_;
const float step_;
const DType start_;
const DType stop_;
const DType step_;
const int repeat_;
/*! \brief constructor */
RangeExp(DType start, DType stop, DType step, int repeat)
Expand Down Expand Up @@ -67,18 +67,26 @@ MakePlan(const RangeExp<DType> &exp) {
return Plan<RangeExp<DType>, DType>(exp);
}

inline int RangeOutSize(float start, float stop, float step, int repeat) {
return repeat * static_cast<int>(ceil((stop - start) / step));

template<typename DType>
inline int RangeOutSize(DType start, DType stop, DType step, int repeat) {
return repeat * ((stop - start - 1) / step + 1);
}

inline int RangeOutSize(double start, double stop, float step, int repeat) {
return repeat * static_cast<int>(ceil((stop - start) / step));
template<>
inline int RangeOutSize<float>(float start, float stop, float step, int repeat) {
double d_start = static_cast<double>(start);
double d_stop = static_cast<double>(stop);
double d_step = static_cast<double>(step);
return repeat * static_cast<int>(ceil((d_stop - d_start) / d_step));
}

inline int RangeOutSize(int start, int stop, int step, int repeat) {
return repeat * ((stop - start - 1) / step + 1);
template<>
inline int RangeOutSize<double>(double start, double stop, double step, int repeat) {
return repeat * static_cast<int>(ceil((stop - start) / step));
}


template<int dim, typename DType>
struct ShapeCheck<dim, RangeExp<DType> > {
inline static Shape<dim>
Expand All @@ -96,7 +104,7 @@ struct ShapeCheck<dim, RangeExp<DType> > {
CHECK(t.start_ > t.stop_) << "RangeExp does not support (start, stop, step)= "
<< "(" << t.start_ << "," << t.stop_ << "," << t.step_ << ")";
}
return Shape1(RangeOutSize(t.start_, t.stop_, t.step_, t.repeat_));
return Shape1(RangeOutSize<DType>(t.start_, t.stop_, t.step_, t.repeat_));
}
};

Expand Down