From c232ec63f80eea05d3756feb22e53aa5a1e67d93 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Mon, 27 Aug 2018 12:07:44 -0500 Subject: [PATCH 1/4] [SPARK-25253][PYSPARK] Refactor local connection & auth code --- .../spark/api/python/PythonRunner.scala | 3 +- python/pyspark/java_gateway.py | 35 ++++++++++++++++++- python/pyspark/rdd.py | 27 ++------------ python/pyspark/taskcontext.py | 30 ++-------------- python/pyspark/worker.py | 7 ++-- 5 files changed, 41 insertions(+), 61 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 151c910bf1aee..a99e7f09455e5 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -206,6 +206,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( sock = serverSocket.get.accept() // Wait for function call from python side. sock.setSoTimeout(10000) + authHelper.authClient(sock) val input = new DataInputStream(sock.getInputStream()) input.readInt() match { case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION => @@ -324,8 +325,6 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( def barrierAndServe(sock: Socket): Unit = { require(serverSocket.isDefined, "No available ServerSocket to redirect the barrier() call.") - authHelper.authClient(sock) - val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) try { context.asInstanceOf[BarrierTaskContext].barrier() diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index fa2d5e8db716a..ace3b21dd413b 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -134,7 +134,7 @@ def killChild(): return gateway -def do_server_auth(conn, auth_secret): +def _do_server_auth(conn, auth_secret): """ Performs the authentication protocol defined by the SocketAuthHelper class on the given file-like object 'conn'. @@ -147,6 +147,39 @@ def do_server_auth(conn, auth_secret): raise Exception("Unexpected reply from iterator server.") +def local_connect_and_auth(sock_info): + """ + Connect to local host, authenticate with it, and return a (sockfile,sock) for that connection. + Handles IPV4 & IPV6, does some error handling. + :param sock_info: a tuple of (port, auth_secret) for connecting + :return: a tuple with (sockfile, sock) + """ + port, auth_secret = sock_info + sock = None + errors = [] + # Support for both IPv4 and IPv6. + # On most of IPv6-ready systems, IPv6 will take precedence. + for res in socket.getaddrinfo("127.0.0.1", port, socket.AF_UNSPEC, socket.SOCK_STREAM): + af, socktype, proto, canonname, sa = res + sock = socket.socket(af, socktype, proto) + try: + sock.settimeout(15) + sock.connect(sa) + except socket.error as e: + emsg = _exception_message(e) + errors.append("tried to connect to %s, but an error occured: %s" % (sa, emsg)) + sock.close() + sock = None + continue + break + if not sock: + raise Exception("could not open socket: %s" % errors) + + sockfile = sock.makefile("rwb", 65536) + _do_server_auth(sockfile, auth_secret) + return (sockfile, sock) + + def ensure_callback_server_started(gw): """ Start callback server if not already started. The callback server is needed if the Java diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index b061074a28ab4..fd14613fd40f3 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -39,7 +39,7 @@ else: from itertools import imap as map, ifilter as filter -from pyspark.java_gateway import do_server_auth +from pyspark.java_gateway import local_connect_and_auth from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ PickleSerializer, pack_long, AutoBatchedSerializer, write_with_length, \ @@ -141,33 +141,10 @@ def _parse_memory(s): def _load_from_socket(sock_info, serializer): - port, auth_secret = sock_info - sock = None - errors = [] - # Support for both IPv4 and IPv6. - # On most of IPv6-ready systems, IPv6 will take precedence. - for res in socket.getaddrinfo("localhost", port, socket.AF_UNSPEC, socket.SOCK_STREAM): - af, socktype, proto, canonname, sa = res - sock = socket.socket(af, socktype, proto) - try: - sock.settimeout(15) - sock.connect(sa) - except socket.error as e: - emsg = _exception_message(e) - errors.append("tried to connect to %s, but an error occured: %s" % (sa, emsg)) - sock.close() - sock = None - continue - break - if not sock: - raise Exception("could not open socket: %s" % errors) + (sockfile, sock) = local_connect_and_auth(sock_info) # The RDD materialization time is unpredicable, if we set a timeout for socket reading # operation, it will very possibly fail. See SPARK-18281. sock.settimeout(None) - - sockfile = sock.makefile("rwb", 65536) - do_server_auth(sockfile, auth_secret) - # The socket will be automatically closed when garbage-collected. return serializer.load_stream(sockfile) diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py index c0312e5265c6e..168795709baaf 100644 --- a/python/pyspark/taskcontext.py +++ b/python/pyspark/taskcontext.py @@ -18,7 +18,7 @@ from __future__ import print_function import socket -from pyspark.java_gateway import do_server_auth +from pyspark.java_gateway import local_connect_and_auth from pyspark.serializers import write_int, UTF8Deserializer @@ -108,38 +108,12 @@ def _load_from_socket(port, auth_secret): """ Load data from a given socket, this is a blocking method thus only return when the socket connection has been closed. - - This is copied from context.py, while modified the message protocol. """ - sock = None - # Support for both IPv4 and IPv6. - # On most of IPv6-ready systems, IPv6 will take precedence. - for res in socket.getaddrinfo("localhost", port, socket.AF_UNSPEC, socket.SOCK_STREAM): - af, socktype, proto, canonname, sa = res - sock = socket.socket(af, socktype, proto) - try: - # Do not allow timeout for socket reading operation. - sock.settimeout(None) - sock.connect(sa) - except socket.error: - sock.close() - sock = None - continue - break - if not sock: - raise Exception("could not open socket") - - # We don't really need a socket file here, it's just for convenience that we can reuse the - # do_server_auth() function and data serialization methods. - sockfile = sock.makefile("rwb", 65536) - + (sockfile, sock) = local_connect_and_auth((port, auth_secret)) # Make a barrier() function call. write_int(BARRIER_FUNCTION, sockfile) sockfile.flush() - # Do server auth. - do_server_auth(sockfile, auth_secret) - # Collect result. res = UTF8Deserializer().loads(sockfile) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index d54a5b8e396ea..b642bc4e31981 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -27,7 +27,7 @@ from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry -from pyspark.java_gateway import do_server_auth +from pyspark.java_gateway import local_connect_and_auth from pyspark.taskcontext import BarrierTaskContext, TaskContext from pyspark.files import SparkFiles from pyspark.rdd import PythonEvalType @@ -364,8 +364,5 @@ def process(): # Read information about how to connect back to the JVM from the environment. java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.connect(("127.0.0.1", java_port)) - sock_file = sock.makefile("rwb", 65536) - do_server_auth(sock_file, auth_secret) + (sock_file, _) = local_connect_and_auth((java_port, auth_secret)) main(sock_file, sock_file) From d07d21d83516492155207c06dfbde3a171412a68 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Tue, 28 Aug 2018 08:57:30 -0500 Subject: [PATCH 2/4] feedback --- python/pyspark/java_gateway.py | 8 ++++---- python/pyspark/rdd.py | 2 +- python/pyspark/taskcontext.py | 2 +- python/pyspark/worker.py | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index ace3b21dd413b..246a2853ccc56 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -147,20 +147,20 @@ def _do_server_auth(conn, auth_secret): raise Exception("Unexpected reply from iterator server.") -def local_connect_and_auth(sock_info): +def local_connect_and_auth(port, auth_secret): """ Connect to local host, authenticate with it, and return a (sockfile,sock) for that connection. Handles IPV4 & IPV6, does some error handling. - :param sock_info: a tuple of (port, auth_secret) for connecting + :param port + :param auth_secret :return: a tuple with (sockfile, sock) """ - port, auth_secret = sock_info sock = None errors = [] # Support for both IPv4 and IPv6. # On most of IPv6-ready systems, IPv6 will take precedence. for res in socket.getaddrinfo("127.0.0.1", port, socket.AF_UNSPEC, socket.SOCK_STREAM): - af, socktype, proto, canonname, sa = res + af, socktype, proto, _, sa = res sock = socket.socket(af, socktype, proto) try: sock.settimeout(15) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index fd14613fd40f3..380475e706fbe 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -141,7 +141,7 @@ def _parse_memory(s): def _load_from_socket(sock_info, serializer): - (sockfile, sock) = local_connect_and_auth(sock_info) + (sockfile, sock) = local_connect_and_auth(*sock_info) # The RDD materialization time is unpredicable, if we set a timeout for socket reading # operation, it will very possibly fail. See SPARK-18281. sock.settimeout(None) diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py index 168795709baaf..814142494e4c6 100644 --- a/python/pyspark/taskcontext.py +++ b/python/pyspark/taskcontext.py @@ -109,7 +109,7 @@ def _load_from_socket(port, auth_secret): Load data from a given socket, this is a blocking method thus only return when the socket connection has been closed. """ - (sockfile, sock) = local_connect_and_auth((port, auth_secret)) + (sockfile, sock) = local_connect_and_auth(port, auth_secret) # Make a barrier() function call. write_int(BARRIER_FUNCTION, sockfile) sockfile.flush() diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index b642bc4e31981..fcca8708a232b 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -364,5 +364,5 @@ def process(): # Read information about how to connect back to the JVM from the environment. java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] - (sock_file, _) = local_connect_and_auth((java_port, auth_secret)) + (sock_file, _) = local_connect_and_auth(java_port, auth_secret) main(sock_file, sock_file) From b0c4483e9506d8af6f12b6d41848ece52ffe9bc4 Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Tue, 28 Aug 2018 12:10:09 -0500 Subject: [PATCH 3/4] feedback --- python/pyspark/taskcontext.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py index 814142494e4c6..53fc2b29e066f 100644 --- a/python/pyspark/taskcontext.py +++ b/python/pyspark/taskcontext.py @@ -110,6 +110,8 @@ def _load_from_socket(port, auth_secret): connection has been closed. """ (sockfile, sock) = local_connect_and_auth(port, auth_secret) + # The barrier() call may block forever, so no timeout + sock.settimeout(None) # Make a barrier() function call. write_int(BARRIER_FUNCTION, sockfile) sockfile.flush() From 65ed777d11f1047bdb371259dc44969dad6ca0ac Mon Sep 17 00:00:00 2001 From: Imran Rashid Date: Tue, 28 Aug 2018 15:54:25 -0500 Subject: [PATCH 4/4] feedback --- python/pyspark/java_gateway.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 246a2853ccc56..b06503b53be90 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -161,24 +161,21 @@ def local_connect_and_auth(port, auth_secret): # On most of IPv6-ready systems, IPv6 will take precedence. for res in socket.getaddrinfo("127.0.0.1", port, socket.AF_UNSPEC, socket.SOCK_STREAM): af, socktype, proto, _, sa = res - sock = socket.socket(af, socktype, proto) try: + sock = socket.socket(af, socktype, proto) sock.settimeout(15) sock.connect(sa) + sockfile = sock.makefile("rwb", 65536) + _do_server_auth(sockfile, auth_secret) + return (sockfile, sock) except socket.error as e: emsg = _exception_message(e) errors.append("tried to connect to %s, but an error occured: %s" % (sa, emsg)) sock.close() sock = None - continue - break - if not sock: + else: raise Exception("could not open socket: %s" % errors) - sockfile = sock.makefile("rwb", 65536) - _do_server_auth(sockfile, auth_secret) - return (sockfile, sock) - def ensure_callback_server_started(gw): """