From 156a59c8e4aded8745e8258bf72cef64ea513721 Mon Sep 17 00:00:00 2001 From: sxjscience Date: Fri, 13 Oct 2017 23:17:07 +0800 Subject: [PATCH] fix range --- mshadow/extension/range.h | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/mshadow/extension/range.h b/mshadow/extension/range.h index eb73f441..ab49b6e3 100644 --- a/mshadow/extension/range.h +++ b/mshadow/extension/range.h @@ -23,9 +23,9 @@ namespace expr { template struct RangeExp: public Exp, DType, type::kMapper> { - const float start_; - const float stop_; - const float step_; + const DType start_; + const DType stop_; + const DType step_; const int repeat_; /*! \brief constructor */ RangeExp(DType start, DType stop, DType step, int repeat) @@ -67,18 +67,26 @@ MakePlan(const RangeExp &exp) { return Plan, DType>(exp); } -inline int RangeOutSize(float start, float stop, float step, int repeat) { - return repeat * static_cast(ceil((stop - start) / step)); + +template +inline int RangeOutSize(DType start, DType stop, DType step, int repeat) { + return repeat * ((stop - start - 1) / step + 1); } -inline int RangeOutSize(double start, double stop, float step, int repeat) { - return repeat * static_cast(ceil((stop - start) / step)); +template<> +inline int RangeOutSize(float start, float stop, float step, int repeat) { + double d_start = static_cast(start); + double d_stop = static_cast(stop); + double d_step = static_cast(step); + return repeat * static_cast(ceil((d_stop - d_start) / d_step)); } -inline int RangeOutSize(int start, int stop, int step, int repeat) { - return repeat * ((stop - start - 1) / step + 1); +template<> +inline int RangeOutSize(double start, double stop, double step, int repeat) { + return repeat * static_cast(ceil((stop - start) / step)); } + template struct ShapeCheck > { inline static Shape @@ -96,7 +104,7 @@ struct ShapeCheck > { CHECK(t.start_ > t.stop_) << "RangeExp does not support (start, stop, step)= " << "(" << t.start_ << "," << t.stop_ << "," << t.step_ << ")"; } - return Shape1(RangeOutSize(t.start_, t.stop_, t.step_, t.repeat_)); + return Shape1(RangeOutSize(t.start_, t.stop_, t.step_, t.repeat_)); } };