From 79ba32e47082eed6ceaf738817164ca4d7567de8 Mon Sep 17 00:00:00 2001 From: Qiang Kou Date: Wed, 7 Oct 2015 18:57:13 -0400 Subject: [PATCH 1/2] cleanup load/save works well --- R-package/DESCRIPTION | 2 +- R-package/NAMESPACE | 2 ++ R-package/R/ndarray.R | 38 +++++++++++++++++++++++++---------- R-package/R/{base.R => zzz.R} | 2 -- R-package/src/ndarray.cc | 7 ++++--- R-package/src/ndarray.h | 2 +- 6 files changed, 35 insertions(+), 18 deletions(-) rename R-package/R/{base.R => zzz.R} (88%) diff --git a/R-package/DESCRIPTION b/R-package/DESCRIPTION index 6ca3a1ea44d8..d5cd1dd949eb 100644 --- a/R-package/DESCRIPTION +++ b/R-package/DESCRIPTION @@ -9,5 +9,5 @@ Description: MXNet is a deep learning framework designed for both efficiency and License: Apache-2.0 URL: https://github.com/dmlc/mxnet BugReports: https://github.com/dmlc/mxnet/issues -Imports: Rcpp (>= 0.11.1) +Imports: methods, Rcpp (>= 0.11.1) LinkingTo: Rcpp diff --git a/R-package/NAMESPACE b/R-package/NAMESPACE index 8eb5798245ee..aaa717375ee4 100644 --- a/R-package/NAMESPACE +++ b/R-package/NAMESPACE @@ -1,2 +1,4 @@ exportPattern("^[[:alpha:]]+") import(Rcpp) +import(methods) +export(`+.Rcpp_MXNDArray`) \ No newline at end of file diff --git a/R-package/R/ndarray.R b/R-package/R/ndarray.R index b6c744030b91..329af1484fe4 100644 --- a/R-package/R/ndarray.R +++ b/R-package/R/ndarray.R @@ -1,12 +1,28 @@ -#' NDArray -#' -#' Additional NDArray related operations -init.ndarray.methods <-function() { - require(methods) - setMethod("+", signature(e1="Rcpp_MXNDArray", e2="numeric"), function(e1, e2) { - mx.nd.internal.plus.scalar(e1, e2) - }) - setMethod("+", signature(e1="Rcpp_MXNDArray", e2="Rcpp_MXNDArray"), function(e1, e2) { - mx.nd.internal.plus(e1, e2) - }) +mx.nd.load <- function(filename) { + filename <- path.expand(filename) + mx.nd.internal.load(filename) +} + +mx.nd.save <- function(ndarray, filename) { + filename <- path.expand(filename) + mx.nd.internal.save(ndarray, filename) } + +is.MXNDArray <- function(x) + inherits(x, "Rcpp_MXNDArray") + +as.array.Rcpp_MXNDArray <- function(x){ + return(x$as.array()) +} + +`+.Rcpp_MXNDArray` <- function(e1, e2) { + if(is.MXNDArray(e1)&&is.MXNDArray(e2)) { + mx.nd.internal.plus(e1, e2) + } else if (is.MXNDArray(e1)&&is.numeric(e2)) { + mx.nd.internal.plus.scalar(e1, e2) + } else if (is.MXNDArray(e2)&&is.numeric(e1)) { + mx.nd.internal.plus.scalar(e2, e1) + } else { + stop("unsupport type found.") + } +} \ No newline at end of file diff --git a/R-package/R/base.R b/R-package/R/zzz.R similarity index 88% rename from R-package/R/base.R rename to R-package/R/zzz.R index 6c051f3d4209..902e7fea809f 100644 --- a/R-package/R/base.R +++ b/R-package/R/zzz.R @@ -1,10 +1,8 @@ -require(methods) .onLoad <- function(libname, pkgname) { library.dynam("libmxnet", pkgname, libname, local=FALSE) library.dynam("mxnet", pkgname, libname) loadModule("mxnet", TRUE) - init.ndarray.methods() } .onUnload <- function(libpath) { diff --git a/R-package/src/ndarray.cc b/R-package/src/ndarray.cc index b894937de231..ec6bdd4cd69e 100644 --- a/R-package/src/ndarray.cc +++ b/R-package/src/ndarray.cc @@ -157,7 +157,7 @@ void NDArray::Save(const Rcpp::RObject &sxptr, MX_CALL(MXNDArraySave(filename.c_str(), num_args, dmlc::BeginPtr(handles), dmlc::BeginPtr(keys))); - } else if (TYPEOF(sxptr) == EXTPTRSXP) { + } else if (TYPEOF(sxptr) == EXTPTRSXP) { // TODO this line is wrong?? MX_CALL(MXNDArraySave(filename.c_str(), 1, &(NDArray::XPtr(sxptr)->handle_), nullptr)); } else { @@ -220,8 +220,9 @@ void NDArray::InitRcppModule() { using namespace Rcpp; // NOLINT(*) class_("MXNDArray") .method("as.array", &NDArray::AsNumericVector); - function("mx.nd.load", &NDArray::Load); - function("mx.nd.save", &NDArray::Save); + // don't call load/save directly, let R provides the completed file path first + function("mx.nd.internal.load", &NDArray::Load); + function("mx.nd.internal.save", &NDArray::Save); function("mx.nd.array", &NDArray::Array); } diff --git a/R-package/src/ndarray.h b/R-package/src/ndarray.h index 21371f8ba57b..5dcd0ae79514 100644 --- a/R-package/src/ndarray.h +++ b/R-package/src/ndarray.h @@ -170,7 +170,7 @@ class NDArrayFunction : public ::Rcpp::CppFunction { } // namespace mxnet -RCPP_EXPOSED_CLASS_NODECL(::mxnet::R::NDArray); +RCPP_EXPOSED_CLASS_NODECL(::mxnet::R::NDArray) namespace mxnet { namespace R { From 145ca5acbfe2d4f2530d559ea9005deafbdf7825 Mon Sep 17 00:00:00 2001 From: Qiang Kou Date: Wed, 7 Oct 2015 19:55:31 -0400 Subject: [PATCH 2/2] reorganize R script --- R-package/NAMESPACE | 1 - R-package/R/ndarray.R | 3 ++- R-package/R/zzz.R | 1 + R-package/demo/basic_ndarray.R | 4 +--- 4 files changed, 4 insertions(+), 5 deletions(-) diff --git a/R-package/NAMESPACE b/R-package/NAMESPACE index aaa717375ee4..01b58a27fdc5 100644 --- a/R-package/NAMESPACE +++ b/R-package/NAMESPACE @@ -1,4 +1,3 @@ exportPattern("^[[:alpha:]]+") import(Rcpp) import(methods) -export(`+.Rcpp_MXNDArray`) \ No newline at end of file diff --git a/R-package/R/ndarray.R b/R-package/R/ndarray.R index 3ec0d74fc373..098035c5ad94 100644 --- a/R-package/R/ndarray.R +++ b/R-package/R/ndarray.R @@ -15,6 +15,7 @@ is.MXNDArray <- function(x) #' NDArray #' #' Additional NDArray related operations +init.ndarray.methods <- function() { setMethod("+", signature(e1 = "Rcpp_MXNDArray", e2 = "numeric"), function(e1, e2) { mx.nd.internal.plus.scalar(e1, e2) }) @@ -54,4 +55,4 @@ is.MXNDArray <- function(x) setMethod("as.array", signature(x = "Rcpp_MXNDArray"), function(x) { x$as.array() }) - +} diff --git a/R-package/R/zzz.R b/R-package/R/zzz.R index 902e7fea809f..7da79078d272 100644 --- a/R-package/R/zzz.R +++ b/R-package/R/zzz.R @@ -3,6 +3,7 @@ library.dynam("libmxnet", pkgname, libname, local=FALSE) library.dynam("mxnet", pkgname, libname) loadModule("mxnet", TRUE) + init.ndarray.methods() } .onUnload <- function(libpath) { diff --git a/R-package/demo/basic_ndarray.R b/R-package/demo/basic_ndarray.R index ded9cbe0213e..cd850d09670f 100644 --- a/R-package/demo/basic_ndarray.R +++ b/R-package/demo/basic_ndarray.R @@ -1,6 +1,4 @@ require(mxnet) -require(methods) - x = as.array(c(1,2,3)) mat = mx.nd.array(x, mx.cpu(0)) @@ -8,7 +6,7 @@ mat = mat + 1.0 mat = mat + mat mat = mat - 5 mat = 10 / mat -mat = 7*mat +mat = 7 * mat mat = 1 - mat + (2 * mat)/(mat + 0.5) as.array(mat)