From 52fc982567a0c67f3855dfef352ae1bd10916e5b Mon Sep 17 00:00:00 2001 From: Ayres Date: Mon, 23 Jul 2018 16:42:43 -0700 Subject: [PATCH 1/2] Reduced test to 3 epochs and made gpu only --- .../multitask/MultiTaskSuite.scala | 27 ++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/multitask/MultiTaskSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/multitask/MultiTaskSuite.scala index dab977019097..a7dad3e6325d 100644 --- a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/multitask/MultiTaskSuite.scala +++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/multitask/MultiTaskSuite.scala @@ -44,21 +44,24 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer} * This will run as a part of "make scalatest" */ class MultiTaskSuite extends FunSuite { - test("Multitask Test") { - val logger = LoggerFactory.getLogger(classOf[MultiTaskSuite]) - logger.info("Multitask Test...") + if (System.getenv().containsKey("SCALA_TEST_ON_GPU") && + System.getenv("SCALA_TEST_ON_GPU").toInt == 1) { + val logger = LoggerFactory.getLogger(classOf[MultiTaskSuite]) + logger.info("Multitask Test...") - val batchSize = 100 - val numEpoch = 10 - val ctx = Context.cpu() + val batchSize = 100 + val numEpoch = 3 + val ctx = Context.gpu() - val modelPath = ExampleMultiTask.getTrainingData - val (executor, evalMetric) = ExampleMultiTask.train(batchSize, numEpoch, ctx, modelPath) - evalMetric.get.foreach { case (name, value) => - assert(value >= 0.95f) + val modelPath = ExampleMultiTask.getTrainingData + val (executor, evalMetric) = ExampleMultiTask.train(batchSize, numEpoch, ctx, modelPath) + evalMetric.get.foreach { case (name, value) => + assert(value >= 0.95f) + } + executor.dispose() + } else { + logger.info("GPU test only, skipped...") } - executor.dispose() } - } From 97bf7c27e23ffd75271f76a930550473f802f6b8 Mon Sep 17 00:00:00 2001 From: Ayres Date: Mon, 23 Jul 2018 17:12:51 -0700 Subject: [PATCH 2/2] Moved logger variable so that it's accessible --- .../org/apache/mxnetexamples/multitask/MultiTaskSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/multitask/MultiTaskSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/multitask/MultiTaskSuite.scala index a7dad3e6325d..b86f6751e45b 100644 --- a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/multitask/MultiTaskSuite.scala +++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/multitask/MultiTaskSuite.scala @@ -45,9 +45,9 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer} */ class MultiTaskSuite extends FunSuite { test("Multitask Test") { + val logger = LoggerFactory.getLogger(classOf[MultiTaskSuite]) if (System.getenv().containsKey("SCALA_TEST_ON_GPU") && System.getenv("SCALA_TEST_ON_GPU").toInt == 1) { - val logger = LoggerFactory.getLogger(classOf[MultiTaskSuite]) logger.info("Multitask Test...") val batchSize = 100