From 17cf97fc80c431071c1538e97490e83613020fdc Mon Sep 17 00:00:00 2001 From: vicennial Date: Thu, 29 Jun 2023 18:25:33 +0200 Subject: [PATCH 1/2] init --- .../connect/planner/SparkConnectPlanner.scala | 13 +-- .../sql/connect/service/SessionHolder.scala | 45 +------- .../service/SparkConnectAnalyzeHandler.scala | 2 +- .../artifact/ArtifactManagerSuite.scala | 21 +++- core/src/test/resources/TestHelloV2.jar | Bin 0 -> 3784 bytes core/src/test/resources/TestHelloV3.jar | Bin 0 -> 3595 bytes .../executor/ClassLoaderIsolationSuite.scala | 102 ++++++++++++++++++ 7 files changed, 127 insertions(+), 56 deletions(-) create mode 100644 core/src/test/resources/TestHelloV2.jar create mode 100644 core/src/test/resources/TestHelloV3.jar create mode 100644 core/src/test/scala/org/apache/spark/executor/ClassLoaderIsolationSuite.scala diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index cecf14a70451a..ace981c3d826e 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -26,8 +26,6 @@ import com.google.protobuf.{Any => ProtoAny, ByteString} import io.grpc.{Context, Status, StatusRuntimeException} import io.grpc.stub.StreamObserver import org.apache.commons.lang3.exception.ExceptionUtils -import org.json4s._ -import org.json4s.jackson.JsonMethods.parse import org.apache.spark.{Partition, SparkEnv, TaskContext} import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction} @@ -91,15 +89,6 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { private lazy val pythonExec = sys.env.getOrElse("PYSPARK_PYTHON", sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", "python3")) - // SparkConnectPlanner is used per request. - private lazy val pythonIncludes = { - implicit val formats = DefaultFormats - parse(session.conf.get("spark.connect.pythonUDF.includes", "[]")) - .extract[Array[String]] - .toList - .asJava - } - // The root of the query plan is a relation and we apply the transformations to it. def transformRelation(rel: proto.Relation): LogicalPlan = { val plan = rel.getRelTypeCase match { @@ -1519,7 +1508,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { command = fun.getCommand.toByteArray, // Empty environment variables envVars = Maps.newHashMap(), - pythonIncludes = pythonIncludes, + pythonIncludes = sessionHolder.artifactManager.getSparkConnectPythonIncludes.asJava, pythonExec = pythonExec, pythonVer = fun.getPythonVer, // Empty broadcast variables diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index 004322097790f..0152f980f15d9 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -24,9 +24,6 @@ import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap} import scala.collection.JavaConverters._ import scala.util.control.NonFatal -import org.json4s.JsonDSL._ -import org.json4s.jackson.JsonMethods.{compact, render} - import org.apache.spark.JobArtifactSet import org.apache.spark.connect.proto import org.apache.spark.internal.Logging @@ -107,7 +104,7 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio * @param f * @tparam T */ - def withContext[T](f: => T): T = { + def withContextClassLoader[T](f: => T): T = { // Needed for deserializing and evaluating the UDF on the driver Utils.withContextClassLoader(classloader) { // Needed for propagating the dependencies to the executors. @@ -117,49 +114,15 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio } } - /** - * Set the session-based Python paths to include in Python UDF. - * @param f - * @tparam T - */ - def withSessionBasedPythonPaths[T](f: => T): T = { - try { - session.conf.set( - "spark.connect.pythonUDF.includes", - compact(render(artifactManager.getSparkConnectPythonIncludes))) - f - } finally { - session.conf.unset("spark.connect.pythonUDF.includes") - } - } - /** * Execute a block of code with this session as the active SparkConnect session. * @param f * @tparam T */ def withSession[T](f: SparkSession => T): T = { - withSessionBasedPythonPaths { - withContext { - session.withActive { - f(session) - } - } - } - } - - /** - * Execute a block of code using the session from this [[SessionHolder]] as the active - * SparkConnect session. - * @param f - * @tparam T - */ - def withSessionHolder[T](f: SessionHolder => T): T = { - withSessionBasedPythonPaths { - withContext { - session.withActive { - f(this) - } + withContextClassLoader { + session.withActive { + f(session) } } } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala index 5c069bfaf5d0e..414a852380fd2 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala @@ -38,7 +38,7 @@ private[connect] class SparkConnectAnalyzeHandler( request.getSessionId) // `withSession` ensures that session-specific artifacts (such as JARs and class files) are // available during processing (such as deserialization). - sessionHolder.withSessionHolder { sessionHolder => + sessionHolder.withSession { _ => val response = process(request, sessionHolder) responseObserver.onNext(response) responseObserver.onCompleted() diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala index 42ab8ca18f6e9..5baa2240c45ad 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala @@ -224,6 +224,7 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { test("Classloaders for spark sessions are isolated") { val holder1 = SparkConnectService.getOrCreateIsolatedSession("c1", "session1") val holder2 = SparkConnectService.getOrCreateIsolatedSession("c2", "session2") + val holder3 = SparkConnectService.getOrCreateIsolatedSession("c3", "session3") def addHelloClass(holder: SessionHolder): Unit = { val copyDir = Utils.createTempDir().toPath @@ -234,7 +235,7 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { holder.addArtifact(remotePath, stagingPath, None) } - // Add the classfile only for the first user + // Add the "Hello" classfile for the first user addHelloClass(holder1) val classLoader1 = holder1.classloader @@ -246,7 +247,8 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { val udf1 = org.apache.spark.sql.functions.udf(instance1) holder1.withSession { session => - session.range(10).select(udf1(col("id").cast("string"))).collect() + val result = session.range(10).select(udf1(col("id").cast("string"))).collect() + assert(result.forall(_.getString(0).contains("Talon"))) } assertThrows[ClassNotFoundException] { @@ -257,6 +259,21 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { .newInstance("Talon") .asInstanceOf[String => String] } + + // Add the "Hello" classfile for the third user + addHelloClass(holder3) + val instance3 = holder3 + .classloader + .loadClass("Hello") + .getDeclaredConstructor(classOf[String]) + .newInstance("Ahri") + .asInstanceOf[String => String] + val udf3 = org.apache.spark.sql.functions.udf(instance3) + + holder3.withSession { session => + val result = session.range(10).select(udf3(col("id").cast("string"))).collect() + assert(result.forall(_.getString(0).contains("Ahri"))) + } } } diff --git a/core/src/test/resources/TestHelloV2.jar b/core/src/test/resources/TestHelloV2.jar new file mode 100644 index 0000000000000000000000000000000000000000..d89cf6543a20090abd953af39ac97608aa02050e GIT binary patch literal 3784 zcmaJ^1yB^+{#{}T=|(_BWTh6Q5fD&;1tbJ^iKUSR0qJI?8zf|rkZvRei4|#2;p@B5!Scjn%?=l>4g8@sXZjwIL+)9)NY zpTEksbMCxo>xmwqGFyhGD?jPUw6|-Ts5UBK3@NocVv)6PoEt0KcrPS3^j0b)yF%smff7BB=h`z-=V_Gd3ACQ0(7gNkOv6HL zK(|2OfW51tljHtezW?=J!6Z|qdhYl+`n>wRIaswOXD*Y-4fH%CD&aS0w;WeRBtl!;^0z4(&e8Q>idz|6S0gU6`+07Odh!YjL>&0{ zoLi(FSPZ`S#xobiONS@_89k%96&`Tm!%fs-Daa=0uOPkR*f-;&sPMzMlm`GcjEC9(A@_ryiVSo$}GN9m24XskD}o< zkZJIv)FArmuNB7Xj#1=pPy(u=6w(hu?yiA(REmgIP;vFSD(Xz7LL@S#sds)}pSSIe zpHNJsZ!}Hsct+x`rz@>>C>5LaSjgSA#Et4Y+3O4Hw=J~huT2dH8?!72QuyU*g^`HO z(Cd2FglMHYC)$Z5ooUJZMUEMR8jG{u3;bxs9KY9G!B;#WYwS|i0>p(* z{p*kwr0f6hkzZCWNDq#3K75pWFD4j_t6P)Stixrb4J{=WZ z3nE#WEn6^CAhRR)kuWFw_l}@m_DWvi2?EQq>wR}6wN8zDC+ltUtY6&xpfCA(SZ&1U zru#*!&th-;v#O|t&1W;wC9>fRQ!s~3Ic4*E41WL+B7 zl@(Y}vv!+9$$6ODu3j(0lYT4gu{NoL%NTE0Mc2`sqeOb$aRrCb^uF>Wi;}G-!#zTF zC#KmrmBvmG;+!ko?!h!1L`dIC=D1i=5BW%-O4Staa3}vYYbn*lt}+>o5-}IwFcIa6 zCN)MaJy$6tMk(%lTgA!Ht;EW*lx$X~IquoiP5uQlMDL1r-F#46cy`;@*QzLzaL9{J znjV{IBD@TCaRxRjy0}S@eu6p_bu$;D7H)Al4TTdSLzdDKCvc;}M=dH(Z`C5rUUz(XPO@5m?kEi3M24(P7j z#gvy7jb@*yZe+;Vt&@C+FE|VW`Xj`X=lZgPcB7~2J5`rkAC08awBL?(#^S|k*w1x| z$6{SCNdn|J=-0-*e;|izJxUUEn6qE0N4~l*UqWUMI#S#Nm(oGFV|E)sVlrRn1k^ke zGeh{*3JVnPQG{|XfTR56o>snv$-1na`G*b9|4>&5>164rE+yI$U{J#DEC-Y~d>&K? zQ!ZJpGn`(R!GN3sb7j-|$WA{k>eCXyzO2xG#n!+h`aGy#w0oZ-o8C${q&P{O$9Qgj zc~g1YmSZ)T>eSSCqDtiXA~429mT^Wk@@a1pNKjuF#%mdB7S(Q8#oq+Alg#R;ez^+S zEEoN*VEpQHh%AcoaQNMu4v@^Dt6>rA8@*~4dX6hu^O=<-N8pZ%E=$C z${Z56--`Ytf0g~g02wgEJ1K@SbUCh_eAW9w&-Qn;{i}3DjyidVx#=MwL11tCl7sH@R#5Zn=4*@!0l5P;R zSYgKRi$XB9gP(0gy}}pVwp;LX8X(O@X4X3kXt^Ixxw^`CjFL6IlyFp8De6aVoi>0%pDQ9NUgxcoFpLYmCHHw~)8M;2T zz%T@65wbx{rjC|ftZl1{GZETii44qZ4K=fRr%$(^)trNK% z=46|u(eYrA(-Ku$%`I&)`Q#6Z(alj{e*R{9+P#69c8VgpOoTZFTlHpEBXic>m&jLKvb%gB~jX<+HQ1Y@U z)?#V!L>9@1*fN|xX?{HV%JACIZ~Jyo56>Q+jXR}PAhCs_00*{_hiau<6cQtNd9|IJ zmZbN{@-}-s>@n39D-Y!&`uCisbtm1b)hgeW6xz&=ue8>@uYo;?w-$+-Ef`7il4AH! zeXZ#8w{FX(pXFT^92{H%nkQ~J477~!Yf`e5`V#aHOKvv; z4-Ro}bsPhuJ(gT<2R5JFv=h85CJ1E#BA5;v1J?}sAiU}@v9KaW!Wuc+FZT^TQ3FVv z>I0j_P6EpLEGXq15$tEf#;p5MWf%BAEWG0aptG#%Mf}t{zo<&Q&-fa<> z!2_v*fYHO77C(%*XLI=CZ~ON#mdRxZAXYpti6>kKTiq#7`AOyz1^ay&_jq;rlb(}W z*l{}68oQbKbR@5@q>1Ws`NnHb%o_xO@V@@I1C)I+{<8B>fdvDAw83`)2G9Au2n)q# z%%oQ;v$jmDXWCaJb$uzAeSM49JPX$Z&WX%CroB!4r+_?=eVD;!yZQA?xIyHU;kP@n zp@bn;JC~rdo2x!vbe!6z$rFK zQ9=Y;17FSU88Wk_-;Q#exvMvFlYcj25p0;ra6UMaO$BY&FJYS;u#QtH5ob6~v4ld) zYdR(()1G+V_I~pYSTEz2$JkpB8B}{MVu+26-7(#Zd&Zc|&@@;R?yD8Dq>q(5DMnt2 z%_lT&L?lVe{koBNNz_wmyZ}ByaD(r6E(uO(F4#I6GTu@cQXgQBodyn(nOlYg5=JGY z6lh1p+tuVvNvLP&qP-QZF$M2~ z%bPl)$PgQ&WJ7VvTi1~rYrIclg46FH4Oc<=9F<11$lRZIi^OD)Zd6jw|1`d`T*1zX z&-jP|00+t|%Y}nW3;1oi{tP%*vFktaPa*88$~C}0hoRq9%KvTT{=XYn;pjJpel4%U z)bFVDpB?{cAg{vGZ=C&F{(1Ay==9(3{_{_+%;s<4{`!Rfsq&ZgR96NPToHl4h7=9} J;7WG2`xl*1oPhuU literal 0 HcmV?d00001 diff --git a/core/src/test/resources/TestHelloV3.jar b/core/src/test/resources/TestHelloV3.jar new file mode 100644 index 0000000000000000000000000000000000000000..b175a6c8640793b36dd987a80fc39dd4aa0d3db0 GIT binary patch literal 3595 zcmaJ^2QZxL8eUdeEut(dglJi;RiZ@n9%Ts@i6EkvXgRv9WR>Vv^mY(Jh#F;C5-oa9 z^xm>+juyd+;BxPoNzT3J+~=Eb=9~GS@15uS=Y40MM^76_!UzC?K!CxwVS0dHzz84# z=%^Vg!8LW&g>{s4HPzLShC({(r#%3`o=#scLK`mhiAEbPFwi$xWgs#lwy;boq@go_ z=tbxX5=E|u+9c!i?(>HLM*Z?u@c37~t@~HjSgd-M^YP3k{CRn=d6QQXcYgC*VZ!IF zyk1FKxWTGJRI~_*C+LKAeer*Njel{O7@((34w}dyrFS}4wtv3$&h=6P03g8s=p{M- zYVGDC{Lc*PKN+^ZmM-qjw*O}S{o5h`#niBMc6Q?tvUaxg@OW-$euqJcQLf6x#M`uy zWCKbd($L0|s_L8i?s8jELW&{mVKA%q)Edd`|F=(PBc5kz6Bgmbnu4H)H z*>9zJ49fXT;DQWZrR?bL*au%J%q1j6VI^LsY7+sJ*s7FcDHy0y#7>Rw*U8Lh-)vgy^Np(*2MHhjcD+#&`()q z*QZ7^aM}XE79WeK`fc{`fwuH}0$4$(t)Q5~jU9$(Da|TsoeZoo8yL`B7Zq>7VZ^hj zVcu8RIkpGPi&;&_k>u|Ml<7WtLHM4(|6!RqF5V_^3mh93gU;HFp9l3U)o}w=c;EQ+ zztV<|y+4$n<-rD%#xs!4DttFJ;Amh95T6Pk9-uE(cT85;V^n7d?sQU$hK!+Q#I?ge z9jak)<$QQu5?TjDbuL67Z^)*9Wdl#iNu6GbU~c`M3z#!yXz`vO;??D{nU0Q__h6fFI}Ma4`vg-`~uC*WdqX z@0hzLNY+WPd@VEfK5I1(iPfIC1{o(4m9R6}J6?$le7FzxNNn_RAn8N8^rvg#)RS^me z(pR!MnMKP9@shjuU`f*;KfY*UPHGj*-f$6}4qtE5+vC*%1fr*_uUR-&j#dZLvrATJ-p%-KsV2cH;er8VD~MsGEzj1k zq3Q*o6tD`h_&!xyZ)g*XmH_qOL3@blJLKr>$K()CSH~Q2weT)2eRBiQ>e^+nzAZ}_ zqkxdnftlGgPEmff7yg7I*_To*hfZzGQOv&y<%5AZuFXdvxpL69s+xJr=)FGmaw@N%^QaQr} zW5c#rfmW#=1-MM{L3BzLy?F^nx176aG5qi&+pBvW?-a~q z_bhW1T{(1S8LHOKTIJ<4_zrcGvI$(~Npp_OPFK0!IDGD|wC`Ta1W#MFpm_?qFe;kS z&kLBTrn7jngZY~fT9bitS^fQ5M!Yfsm8@Ucz8lPpXIy$f4f||mIu)O={gsksK68|` zZBgg(Yuj3>PK`;o*E1(oopq<-78G9xDx*5?O53(p3`Cr$9=WR+ply==j_aFcZy^iEJObNGq7{>q!HEhu^8uZd-kdIUNoBmq_SiCGYc-6>h(@21bZ_Ne3R@qVxk`3{G#f;%S~wH;0D2j>t%mn2D`8vlN z4C{=nR;8@R8;=QGAI&(bVCgT;X8|U%%NE!OsBL77y?k{aD$jp<<$B1u((_5?J`7c>J`(00^W?y-QMnR31IH~AWKx>C_cLo zNIs%>iM{?|x^$3DBaSA+j98ZIqn&ArC^p^cN@1ZJ!J`8lP2aF~d8f_(MtT|!^YpCJ z+QI(-tJAu&^5mNJU{lGaadqZxrxQL~MhW7j-o3gr`b#Bfmx-LIr|KpOon0Fg z%$yebY=euWNwckO<9vPf;jrw zP+&S_%!m4VPmvuvOS>2Lb+Uvf3zN-s@Y$(m8;9V7-Q#LyI+qn%r_84By~K`a%-1C^5>)@mYWq1pPtqev(Vso z9-|rbgg2Q8v*To3n?}taM9YFSpi#!?ccxNAn-i@A5kJH=YVUurt`Yf?{a!cG zC1TF?$Z3Sh7UwgraRc5h<34`SN|DvM1ldm8KB2*wyPqZ$g56rq0dQo*X+nKg{bYs zh1oi0iv`1qQe)4+sl^SSR6A;4?1JpRa8=sFS05n- zzo!^Q0-kX1x>>?bJ*jD)0CinY;kbF?-leXC#=z$U3v}ADL*ECsM!TcgreJ$$-x&?c zEeF>33|&62_$2i={5-n3ffREh+xV#c8!7H-^ArQxTHca*f#jjP?hX+lq_pL6qh`NvrARkk&$}Xc$R6X{FZb_ZYV1)MGWPrYkrE+iZOYGl{Ig;QIN5O zNt!db3hY;`qm=2NL{2#O>f|6A+d*W?x$~E%z3?_f#EgKS?&jBwaFJa8ihs&67gZ<$ z|C}d&R(boUWBPwLF7m}s44oe@^2pD$@mIyaI< 1L), + files = Map.empty, + archives = Map.empty + ) + + JobArtifactSet.withActive(artifactSetWithHelloV2) { + sc.parallelize(1 to 1).foreach { i => + val cls = Utils.classForName("com.example.Hello$") + val module = cls.getField("MODULE$").get(null) + val result = cls.getMethod("test").invoke(module).asInstanceOf[Int] + if (result != 2) { + throw new RuntimeException("Unexpected result: " + result) + } + } + } + + // TestHelloV3's test method returns '3' + val artifactSetWithHelloV3 = new JobArtifactSet( + uuid = Some("hello3"), + replClassDirUri = None, + jars = Map(jar3 -> 1L), + files = Map.empty, + archives = Map.empty + ) + + JobArtifactSet.withActive(artifactSetWithHelloV3) { + sc.parallelize(1 to 1).foreach { i => + val cls = Utils.classForName("com.example.Hello$") + val module = cls.getField("MODULE$").get(null) + val result = cls.getMethod("test").invoke(module).asInstanceOf[Int] + if (result != 3) { + throw new RuntimeException("Unexpected result: " + result) + } + } + } + + // Should not be able to see any "Hello" class if they're excluded from the artifact set + val artifactSetWithoutHello = new JobArtifactSet( + uuid = Some("Jar 1"), + replClassDirUri = None, + jars = Map(jar1 -> 1L), + files = Map.empty, + archives = Map.empty + ) + + JobArtifactSet.withActive(artifactSetWithoutHello) { + sc.parallelize(1 to 1).foreach { i => + try { + Utils.classForName("com.example.Hello$") + throw new RuntimeException("Import should fail") + } catch { + case _: ClassNotFoundException => + } + } + } + } +} From 4f759470978527bea85671951135c910677d5f1c Mon Sep 17 00:00:00 2001 From: vicennial Date: Fri, 30 Jun 2023 10:54:10 +0200 Subject: [PATCH 2/2] lint --- .../spark/sql/connect/artifact/ArtifactManagerSuite.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala index 5baa2240c45ad..612bf096b22bd 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala @@ -262,8 +262,7 @@ class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper { // Add the "Hello" classfile for the third user addHelloClass(holder3) - val instance3 = holder3 - .classloader + val instance3 = holder3.classloader .loadClass("Hello") .getDeclaredConstructor(classOf[String]) .newInstance("Ahri")