diff --git a/python/mxnet/ndarray/numpy/random.py b/python/mxnet/ndarray/numpy/random.py index f55355a64326..141177b033fd 100644 --- a/python/mxnet/ndarray/numpy/random.py +++ b/python/mxnet/ndarray/numpy/random.py @@ -79,15 +79,19 @@ def randint(low, high=None, size=None, dtype=None, ctx=None, out=None): [3, 2, 2, 0]]) """ if dtype is None: - dtype = 'int' + dtype = 'int64' + elif not isinstance(dtype, str): + dtype = np.dtype(dtype).name if ctx is None: - ctx = current_context() + ctx = str(current_context()) + else: + ctx = str(ctx) if size is None: size = () if high is None: high = low low = 0 - return _npi.random_randint(low, high, shape=size, dtype=dtype, ctx=ctx, out=out) + return _api_internal.randint(low, high, size, dtype, ctx, out) def uniform(low=0.0, high=1.0, size=None, dtype=None, ctx=None, out=None): diff --git a/src/api/operator/random/np_randint_op.cc b/src/api/operator/random/np_randint_op.cc new file mode 100644 index 000000000000..8e05822fa907 --- /dev/null +++ b/src/api/operator/random/np_randint_op.cc @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_randint_op.cc + * \brief Implementation of the API of functions in src/operator/numpy/random/np_randint_op.cc + */ +#include +#include +#include +#include "../utils.h" +#include "../../../operator/random/sample_op.h" + +namespace mxnet { + +MXNET_REGISTER_API("_npi.randint") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_random_randint"); + nnvm::NodeAttrs attrs; + op::SampleRandIntParam param; + int num_inputs = 0; + param.low = args[0].operator int(); + param.high = args[1].operator int(); + if (args[2].type_code() == kDLInt) { + param.shape = TShape(1, args[2].operator int64_t()); + } else { + param.shape = TShape(args[2].operator ObjectRef()); + } + if (args[3].type_code() == kNull) { + param.dtype = mxnet::common::GetDefaultDtype(); + } else { + param.dtype = String2MXNetTypeWithBool(args[3].operator std::string()); + } + attrs.parsed = param; + attrs.op = op; + if (args[4].type_code() != kNull) { + attrs.dict["ctx"] = args[4].operator std::string(); + } + NDArray* out = args[5].operator mxnet::NDArray*(); + NDArray** outputs = out == nullptr ? nullptr : &out; + int num_outputs = out != nullptr; + SetAttrDict(&attrs); + auto ndoutputs = Invoke(op, &attrs, num_inputs, nullptr, &num_outputs, outputs); + if (out) { + *ret = PythonArg(5); + } else { + *ret = reinterpret_cast(ndoutputs[0]); + } +}); + +} // namespace mxnet diff --git a/src/operator/random/sample_op.h b/src/operator/random/sample_op.h index 03ca89ef4e7b..c5681a70353c 100644 --- a/src/operator/random/sample_op.h +++ b/src/operator/random/sample_op.h @@ -277,6 +277,17 @@ struct SampleRandIntParam : public dmlc::Parameter, .describe("DType of the output in case this can't be inferred. " "Defaults to int32 if not defined (dtype=None)."); } + void SetAttrDict(std::unordered_map* dict) { + std::ostringstream low_s, high_s, dtype_s, shape_s; + low_s << low; + high_s << high; + dtype_s << dtype; + shape_s << shape; + (*dict)["low"] = low_s.str(); + (*dict)["high"] = high_s.str(); + (*dict)["dtype"] = MXNetTypeWithBool2String(dtype); + (*dict)["shape"] = shape_s.str(); + } }; struct SampleUniformLikeParam : public dmlc::Parameter,