diff --git a/scala-package/core/pom.xml b/scala-package/core/pom.xml index 134e0a59da1f..5095da0e8609 100644 --- a/scala-package/core/pom.xml +++ b/scala-package/core/pom.xml @@ -104,5 +104,10 @@ 1.3.0-SNAPSHOT provided + + commons-io + commons-io + 2.1 + diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/Util.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/Util.scala new file mode 100644 index 000000000000..c1ff10c6c8a2 --- /dev/null +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/Util.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mxnetexamples + +import java.io.File +import java.net.URL + +import org.apache.commons.io.FileUtils + +object Util { + + def downloadUrl(url: String, filePath: String, maxRetry: Option[Int] = None) : Unit = { + val tmpFile = new File(filePath) + var retry = maxRetry.getOrElse(3) + var success = false + if (!tmpFile.exists()) { + while (retry > 0 && !success) { + try { + FileUtils.copyURLToFile(new URL(url), tmpFile) + success = true + } catch { + case e: Exception => retry -= 1 + } + } + } else { + success = true + } + if (!success) throw new Exception(s"$url Download failed!") + } +} diff --git a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala index 1270af3c45b4..9df2bcc0566d 100644 --- a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala +++ b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala @@ -24,9 +24,7 @@ import org.kohsuke.args4j.{CmdLineParser, Option} import org.slf4j.LoggerFactory import scala.collection.JavaConverters._ - import org.apache.commons.io.FileUtils - import org.apache.mxnet.Symbol import org.apache.mxnet.DataIter import org.apache.mxnet.DataBatch @@ -37,13 +35,13 @@ import org.apache.mxnet.Context import org.apache.mxnet.Xavier import org.apache.mxnet.optimizer.RMSProp import org.apache.mxnet.Executor +import org.apache.mxnetexamples.Util import scala.collection.immutable.ListMap import scala.sys.process.Process /** * Example of multi-task - * @author Depeng Liang */ object ExampleMultiTask { private val logger = LoggerFactory.getLogger(classOf[ExampleMultiTask]) @@ -204,11 +202,8 @@ object ExampleMultiTask { val baseUrl = "https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci" val tempDirPath = System.getProperty("java.io.tmpdir") val modelDirPath = tempDirPath + File.separator + "multitask/" - val tmpFile = new File(tempDirPath + "/multitask/mnist.zip") - if (!tmpFile.exists()) { - FileUtils.copyURLToFile(new URL(baseUrl + "/mnist/mnist.zip"), - tmpFile) - } + Util.downloadUrl(baseUrl + "/mnist/mnist.zip", + tempDirPath + "/multitask/mnist.zip") // TODO: Need to confirm with Windows Process("unzip " + tempDirPath + "/multitask/mnist.zip -d " diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/cnntextclassification/CNNClassifierExampleSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/cnntextclassification/CNNClassifierExampleSuite.scala index f7d1332792fb..95c9823e3b28 100644 --- a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/cnntextclassification/CNNClassifierExampleSuite.scala +++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/cnntextclassification/CNNClassifierExampleSuite.scala @@ -22,6 +22,7 @@ import java.net.URL import org.apache.commons.io.FileUtils import org.apache.mxnet.Context +import org.apache.mxnetexamples.Util import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.slf4j.LoggerFactory @@ -46,22 +47,13 @@ class CNNClassifierExampleSuite extends FunSuite with BeforeAndAfterAll { logger.info("Downloading CNN text...") val baseUrl = "https://s3.us-east-2.amazonaws.com/mxnet-scala" - var tmpFile = new File(tempDirPath + "/CNN/rt-polarity.pos") - if (!tmpFile.exists()) { - FileUtils.copyURLToFile(new URL(baseUrl + "/scala-example-ci/CNN/rt-polarity.pos"), - tmpFile) - } - tmpFile = new File(tempDirPath + "/CNN/rt-polarity.neg") - if (!tmpFile.exists()) { - FileUtils.copyURLToFile(new URL(baseUrl + "/scala-example-ci/CNN/rt-polarity.neg"), - tmpFile) - } + Util.downloadUrl(baseUrl + "/scala-example-ci/CNN/rt-polarity.pos", + tempDirPath + "/CNN/rt-polarity.pos") + Util.downloadUrl(baseUrl + "/scala-example-ci/CNN/rt-polarity.neg", + tempDirPath + "/CNN/rt-polarity.neg") logger.info("Downloading pretrianed Word2Vec Model, may take a while") - tmpFile = new File(tempDirPath + "/CNN/" + w2vModelName) - if (!tmpFile.exists()) { - FileUtils.copyURLToFile(new URL(baseUrl + "/scala-example-ci/CNN/" + w2vModelName), - tmpFile) - } + Util.downloadUrl(baseUrl + "/scala-example-ci/CNN/" + w2vModelName, + tempDirPath + "/CNN/" + w2vModelName) val modelDirPath = tempDirPath + File.separator + "CNN" diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/customop/CustomOpExampleSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/customop/CustomOpExampleSuite.scala index 4ba0e1bb87cb..6385e062a260 100644 --- a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/customop/CustomOpExampleSuite.scala +++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/customop/CustomOpExampleSuite.scala @@ -21,6 +21,7 @@ import java.net.URL import org.apache.commons.io.FileUtils import org.apache.mxnet.Context +import org.apache.mxnetexamples.Util import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.slf4j.LoggerFactory @@ -64,11 +65,8 @@ class CustomOpExampleSuite extends FunSuite with BeforeAndAfterAll { val baseUrl = "https://s3.us-east-2.amazonaws.com/mxnet-scala/scala-example-ci" val tempDirPath = System.getProperty("java.io.tmpdir") val modelDirPath = tempDirPath + File.separator + "mnist/" - val tmpFile = new File(tempDirPath + "/mnist/mnist.zip") - if (!tmpFile.exists()) { - FileUtils.copyURLToFile(new URL(baseUrl + "/mnist/mnist.zip"), - tmpFile) - } + Util.downloadUrl(baseUrl + "/mnist/mnist.zip", + tempDirPath + "/mnist/mnist.zip") // TODO: Need to confirm with Windows Process("unzip " + tempDirPath + "/mnist/mnist.zip -d " + tempDirPath + "/mnist/") ! diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/gan/GanExampleSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/gan/GanExampleSuite.scala index 12459fb1cc19..8ab3a4b364a7 100644 --- a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/gan/GanExampleSuite.scala +++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/gan/GanExampleSuite.scala @@ -22,6 +22,7 @@ import java.net.URL import org.apache.commons.io.FileUtils import org.apache.mxnet.Context +import org.apache.mxnetexamples.Util import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.slf4j.LoggerFactory @@ -38,11 +39,8 @@ class GanExampleSuite extends FunSuite with BeforeAndAfterAll{ val tempDirPath = System.getProperty("java.io.tmpdir") val modelDirPath = tempDirPath + File.separator + "mnist/" logger.info("tempDirPath: %s".format(tempDirPath)) - val tmpFile = new File(tempDirPath + "/mnist/mnist.zip") - if (!tmpFile.exists()) { - FileUtils.copyURLToFile(new URL(baseUrl + "/mnist/mnist.zip"), - tmpFile) - } + Util.downloadUrl(baseUrl + "/mnist/mnist.zip", + tempDirPath + "/mnist/mnist.zip") // TODO: Need to confirm with Windows Process("unzip " + tempDirPath + "/mnist/mnist.zip -d " + tempDirPath + "/mnist/") ! diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/MNISTExampleSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/MNISTExampleSuite.scala index 3e91b5b0245d..7b1d6ddc38b5 100644 --- a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/MNISTExampleSuite.scala +++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/imclassification/MNISTExampleSuite.scala @@ -22,6 +22,7 @@ import java.net.URL import org.apache.commons.io.FileUtils import org.apache.mxnet.Context +import org.apache.mxnetexamples.Util import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.slf4j.LoggerFactory @@ -41,11 +42,8 @@ class MNISTExampleSuite extends FunSuite with BeforeAndAfterAll { val tempDirPath = System.getProperty("java.io.tmpdir") val modelDirPath = tempDirPath + File.separator + "mnist/" logger.info("tempDirPath: %s".format(tempDirPath)) - val tmpFile = new File(tempDirPath + "/mnist/mnist.zip") - if (!tmpFile.exists()) { - FileUtils.copyURLToFile(new URL(baseUrl + "/mnist/mnist.zip"), - tmpFile) - } + Util.downloadUrl(baseUrl + "/mnist/mnist.zip", + tempDirPath + "/mnist/mnist.zip") // TODO: Need to confirm with Windows Process("unzip " + tempDirPath + "/mnist/mnist.zip -d " + tempDirPath + "/mnist/") ! diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExampleSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExampleSuite.scala index 2b5ac7f8a2ae..f0bb07b4a398 100644 --- a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExampleSuite.scala +++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExampleSuite.scala @@ -24,6 +24,7 @@ import java.net.URL import org.apache.commons.io.FileUtils import org.apache.mxnet.Context +import org.apache.mxnetexamples.Util import sys.process.Process @@ -42,28 +43,14 @@ class ImageClassifierExampleSuite extends FunSuite with BeforeAndAfterAll { val baseUrl = "https://s3.us-east-2.amazonaws.com/scala-infer-models" - var tmpFile = new File(tempDirPath + "/resnet18/resnet-18-symbol.json") - if (!tmpFile.exists()) { - FileUtils.copyURLToFile(new URL(baseUrl + "/resnet-18/resnet-18-symbol.json"), - tmpFile) - } - tmpFile = new File(tempDirPath + "/resnet18/resnet-18-0000.params") - if (!tmpFile.exists()) { - FileUtils.copyURLToFile(new URL(baseUrl + "/resnet-18/resnet-18-0000.params"), - tmpFile) - } - tmpFile = new File(tempDirPath + "/resnet18/synset.txt") - if (!tmpFile.exists()) { - FileUtils.copyURLToFile(new URL(baseUrl + "/resnet-18/synset.txt"), - tmpFile) - } - tmpFile = new File(tempDirPath + "/inputImages/resnet18/Pug-Cookie.jpg") - if (!tmpFile.exists()) { - FileUtils.copyURLToFile( - new URL("https://s3.amazonaws.com/model-server/inputs/Pug-Cookie.jpg"), - tmpFile - ) - } + Util.downloadUrl(baseUrl + "/resnet-18/resnet-18-symbol.json", + tempDirPath + "/resnet18/resnet-18-symbol.json") + Util.downloadUrl(baseUrl + "/resnet-18/resnet-18-0000.params", + tempDirPath + "/resnet18/resnet-18-0000.params") + Util.downloadUrl(baseUrl + "/resnet-18/synset.txt", + tempDirPath + "/resnet18/synset.txt") + Util.downloadUrl("https://s3.amazonaws.com/model-server/inputs/Pug-Cookie.jpg", + tempDirPath + "/inputImages/resnet18/Pug-Cookie.jpg") val modelDirPath = tempDirPath + File.separator + "resnet18/" val inputImagePath = tempDirPath + File.separator + diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/objectdetector/ObjectDetectorExampleSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/objectdetector/ObjectDetectorExampleSuite.scala index 85b98381a433..31da38569281 100644 --- a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/objectdetector/ObjectDetectorExampleSuite.scala +++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/objectdetector/ObjectDetectorExampleSuite.scala @@ -22,6 +22,7 @@ import java.net.URL import org.apache.commons.io.FileUtils import org.apache.mxnet.Context +import org.apache.mxnetexamples.Util import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.slf4j.LoggerFactory @@ -39,27 +40,14 @@ class ObjectDetectorExampleSuite extends FunSuite with BeforeAndAfterAll { val modelBase = "https://s3.amazonaws.com/model-server/models/resnet50_ssd/" val imageBase = "https://s3.amazonaws.com/model-server/inputs/" - - var tmpFile = new File(tempDirPath + "/resnetssd/resnet50_ssd_model-symbol.json") - if (!tmpFile.exists()) { - FileUtils.copyURLToFile(new URL(modelBase + "resnet50_ssd_model-symbol.json"), - tmpFile) - } - tmpFile = new File(tempDirPath + "/resnetssd/resnet50_ssd_model-0000.params") - if (!tmpFile.exists()) { - FileUtils.copyURLToFile(new URL(modelBase + "resnet50_ssd_model-0000.params"), - tmpFile) - } - tmpFile = new File(tempDirPath + "/resnetssd/synset.txt") - if (!tmpFile.exists()) { - FileUtils.copyURLToFile(new URL(modelBase + "synset.txt"), - tmpFile) - } - tmpFile = new File(tempDirPath + "/inputImages/resnetssd/dog-ssd.jpg") - if (!tmpFile.exists()) { - FileUtils.copyURLToFile(new URL(imageBase + "dog-ssd.jpg"), - tmpFile) - } + Util.downloadUrl(modelBase + "resnet50_ssd_model-symbol.json", + tempDirPath + "/resnetssd/resnet50_ssd_model-symbol.json") + Util.downloadUrl(modelBase + "resnet50_ssd_model-0000.params", + tempDirPath + "/resnetssd/resnet50_ssd_model-0000.params") + Util.downloadUrl(modelBase + "synset.txt", + tempDirPath + "/resnetssd/synset.txt") + Util.downloadUrl(imageBase + "dog-ssd.jpg", + tempDirPath + "/inputImages/resnetssd/dog-ssd.jpg") val modelDirPath = tempDirPath + File.separator + "resnetssd/" val inputImagePath = tempDirPath + File.separator +