From ef725fa9a588fbee836bcea0c5e769c3792bb3b6 Mon Sep 17 00:00:00 2001 From: Jim Dunn Date: Wed, 18 Jul 2018 15:59:27 -0700 Subject: [PATCH 1/2] remove mod from arity 2 version of load-checkpoint --- .../clojure-package/src/org/apache/clojure_mxnet/module.clj | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/module.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/module.clj index 22ab761547e2..ab6d345fe91d 100644 --- a/contrib/clojure-package/src/org/apache/clojure_mxnet/module.clj +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/module.clj @@ -309,7 +309,6 @@ (defn load-checkpoint "Create a model from previously saved checkpoint. - - mod module - opts map of - prefix Path prefix of saved model files. You should have prefix-symbol.json, prefix-xxxx.params, and optionally prefix-xxxx.states, @@ -341,7 +340,7 @@ (util/->option (when workload-list (util/vec->indexed-seq workload-list))) (util/->option (when fixed-param-names (util/vec->set fixed-param-names))))) ([prefix epoch] - (load-checkpoint mod {:prefix prefix :epoch epoch}))) + (load-checkpoint {:prefix prefix :epoch epoch}))) (defn load-optimizer-states [mod fname] (.mod load fname)) @@ -670,4 +669,3 @@ (fit-params {:allow-missing true}) (fit-params {})) - From d5e6db97c574627a485548ea5eacc453d6924490 Mon Sep 17 00:00:00 2001 From: Jim Dunn Date: Wed, 1 Aug 2018 09:12:01 -0700 Subject: [PATCH 2/2] load-checkpoint arity 2 test --- .../test/org/apache/clojure_mxnet/module_test.clj | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/module_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/module_test.clj index f3d4e75e8c97..0f71b5a850cc 100644 --- a/contrib/clojure-package/test/org/apache/clojure_mxnet/module_test.clj +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/module_test.clj @@ -101,13 +101,20 @@ (m/init-optimizer {:optimizer (optimizer/sgd {:learning-rate 0.1 :momentum 0.9})}) (m/update) (m/save-checkpoint {:prefix "test" :epoch 0 :save-opt-states true})) - (let [mod2 (m/load-checkpoint {:prefix "test" :epoch 0 :load-optimizer-states true})] (-> mod2 (m/bind {:data-shapes [{:name "data" :shape [10 10] :layout "NT"}]}) (m/init-optimizer {:optimizer (optimizer/sgd {:learning-rate 0.1 :momentum 0.9})})) - (is (= (-> mod m/symbol sym/to-json) (-> mod2 m/symbol sym/to-json))) - (is (= (-> mod m/params first) (-> mod2 m/params first)))))) + (is (= (-> mod m/symbol sym/to-json) (-> mod2 m/symbol sym/to-json))) + (is (= (-> mod m/params first) (-> mod2 m/params first)))) + ;; arity 2 version of above. `load-optimizer-states` is `false` here by default, + ;; but optimizers states aren't checked here so it's not relevant to the test outcome. + (let [mod3 (m/load-checkpoint "test" 0)] + (-> mod3 + (m/bind {:data-shapes [{:name "data" :shape [10 10] :layout "NT"}]}) + (m/init-optimizer {:optimizer (optimizer/sgd {:learning-rate 0.1 :momentum 0.9})})) + (is (= (-> mod m/symbol sym/to-json) (-> mod3 m/symbol sym/to-json))) + (is (= (-> mod m/params first) (-> mod3 m/params first)))))) (deftest test-module-save-load-multi-device (let [s (sym/variable "data") @@ -321,4 +328,3 @@ (comment (m/data-shapes x)) -