diff --git a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala index f14289f984a2f..a079743c847ae 100644 --- a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala @@ -66,14 +66,19 @@ private[spark] class StreamingPythonRunner( envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString) envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString) - conf.set(PYTHON_USE_DAEMON, false) envVars.put("SPARK_CONNECT_LOCAL_URL", connectUrl) - val (worker, _) = env.createPythonWorker( - pythonExec, workerModule, envVars.asScala.toMap) - pythonWorker = Some(worker) + val prevConf = conf.get(PYTHON_USE_DAEMON) + conf.set(PYTHON_USE_DAEMON, false) + try { + val (worker, _) = env.createPythonWorker( + pythonExec, workerModule, envVars.asScala.toMap) + pythonWorker = Some(worker) + } finally { + conf.set(PYTHON_USE_DAEMON, prevConf) + } - val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) + val stream = new BufferedOutputStream(pythonWorker.get.getOutputStream, bufferSize) val dataOut = new DataOutputStream(stream) // TODO(SPARK-44461): verify python version @@ -87,7 +92,8 @@ private[spark] class StreamingPythonRunner( dataOut.write(command.toArray) dataOut.flush() - val dataIn = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) + val dataIn = new DataInputStream( + new BufferedInputStream(pythonWorker.get.getInputStream, bufferSize)) val resFromPython = dataIn.readInt() logInfo(s"Runner initialization returned $resFromPython")