diff --git a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/multitask/MultiTaskSuite.scala b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/multitask/MultiTaskSuite.scala index dab977019097..b86f6751e45b 100644 --- a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/multitask/MultiTaskSuite.scala +++ b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/multitask/MultiTaskSuite.scala @@ -44,21 +44,24 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer} * This will run as a part of "make scalatest" */ class MultiTaskSuite extends FunSuite { - test("Multitask Test") { val logger = LoggerFactory.getLogger(classOf[MultiTaskSuite]) - logger.info("Multitask Test...") + if (System.getenv().containsKey("SCALA_TEST_ON_GPU") && + System.getenv("SCALA_TEST_ON_GPU").toInt == 1) { + logger.info("Multitask Test...") - val batchSize = 100 - val numEpoch = 10 - val ctx = Context.cpu() + val batchSize = 100 + val numEpoch = 3 + val ctx = Context.gpu() - val modelPath = ExampleMultiTask.getTrainingData - val (executor, evalMetric) = ExampleMultiTask.train(batchSize, numEpoch, ctx, modelPath) - evalMetric.get.foreach { case (name, value) => - assert(value >= 0.95f) + val modelPath = ExampleMultiTask.getTrainingData + val (executor, evalMetric) = ExampleMultiTask.train(batchSize, numEpoch, ctx, modelPath) + evalMetric.get.foreach { case (name, value) => + assert(value >= 0.95f) + } + executor.dispose() + } else { + logger.info("GPU test only, skipped...") } - executor.dispose() } - }