diff --git a/R-package/DESCRIPTION b/R-package/DESCRIPTION index 6ca3a1ea44d8..d5cd1dd949eb 100644 --- a/R-package/DESCRIPTION +++ b/R-package/DESCRIPTION @@ -9,5 +9,5 @@ Description: MXNet is a deep learning framework designed for both efficiency and License: Apache-2.0 URL: https://github.com/dmlc/mxnet BugReports: https://github.com/dmlc/mxnet/issues -Imports: Rcpp (>= 0.11.1) +Imports: methods, Rcpp (>= 0.11.1) LinkingTo: Rcpp diff --git a/R-package/NAMESPACE b/R-package/NAMESPACE index 8eb5798245ee..01b58a27fdc5 100644 --- a/R-package/NAMESPACE +++ b/R-package/NAMESPACE @@ -1,2 +1,3 @@ exportPattern("^[[:alpha:]]+") import(Rcpp) +import(methods) diff --git a/R-package/R/ndarray.R b/R-package/R/ndarray.R index 04f7ac50c788..098035c5ad94 100644 --- a/R-package/R/ndarray.R +++ b/R-package/R/ndarray.R @@ -1,8 +1,21 @@ +mx.nd.load <- function(filename) { + filename <- path.expand(filename) + mx.nd.internal.load(filename) +} + +mx.nd.save <- function(ndarray, filename) { + filename <- path.expand(filename) + mx.nd.internal.save(ndarray, filename) +} + +is.MXNDArray <- function(x) + inherits(x, "Rcpp_MXNDArray") + + #' NDArray #' #' Additional NDArray related operations init.ndarray.methods <- function() { - require(methods) setMethod("+", signature(e1 = "Rcpp_MXNDArray", e2 = "numeric"), function(e1, e2) { mx.nd.internal.plus.scalar(e1, e2) }) @@ -43,4 +56,3 @@ init.ndarray.methods <- function() { x$as.array() }) } - diff --git a/R-package/R/base.R b/R-package/R/zzz.R similarity index 95% rename from R-package/R/base.R rename to R-package/R/zzz.R index 6c051f3d4209..7da79078d272 100644 --- a/R-package/R/base.R +++ b/R-package/R/zzz.R @@ -1,4 +1,3 @@ -require(methods) .onLoad <- function(libname, pkgname) { library.dynam("libmxnet", pkgname, libname, local=FALSE) diff --git a/R-package/demo/basic_ndarray.R b/R-package/demo/basic_ndarray.R index ded9cbe0213e..cd850d09670f 100644 --- a/R-package/demo/basic_ndarray.R +++ b/R-package/demo/basic_ndarray.R @@ -1,6 +1,4 @@ require(mxnet) -require(methods) - x = as.array(c(1,2,3)) mat = mx.nd.array(x, mx.cpu(0)) @@ -8,7 +6,7 @@ mat = mat + 1.0 mat = mat + mat mat = mat - 5 mat = 10 / mat -mat = 7*mat +mat = 7 * mat mat = 1 - mat + (2 * mat)/(mat + 0.5) as.array(mat) diff --git a/R-package/src/ndarray.cc b/R-package/src/ndarray.cc index b894937de231..ec6bdd4cd69e 100644 --- a/R-package/src/ndarray.cc +++ b/R-package/src/ndarray.cc @@ -157,7 +157,7 @@ void NDArray::Save(const Rcpp::RObject &sxptr, MX_CALL(MXNDArraySave(filename.c_str(), num_args, dmlc::BeginPtr(handles), dmlc::BeginPtr(keys))); - } else if (TYPEOF(sxptr) == EXTPTRSXP) { + } else if (TYPEOF(sxptr) == EXTPTRSXP) { // TODO this line is wrong?? MX_CALL(MXNDArraySave(filename.c_str(), 1, &(NDArray::XPtr(sxptr)->handle_), nullptr)); } else { @@ -220,8 +220,9 @@ void NDArray::InitRcppModule() { using namespace Rcpp; // NOLINT(*) class_("MXNDArray") .method("as.array", &NDArray::AsNumericVector); - function("mx.nd.load", &NDArray::Load); - function("mx.nd.save", &NDArray::Save); + // don't call load/save directly, let R provides the completed file path first + function("mx.nd.internal.load", &NDArray::Load); + function("mx.nd.internal.save", &NDArray::Save); function("mx.nd.array", &NDArray::Array); } diff --git a/R-package/src/ndarray.h b/R-package/src/ndarray.h index 21371f8ba57b..5dcd0ae79514 100644 --- a/R-package/src/ndarray.h +++ b/R-package/src/ndarray.h @@ -170,7 +170,7 @@ class NDArrayFunction : public ::Rcpp::CppFunction { } // namespace mxnet -RCPP_EXPOSED_CLASS_NODECL(::mxnet::R::NDArray); +RCPP_EXPOSED_CLASS_NODECL(::mxnet::R::NDArray) namespace mxnet { namespace R {