From 2e0b308ef8e4829171f97298fb340e060c410747 Mon Sep 17 00:00:00 2001 From: Chris Martin Date: Thu, 20 Jun 2019 18:15:17 +0100 Subject: [PATCH 01/10] initial commit of cogroup --- .../spark/api/python/PythonRunner.scala | 2 + python/pyspark/rdd.py | 1 + python/pyspark/serializers.py | 34 ++++ python/pyspark/sql/cogroup.py | 15 ++ python/pyspark/sql/functions.py | 3 + python/pyspark/sql/group.py | 4 + .../tests/test_pandas_udf_cogrouped_map.py | 113 +++++++++++ python/pyspark/worker.py | 39 +++- .../logical/pythonLogicalOperators.scala | 12 ++ .../spark/sql/RelationalGroupedDataset.scala | 31 ++- .../spark/sql/execution/SparkStrategies.scala | 3 + .../python/FlatMapCoGroupsInPandasExec.scala | 93 +++++++++ .../python/FlatMapGroupsInPandasExec.scala | 16 +- .../python/InterleavedArrowWriter.scala | 82 ++++++++ .../python/MultiDfArrowPythonRunner.scala | 191 ++++++++++++++++++ 15 files changed, 631 insertions(+), 8 deletions(-) create mode 100644 python/pyspark/sql/cogroup.py create mode 100644 python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/InterleavedArrowWriter.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/MultiDfArrowPythonRunner.scala diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 4dcc5eb5fbfcd..fe28ec6fc1c29 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -46,6 +46,7 @@ private[spark] object PythonEvalType { val SQL_GROUPED_MAP_PANDAS_UDF = 201 val SQL_GROUPED_AGG_PANDAS_UDF = 202 val SQL_WINDOW_AGG_PANDAS_UDF = 203 + val SQL_COGROUPED_MAP_PANDAS_UDF = 204 def toString(pythonEvalType: Int): String = pythonEvalType match { case NON_UDF => "NON_UDF" @@ -54,6 +55,7 @@ private[spark] object PythonEvalType { case SQL_GROUPED_MAP_PANDAS_UDF => "SQL_GROUPED_MAP_PANDAS_UDF" case SQL_GROUPED_AGG_PANDAS_UDF => "SQL_GROUPED_AGG_PANDAS_UDF" case SQL_WINDOW_AGG_PANDAS_UDF => "SQL_WINDOW_AGG_PANDAS_UDF" + case SQL_COGROUPED_MAP_PANDAS_UDF => "SQL_COGROUPED_MAP_PANDAS_UDF" } } diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index f0682e71a1780..f198316d0c4a5 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -73,6 +73,7 @@ class PythonEvalType(object): SQL_GROUPED_MAP_PANDAS_UDF = 201 SQL_GROUPED_AGG_PANDAS_UDF = 202 SQL_WINDOW_AGG_PANDAS_UDF = 203 + SQL_COGROUPED_MAP_PANDAS_UDF = 204 def portable_hash(x): diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 516ee7e7b3084..e818241aaeebf 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -359,6 +359,24 @@ def __repr__(self): return "ArrowStreamPandasSerializer" +class InterleavedArrowReader(object): + + def __init__(self, stream): + import pyarrow as pa + self._schema1 = pa.read_schema(stream) + self._schema2 = pa.read_schema(stream) + self._reader = pa.MessageReader.open_stream(stream) + + def __iter__(self): + return self + + def __next__(self): + import pyarrow as pa + batch1 = pa.read_record_batch(self._reader.read_next_message(), self._schema1) + batch2 = pa.read_record_batch(self._reader.read_next_message(), self._schema2) + return batch1, batch2 + + class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer): """ Serializer used by Python worker to evaluate Pandas UDFs @@ -404,6 +422,22 @@ def __repr__(self): return "ArrowStreamPandasUDFSerializer" +class InterleavedArrowStreamPandasSerializer(ArrowStreamPandasUDFSerializer): + + def __init__(self, timezone, safecheck, assign_cols_by_name): + super(InterleavedArrowStreamPandasSerializer, self).__init__(timezone, safecheck, assign_cols_by_name) + + def load_stream(self, stream): + """ + Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series. + """ + import pyarrow as pa + reader = InterleavedArrowReader(pa.input_stream(stream)) + for batch1, batch2 in reader: + yield ( [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch1]).itercolumns()], + [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch2]).itercolumns()]) + + class BatchedSerializer(Serializer): """ diff --git a/python/pyspark/sql/cogroup.py b/python/pyspark/sql/cogroup.py new file mode 100644 index 0000000000000..8946ab1202651 --- /dev/null +++ b/python/pyspark/sql/cogroup.py @@ -0,0 +1,15 @@ +from pyspark.sql.dataframe import DataFrame + + +class CoGroupedData(object): + + def __init__(self, gd1, gd2): + self._gd1 = gd1 + self._gd2 = gd2 + self.sql_ctx = gd1.sql_ctx + + def apply(self, udf): + df = self._gd1._df + udf_column = udf(*[df[col] for col in df.columns]) + jdf = self._gd1._jgd.flatMapCoGroupsInPandas(self._gd2._jgd, udf_column._jc.expr()) + return DataFrame(jdf, self.sql_ctx) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 613822b7edf2d..f156b9e9bd984 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2798,6 +2798,8 @@ class PandasUDFType(object): GROUPED_MAP = PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF + COGROUPED_MAP = PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF + GROUPED_AGG = PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF @@ -3179,6 +3181,7 @@ def pandas_udf(f=None, returnType=None, functionType=None): if eval_type not in [PythonEvalType.SQL_SCALAR_PANDAS_UDF, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF]: raise ValueError("Invalid functionType: " "functionType must be one the values from PandasUDFType") diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index cc1da8e7c1f72..04f42b1598376 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -22,6 +22,7 @@ from pyspark.sql.column import Column, _to_seq from pyspark.sql.dataframe import DataFrame from pyspark.sql.types import * +from pyspark.sql.cogroup import CoGroupedData __all__ = ["GroupedData"] @@ -220,6 +221,9 @@ def pivot(self, pivot_col, values=None): jgd = self._jgd.pivot(pivot_col, values) return GroupedData(jgd, self._df) + def cogroup(self, other): + return CoGroupedData(self, other) + @since(2.3) def apply(self, udf): """ diff --git a/python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py b/python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py new file mode 100644 index 0000000000000..e98b66c64eaae --- /dev/null +++ b/python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py @@ -0,0 +1,113 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import datetime +import unittest +import sys + +from collections import OrderedDict +from decimal import Decimal + +from pyspark.sql import Row +from pyspark.sql.functions import array, explode, col, lit, udf, sum, pandas_udf, PandasUDFType +from pyspark.sql.types import * +from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ + pandas_requirement_message, pyarrow_requirement_message +from pyspark.testing.utils import QuietTest + +if have_pandas: + import pandas as pd + from pandas.util.testing import assert_frame_equal + +if have_pyarrow: + import pyarrow as pa + + +""" +Tests below use pd.DataFrame.assign that will infer mixed types (unicode/str) for column names +from kwargs w/ Python 2, so need to set check_column_type=False and avoid this check +""" +if sys.version < '3': + _check_column_type = False +else: + _check_column_type = True + + +@unittest.skipIf( + not have_pandas or not have_pyarrow, + pandas_requirement_message or pyarrow_requirement_message) +class CoGroupedMapPandasUDFTests(ReusedSQLTestCase): + + @property + def data(self): + return self.spark.range(10).toDF('id') \ + .withColumn("vs", array([lit(i) for i in range(20, 30)])) \ + .withColumn("v", explode(col('vs'))).drop('vs') + + def test_supported_types(self): + + df1 = self.spark.createDataFrame( + pd.DataFrame.from_dict({ + 'id' : [1,1,10, 10, 1,1], + 'x' : [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] + })) + + df2 = self.spark.createDataFrame( + pd.DataFrame.from_dict({ + 'id2': [1,1,10, 10, 1,1], + 'a': [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] + })) + + output_schema = StructType([ + StructField("id", LongType()), + StructField("x", DoubleType()), + ]) + + @pandas_udf(output_schema, functionType=PandasUDFType.COGROUPED_MAP) + def foo(left, right): + print("hello") + print(left) + print("goodbye") + print(right) + return left + + output_schema2 = StructType([ + StructField("id", LongType()) + ]) + @pandas_udf(output_schema, functionType=PandasUDFType.GROUPED_MAP) + def foo2(key, df): + print('key is ' + str(key)) + print(df) + return df + + + df1.groupby(col("id") > 5)\ + .apply(foo2)\ + .show() + + + + +if __name__ == "__main__": + from pyspark.sql.tests.test_pandas_udf_cogrouped_map import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 16257bef6b320..2ccc28cffce78 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -38,7 +38,7 @@ from pyspark.rdd import PythonEvalType from pyspark.serializers import write_with_length, write_int, read_long, read_bool, \ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ - BatchedSerializer, ArrowStreamPandasUDFSerializer + BatchedSerializer, ArrowStreamPandasUDFSerializer, InterleavedArrowStreamPandasSerializer from pyspark.sql.types import to_arrow_type, StructType from pyspark.util import _get_argspec, fail_on_stopiteration from pyspark import shuffle @@ -103,8 +103,25 @@ def verify_result_length(*a): return lambda *a: (verify_result_length(*a), arrow_return_type) -def wrap_grouped_map_pandas_udf(f, return_type, argspec): +def wrap_cogrouped_map_pandas_udf(f, return_type): + def wrapped(left, right): + import pandas as pd + result = f(pd.concat(left, axis=1), pd.concat(right, axis=1)) + if not isinstance(result, pd.DataFrame): + raise TypeError("Return type of the user-defined function should be " + "pandas.DataFrame, but is {}".format(type(result))) + if not len(result.columns) == len(return_type): + raise RuntimeError( + "Number of columns of the returned pandas.DataFrame " + "doesn't match specified schema. " + "Expected: {} Actual: {}".format(len(return_type), len(result.columns))) + return result + + return lambda v: [(wrapped(v[0], v[1]), to_arrow_type(return_type))] + + +def wrap_grouped_map_pandas_udf(f, return_type, argspec): def wrapped(key_series, value_series): import pandas as pd @@ -233,6 +250,7 @@ def read_udfs(pickleSer, infile, eval_type): runner_conf = {} if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, + PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF): @@ -255,9 +273,12 @@ def read_udfs(pickleSer, infile, eval_type): # Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of # pandas Series. See SPARK-27240. - df_for_struct = eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF - ser = ArrowStreamPandasUDFSerializer(timezone, safecheck, assign_cols_by_name, - df_for_struct) + if eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: + ser = InterleavedArrowStreamPandasSerializer(timezone, safecheck, assign_cols_by_name) + else: + df_for_struct = eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF + ser = ArrowStreamPandasUDFSerializer(timezone, safecheck, assign_cols_by_name, + df_for_struct) else: ser = BatchedSerializer(PickleSerializer(), 100) @@ -282,6 +303,14 @@ def read_udfs(pickleSer, infile, eval_type): arg0 = ["a[%d]" % o for o in arg_offsets[1: split_offset]] arg1 = ["a[%d]" % o for o in arg_offsets[split_offset:]] mapper_str = "lambda a: f([%s], [%s])" % (", ".join(arg0), ", ".join(arg1)) + elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: + # We assume there is only one UDF here because cogrouped map doesn't + # support combining multiple UDFs. + assert num_udfs == 1 + arg_offsets, udf = read_single_udf( + pickleSer, infile, eval_type, runner_conf, udf_index=0) + udfs['f'] = udf + mapper_str = "lambda a: f(a)" else: # Create function like this: # lambda a: (f0(a[0]), f1(a[1], a[2]), f2(a[3])) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala index 2df30a1a53ad7..35adbae423f25 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala @@ -39,6 +39,18 @@ case class FlatMapGroupsInPandas( override val producedAttributes = AttributeSet(output) } + +case class FlatMapCoGroupsInPandas( + leftAttributes: Seq[Attribute], + rightAttributes: Seq[Attribute], + functionExpr: Expression, + output: Seq[Attribute], + left: LogicalPlan, + right: LogicalPlan) extends BinaryNode { + override val producedAttributes = AttributeSet(output) +} + + trait BaseEvalPython extends UnaryNode { def udfs: Seq[PythonUDF] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index e85636d82a62c..147cc00c0ba91 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -47,8 +47,8 @@ import org.apache.spark.sql.types.{NumericType, StructType} */ @Stable class RelationalGroupedDataset protected[sql]( - df: DataFrame, - groupingExprs: Seq[Expression], + private val df: DataFrame, + private val groupingExprs: Seq[Expression], groupType: RelationalGroupedDataset.GroupType) { private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = { @@ -523,6 +523,33 @@ class RelationalGroupedDataset protected[sql]( Dataset.ofRows(df.sparkSession, plan) } + private[sql] def flatMapCoGroupsInPandas + (r: RelationalGroupedDataset, expr: PythonUDF): DataFrame = { + require(expr.evalType == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, + "Must pass a cogrouped map udf") + require(expr.dataType.isInstanceOf[StructType], + s"The returnType of the udf must be a ${StructType.simpleString}") + + val leftGroupingNamedExpressions = groupingExprs.map { + case ne: NamedExpression => ne + case other => Alias(other, other.toString)() + } + + val rightGroupingNamedExpressions = r.groupingExprs.map { + case ne: NamedExpression => ne + case other => Alias(other, other.toString)() + } + + val leftAttributes = leftGroupingNamedExpressions.map(_.toAttribute) + val rightAttributes = rightGroupingNamedExpressions.map(_.toAttribute) + val left = df.logicalPlan + val right = r.df.logicalPlan + val output = expr.dataType.asInstanceOf[StructType].toAttributes + val plan = FlatMapCoGroupsInPandas(leftAttributes, rightAttributes, expr, output, left, right) + + Dataset.ofRows(df.sparkSession, plan) + } + override def toString: String = { val builder = new StringBuilder builder.append("RelationalGroupedDataset: [grouping expressions: [") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index c4031496f610f..965f04c058966 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -679,6 +679,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { f, p, b, is, ot, planLater(child)) :: Nil case logical.FlatMapGroupsInPandas(grouping, func, output, child) => execution.python.FlatMapGroupsInPandasExec(grouping, func, output, planLater(child)) :: Nil + case logical.FlatMapCoGroupsInPandas(leftGroup, rightGroup, func, output, left, right) => + execution.python.FlatMapCoGroupsInPandasExec( + leftGroup, rightGroup, func, output, planLater(left), planLater(right)) :: Nil case logical.MapElements(f, _, _, objAttr, child) => execution.MapElementsExec(f, objAttr, planLater(child)) :: Nil case logical.AppendColumns(f, _, _, in, out, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala new file mode 100644 index 0000000000000..2dc517354812a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import scala.collection.JavaConverters._ + +import org.apache.spark.TaskContext +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning} +import org.apache.spark.sql.execution.{BinaryExecNode, CoGroupedIterator, GroupedIterator, SparkPlan} +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} + +case class FlatMapCoGroupsInPandasExec( + leftGroup: Seq[Attribute], + rightGroup: Seq[Attribute], + func: Expression, + output: Seq[Attribute], + left: SparkPlan, + right: SparkPlan) + extends BinaryExecNode { + + private val pandasFunction = func.asInstanceOf[PythonUDF].func + + override def outputPartitioning: Partitioning = left.outputPartitioning + + override def producedAttributes: AttributeSet = AttributeSet(output) + + override def requiredChildDistribution: Seq[Distribution] = { + ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: Nil + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = { + leftGroup + .map(SortOrder(_, Ascending)) :: rightGroup.map(SortOrder(_, Ascending)) :: Nil + } + + + override protected def doExecute(): RDD[InternalRow] = { + + val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) + val sessionLocalTimeZone = conf.sessionLocalTimeZone + val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) + + left.execute().zipPartitions(right.execute()) { (leftData, rightData) => + val leftGrouped = GroupedIterator(leftData, leftGroup, left.output) + val rightGrouped = GroupedIterator(rightData, rightGroup, right.output) + val cogroup = new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup) + .map{case (k, l, r) => (l, r)} + val context = TaskContext.get() + val columnarBatchIter = new InterleavedArrowPythonRunner( + chainedFunc, + PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, + Array(Array.empty), + left.schema, + right.schema, + sessionLocalTimeZone, + pythonRunnerConf).compute(cogroup, context.partitionId(), context) + + + val unsafeProj = UnsafeProjection.create(output, output) + + columnarBatchIter.flatMap { batch => + // Grouped Map UDF returns a StructType column in ColumnarBatch, select the children here + val structVector = batch.column(0).asInstanceOf[ArrowColumnVector] + val outputVectors = output.indices.map(structVector.getChild) + val flattenedBatch = new ColumnarBatch(outputVectors.toArray) + flattenedBatch.setNumRows(batch.numRows()) + flattenedBatch.rowIterator.asScala + }.map(unsafeProj) + } + + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index 7b0e014f9ca48..46db29a250520 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -90,10 +90,14 @@ case class FlatMapGroupsInPandasExec( // argOffsets[1 .. argOffsets[0]+1] is the arg offsets for grouping attributes // argOffsets[argOffsets[0]+1 .. ] is the arg offsets for data attributes + println("grouping attributes are " + groupingAttributes.mkString(",")) + println("data attributes are " + child.output.mkString(",")) val dataAttributes = child.output.drop(groupingAttributes.length) val groupingIndicesInData = groupingAttributes.map { attribute => dataAttributes.indexWhere(attribute.semanticEquals) } + println("dataAttributes attributes are " + dataAttributes.mkString(",")) + println("groupingIndicesInData are " + groupingIndicesInData.mkString(",")) val groupingArgOffsets = new ArrayBuffer[Int] val nonDupGroupingAttributes = new ArrayBuffer[Attribute] @@ -116,15 +120,25 @@ case class FlatMapGroupsInPandasExec( } } + println("nonDupGroupingAttributes.length: " + nonDupGroupingAttributes.length) + println("nonDupGroupingAttributes.length: " + nonDupGroupingAttributes.length) + println("dataAttributes.length: " + dataAttributes.length) + val dataArgOffsets = nonDupGroupingAttributes.length until (nonDupGroupingAttributes.length + dataAttributes.length) + val argOffsets = Array(Array(groupingAttributes.length) ++ groupingArgOffsets ++ dataArgOffsets) + println("numAttributes are " + groupingAttributes.length) + println("groupingArgOffsets are " + groupingArgOffsets.mkString(",")) + println("dataArgOffsets are " + dataArgOffsets.mkString(",")) + + // Attributes after deduplication val dedupAttributes = nonDupGroupingAttributes ++ dataAttributes val dedupSchema = StructType.fromAttributes(dedupAttributes) - + println("dedupSchema is " + dedupSchema) inputRDD.mapPartitionsInternal { iter => val grouped = if (groupingAttributes.isEmpty) { Iterator(iter) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/InterleavedArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/InterleavedArrowWriter.scala new file mode 100644 index 0000000000000..f08763532eb23 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/InterleavedArrowWriter.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io.OutputStream +import java.nio.channels.Channels + +import org.apache.arrow.vector.{VectorSchemaRoot, VectorUnloader} +import org.apache.arrow.vector.ipc.WriteChannel +import org.apache.arrow.vector.ipc.message.MessageSerializer + + +class InterleavedArrowWriter( leftRoot: VectorSchemaRoot, + rightRoot: VectorSchemaRoot, + out: WriteChannel) extends AutoCloseable{ + + + private var started = false + private val leftUnloader = new VectorUnloader(leftRoot) + private val rightUnloader = new VectorUnloader(rightRoot) + + def start(): Unit = { + this.ensureStarted() + } + + def writeBatch(): Unit = { + this.ensureStarted() + val leftBatch = leftUnloader.getRecordBatch + val rightBatch = rightUnloader.getRecordBatch + MessageSerializer.serialize(out, leftBatch) + MessageSerializer.serialize(out, rightBatch) + leftBatch.close() + rightBatch.close() + } + + private def ensureStarted(): Unit = { + if (!started) { + started = true + MessageSerializer.serialize(out, leftRoot.getSchema) + MessageSerializer.serialize(out, rightRoot.getSchema) + } + } + + def end(): Unit = { + ensureStarted() + ensureEnded() + } + + def ensureEnded(): Unit = { + out.writeIntLittleEndian(0) + } + + def close(): Unit = { + out.close() + } + +} + +object InterleavedArrowWriter{ + + def apply(leftRoot: VectorSchemaRoot, + rightRoot: VectorSchemaRoot, + out: OutputStream): InterleavedArrowWriter = { + new InterleavedArrowWriter(leftRoot, rightRoot, new WriteChannel(Channels.newChannel(out))) + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MultiDfArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MultiDfArrowPythonRunner.scala new file mode 100644 index 0000000000000..f40ae92ac284e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MultiDfArrowPythonRunner.scala @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io._ +import java.net._ +import java.util.concurrent.atomic.AtomicBoolean + +import scala.collection.JavaConverters._ + +import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter} + +import org.apache.spark._ +import org.apache.spark.api.python._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.arrow.ArrowWriter +import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} +import org.apache.spark.util.Utils + + +class InterleavedArrowPythonRunner( + funcs: Seq[ChainedPythonFunctions], + evalType: Int, + argOffsets: Array[Array[Int]], + leftSchema: StructType, + rightSchema: StructType, + timeZoneId: String, + conf: Map[String, String]) + extends BasePythonRunner[(Iterator[InternalRow], Iterator[InternalRow]), ColumnarBatch]( + funcs, evalType, argOffsets) { + + protected override def newWriterThread( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[(Iterator[InternalRow], Iterator[InternalRow])], + partitionIndex: Int, + context: TaskContext): WriterThread = { + new WriterThread(env, worker, inputIterator, partitionIndex, context) { + + protected override def writeCommand(dataOut: DataOutputStream): Unit = { + + // Write config for the worker as a number of key -> value pairs of strings + dataOut.writeInt(conf.size) + for ((k, v) <- conf) { + PythonRDD.writeUTF(k, dataOut) + PythonRDD.writeUTF(v, dataOut) + } + + PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) + } + + protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { + val leftArrowSchema = ArrowUtils.toArrowSchema(leftSchema, timeZoneId) + val rightArrowSchema = ArrowUtils.toArrowSchema(rightSchema, timeZoneId) + val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"stdout writer for $pythonExec", 0, Long.MaxValue) + val leftRoot = VectorSchemaRoot.create(leftArrowSchema, allocator) + val rightRoot = VectorSchemaRoot.create(rightArrowSchema, allocator) + + Utils.tryWithSafeFinally { + val leftArrowWriter = ArrowWriter.create(leftRoot) + val rightArrowWriter = ArrowWriter.create(rightRoot) + val writer = InterleavedArrowWriter(leftRoot, rightRoot, dataOut) + writer.start() + + while (inputIterator.hasNext) { + + val (nextLeft, nextRight) = inputIterator.next() + + while (nextLeft.hasNext) { + leftArrowWriter.write(nextLeft.next()) + } + while (nextRight.hasNext) { + rightArrowWriter.write(nextRight.next()) + } + leftArrowWriter.finish() + rightArrowWriter.finish() + writer.writeBatch() + leftArrowWriter.reset() + rightArrowWriter.reset() + } + // end writes footer to the output stream and doesn't clean any resources. + // It could throw exception if the output stream is closed, so it should be + // in the try block. + writer.end() + } { + // If we close root and allocator in TaskCompletionListener, there could be a race + // condition where the writer thread keeps writing to the VectorSchemaRoot while + // it's being closed by the TaskCompletion listener. + // Closing root and allocator here is cleaner because root and allocator is owned + // by the writer thread and is only visible to the writer thread. + // + // If the writer thread is interrupted by TaskCompletionListener, it should either + // (1) in the try block, in which case it will get an InterruptedException when + // performing io, and goes into the finally block or (2) in the finally block, + // in which case it will ignore the interruption and close the resources. + leftRoot.close() + rightRoot.close() + allocator.close() + } + } + } + } + + protected override def newReaderIterator( + stream: DataInputStream, + writerThread: WriterThread, + startTime: Long, + env: SparkEnv, + worker: Socket, + releasedOrClosed: AtomicBoolean, + context: TaskContext): Iterator[ColumnarBatch] = { + new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) { + + private val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"stdin reader for $pythonExec", 0, Long.MaxValue) + + private var reader: ArrowStreamReader = _ + private var root: VectorSchemaRoot = _ + private var schema: StructType = _ + private var vectors: Array[ColumnVector] = _ + + context.addTaskCompletionListener[Unit] { _ => + if (reader != null) { + reader.close(false) + } + allocator.close() + } + + private var batchLoaded = true + + protected override def read(): ColumnarBatch = { + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + try { + if (reader != null && batchLoaded) { + batchLoaded = reader.loadNextBatch() + if (batchLoaded) { + val batch = new ColumnarBatch(vectors) + batch.setNumRows(root.getRowCount) + batch + } else { + reader.close(false) + allocator.close() + // Reach end of stream. Call `read()` again to read control data. + read() + } + } else { + stream.readInt() match { + case SpecialLengths.START_ARROW_STREAM => + reader = new ArrowStreamReader(stream, allocator) + root = reader.getVectorSchemaRoot() + schema = ArrowUtils.fromArrowSchema(root.getSchema()) + vectors = root.getFieldVectors().asScala.map { vector => + new ArrowColumnVector(vector) + }.toArray[ColumnVector] + read() + case SpecialLengths.TIMING_DATA => + handleTimingData() + read() + case SpecialLengths.PYTHON_EXCEPTION_THROWN => + throw handlePythonException() + case SpecialLengths.END_OF_DATA_SECTION => + handleEndOfDataSection() + null + } + } + } catch handleException + } + } + } +} From 64ff5acc55392fb2a6fb3bd320820f2800769503 Mon Sep 17 00:00:00 2001 From: Chris Martin Date: Thu, 20 Jun 2019 20:45:59 +0100 Subject: [PATCH 02/10] minor tidy up --- .../tests/test_pandas_udf_cogrouped_map.py | 83 ++++++------- python/pyspark/worker.py | 2 + .../execution/python/ArrowPythonRunner.scala | 70 +---------- .../python/BaseArrowPythonRunner.scala | 114 ++++++++++++++++++ .../python/FlatMapCoGroupsInPandasExec.scala | 4 + .../python/FlatMapGroupsInPandasExec.scala | 16 +-- ...ala => InterleavedArrowPythonRunner.scala} | 77 +----------- .../python/InterleavedArrowWriter.scala | 17 +-- .../apache/spark/sql/GroupedDataTest.scala | 33 +++++ 9 files changed, 203 insertions(+), 213 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/BaseArrowPythonRunner.scala rename sql/core/src/main/scala/org/apache/spark/sql/execution/python/{MultiDfArrowPythonRunner.scala => InterleavedArrowPythonRunner.scala} (62%) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/GroupedDataTest.scala diff --git a/python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py b/python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py index e98b66c64eaae..508825123df07 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py +++ b/python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py @@ -52,62 +52,51 @@ pandas_requirement_message or pyarrow_requirement_message) class CoGroupedMapPandasUDFTests(ReusedSQLTestCase): - @property - def data(self): - return self.spark.range(10).toDF('id') \ - .withColumn("vs", array([lit(i) for i in range(20, 30)])) \ - .withColumn("v", explode(col('vs'))).drop('vs') - - def test_supported_types(self): - - df1 = self.spark.createDataFrame( - pd.DataFrame.from_dict({ - 'id' : [1,1,10, 10, 1,1], - 'x' : [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] - })) - - df2 = self.spark.createDataFrame( - pd.DataFrame.from_dict({ - 'id2': [1,1,10, 10, 1,1], - 'a': [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] - })) + def test_simple(self): + + pdf1 = pd.DataFrame.from_dict({ + 'id': ['a', 'a', 'b', 'b'], + 't': [1.0, 2.0, 1.0, 2.0], + 'x': [10, 10, 30, 40] + + }) + + pdf2 = pd.DataFrame.from_dict({ + 'id2': ['a', 'b'], + 't': [0.5, 0.5], + 'y': [7.0, 8.0] + }) output_schema = StructType([ - StructField("id", LongType()), - StructField("x", DoubleType()), + StructField("id", StringType()), + StructField("t", DoubleType()), + StructField("x", IntegerType()), + StructField("y", DoubleType()), ]) + @pandas_udf(output_schema, functionType=PandasUDFType.COGROUPED_MAP) - def foo(left, right): - print("hello") + def pandas_merge(left, right): print(left) - print("goodbye") + print("#########") print(right) - return left - - output_schema2 = StructType([ - StructField("id", LongType()) - ]) - @pandas_udf(output_schema, functionType=PandasUDFType.GROUPED_MAP) - def foo2(key, df): - print('key is ' + str(key)) - print(df) - return df - - - df1.groupby(col("id") > 5)\ - .apply(foo2)\ - .show() + print("#########") + import pandas as pd + left.sort_values(by='t', inplace=True) + right.sort_values(by='t', inplace=True) + result = pd.merge_asof(left, right, on='t').reset_index() + print(result) + return result + df1 = self.spark.createDataFrame(pdf1) + df2 = self.spark.createDataFrame(pdf2) + gd1 = df1.groupby('id') + gd2 = df2.groupby('id2') -if __name__ == "__main__": - from pyspark.sql.tests.test_pandas_udf_cogrouped_map import * + gd1\ + .cogroup(gd2)\ + .apply(pandas_merge)\ + .explain() - try: - import xmlrunner - testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') - except ImportError: - testRunner = None - unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 2ccc28cffce78..1a423920b77f5 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -236,6 +236,8 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: argspec = _get_argspec(row_func) # signature was lost when wrapping it return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec) + elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: + return arg_offsets, wrap_cogrouped_map_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 3710218b2af5f..5e00eecf1b230 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -45,7 +45,7 @@ class ArrowPythonRunner( schema: StructType, timeZoneId: String, conf: Map[String, String]) - extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch]( + extends BaseArrowPythonRunner[Iterator[InternalRow]]( funcs, evalType, argOffsets) { protected override def newWriterThread( @@ -112,72 +112,4 @@ class ArrowPythonRunner( } } - protected override def newReaderIterator( - stream: DataInputStream, - writerThread: WriterThread, - startTime: Long, - env: SparkEnv, - worker: Socket, - releasedOrClosed: AtomicBoolean, - context: TaskContext): Iterator[ColumnarBatch] = { - new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) { - - private val allocator = ArrowUtils.rootAllocator.newChildAllocator( - s"stdin reader for $pythonExec", 0, Long.MaxValue) - - private var reader: ArrowStreamReader = _ - private var root: VectorSchemaRoot = _ - private var schema: StructType = _ - private var vectors: Array[ColumnVector] = _ - - context.addTaskCompletionListener[Unit] { _ => - if (reader != null) { - reader.close(false) - } - allocator.close() - } - - private var batchLoaded = true - - protected override def read(): ColumnarBatch = { - if (writerThread.exception.isDefined) { - throw writerThread.exception.get - } - try { - if (reader != null && batchLoaded) { - batchLoaded = reader.loadNextBatch() - if (batchLoaded) { - val batch = new ColumnarBatch(vectors) - batch.setNumRows(root.getRowCount) - batch - } else { - reader.close(false) - allocator.close() - // Reach end of stream. Call `read()` again to read control data. - read() - } - } else { - stream.readInt() match { - case SpecialLengths.START_ARROW_STREAM => - reader = new ArrowStreamReader(stream, allocator) - root = reader.getVectorSchemaRoot() - schema = ArrowUtils.fromArrowSchema(root.getSchema()) - vectors = root.getFieldVectors().asScala.map { vector => - new ArrowColumnVector(vector) - }.toArray[ColumnVector] - read() - case SpecialLengths.TIMING_DATA => - handleTimingData() - read() - case SpecialLengths.PYTHON_EXCEPTION_THROWN => - throw handlePythonException() - case SpecialLengths.END_OF_DATA_SECTION => - handleEndOfDataSection() - null - } - } - } catch handleException - } - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BaseArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BaseArrowPythonRunner.scala new file mode 100644 index 0000000000000..3cba06dcf7d52 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BaseArrowPythonRunner.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.python + +import java.io._ +import java.net._ +import java.util.concurrent.atomic.AtomicBoolean + +import scala.collection.JavaConverters._ + +import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.ipc.ArrowStreamReader + +import org.apache.spark._ +import org.apache.spark.api.python._ +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} + + +abstract class BaseArrowPythonRunner[T]( + funcs: Seq[ChainedPythonFunctions], + evalType: Int, + argOffsets: Array[Array[Int]]) + extends BasePythonRunner[T, ColumnarBatch]( + funcs, evalType, argOffsets) { + + + protected override def newReaderIterator( + stream: DataInputStream, + writerThread: WriterThread, + startTime: Long, + env: SparkEnv, + worker: Socket, + releasedOrClosed: AtomicBoolean, + context: TaskContext): Iterator[ColumnarBatch] = { + + new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) { + + private val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"stdin reader for $pythonExec", 0, Long.MaxValue) + + private var reader: ArrowStreamReader = _ + private var root: VectorSchemaRoot = _ + private var schema: StructType = _ + private var vectors: Array[ColumnVector] = _ + + context.addTaskCompletionListener[Unit] { _ => + if (reader != null) { + reader.close(false) + } + allocator.close() + } + + private var batchLoaded = true + + protected override def read(): ColumnarBatch = { + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + try { + if (reader != null && batchLoaded) { + batchLoaded = reader.loadNextBatch() + if (batchLoaded) { + val batch = new ColumnarBatch(vectors) + batch.setNumRows(root.getRowCount) + batch + } else { + reader.close(false) + allocator.close() + // Reach end of stream. Call `read()` again to read control data. + read() + } + } else { + stream.readInt() match { + case SpecialLengths.START_ARROW_STREAM => + reader = new ArrowStreamReader(stream, allocator) + root = reader.getVectorSchemaRoot() + schema = ArrowUtils.fromArrowSchema(root.getSchema()) + vectors = root.getFieldVectors().asScala.map { vector => + new ArrowColumnVector(vector) + }.toArray[ColumnVector] + read() + case SpecialLengths.TIMING_DATA => + handleTimingData() + read() + case SpecialLengths.PYTHON_EXCEPTION_THROWN => + throw handlePythonException() + case SpecialLengths.END_OF_DATA_SECTION => + handleEndOfDataSection() + null + } + } + } catch handleException + } + } + } +} + + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala index 2dc517354812a..22c781b9fe42c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala @@ -60,12 +60,16 @@ case class FlatMapCoGroupsInPandasExec( val sessionLocalTimeZone = conf.sessionLocalTimeZone val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) + left.execute().zipPartitions(right.execute()) { (leftData, rightData) => val leftGrouped = GroupedIterator(leftData, leftGroup, left.output) val rightGrouped = GroupedIterator(rightData, rightGroup, right.output) val cogroup = new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup) .map{case (k, l, r) => (l, r)} val context = TaskContext.get() + println("in zipPartitions: left schema is " + left.schema) + println("in zipPartitions: right schema is " + right.schema) + val columnarBatchIter = new InterleavedArrowPythonRunner( chainedFunc, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index 46db29a250520..7b0e014f9ca48 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -90,14 +90,10 @@ case class FlatMapGroupsInPandasExec( // argOffsets[1 .. argOffsets[0]+1] is the arg offsets for grouping attributes // argOffsets[argOffsets[0]+1 .. ] is the arg offsets for data attributes - println("grouping attributes are " + groupingAttributes.mkString(",")) - println("data attributes are " + child.output.mkString(",")) val dataAttributes = child.output.drop(groupingAttributes.length) val groupingIndicesInData = groupingAttributes.map { attribute => dataAttributes.indexWhere(attribute.semanticEquals) } - println("dataAttributes attributes are " + dataAttributes.mkString(",")) - println("groupingIndicesInData are " + groupingIndicesInData.mkString(",")) val groupingArgOffsets = new ArrayBuffer[Int] val nonDupGroupingAttributes = new ArrayBuffer[Attribute] @@ -120,25 +116,15 @@ case class FlatMapGroupsInPandasExec( } } - println("nonDupGroupingAttributes.length: " + nonDupGroupingAttributes.length) - println("nonDupGroupingAttributes.length: " + nonDupGroupingAttributes.length) - println("dataAttributes.length: " + dataAttributes.length) - val dataArgOffsets = nonDupGroupingAttributes.length until (nonDupGroupingAttributes.length + dataAttributes.length) - val argOffsets = Array(Array(groupingAttributes.length) ++ groupingArgOffsets ++ dataArgOffsets) - println("numAttributes are " + groupingAttributes.length) - println("groupingArgOffsets are " + groupingArgOffsets.mkString(",")) - println("dataArgOffsets are " + dataArgOffsets.mkString(",")) - - // Attributes after deduplication val dedupAttributes = nonDupGroupingAttributes ++ dataAttributes val dedupSchema = StructType.fromAttributes(dedupAttributes) - println("dedupSchema is " + dedupSchema) + inputRDD.mapPartitionsInternal { iter => val grouped = if (groupingAttributes.isEmpty) { Iterator(iter) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MultiDfArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/InterleavedArrowPythonRunner.scala similarity index 62% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/python/MultiDfArrowPythonRunner.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/python/InterleavedArrowPythonRunner.scala index f40ae92ac284e..b39885ee47a2d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MultiDfArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/InterleavedArrowPythonRunner.scala @@ -19,12 +19,8 @@ package org.apache.spark.sql.execution.python import java.io._ import java.net._ -import java.util.concurrent.atomic.AtomicBoolean - -import scala.collection.JavaConverters._ import org.apache.arrow.vector.VectorSchemaRoot -import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter} import org.apache.spark._ import org.apache.spark.api.python._ @@ -32,7 +28,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.arrow.ArrowWriter import org.apache.spark.sql.types._ import org.apache.spark.sql.util.ArrowUtils -import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} +import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.Utils @@ -44,7 +40,7 @@ class InterleavedArrowPythonRunner( rightSchema: StructType, timeZoneId: String, conf: Map[String, String]) - extends BasePythonRunner[(Iterator[InternalRow], Iterator[InternalRow]), ColumnarBatch]( + extends BaseArrowPythonRunner[(Iterator[InternalRow], Iterator[InternalRow])]( funcs, evalType, argOffsets) { protected override def newWriterThread( @@ -119,73 +115,4 @@ class InterleavedArrowPythonRunner( } } } - - protected override def newReaderIterator( - stream: DataInputStream, - writerThread: WriterThread, - startTime: Long, - env: SparkEnv, - worker: Socket, - releasedOrClosed: AtomicBoolean, - context: TaskContext): Iterator[ColumnarBatch] = { - new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) { - - private val allocator = ArrowUtils.rootAllocator.newChildAllocator( - s"stdin reader for $pythonExec", 0, Long.MaxValue) - - private var reader: ArrowStreamReader = _ - private var root: VectorSchemaRoot = _ - private var schema: StructType = _ - private var vectors: Array[ColumnVector] = _ - - context.addTaskCompletionListener[Unit] { _ => - if (reader != null) { - reader.close(false) - } - allocator.close() - } - - private var batchLoaded = true - - protected override def read(): ColumnarBatch = { - if (writerThread.exception.isDefined) { - throw writerThread.exception.get - } - try { - if (reader != null && batchLoaded) { - batchLoaded = reader.loadNextBatch() - if (batchLoaded) { - val batch = new ColumnarBatch(vectors) - batch.setNumRows(root.getRowCount) - batch - } else { - reader.close(false) - allocator.close() - // Reach end of stream. Call `read()` again to read control data. - read() - } - } else { - stream.readInt() match { - case SpecialLengths.START_ARROW_STREAM => - reader = new ArrowStreamReader(stream, allocator) - root = reader.getVectorSchemaRoot() - schema = ArrowUtils.fromArrowSchema(root.getSchema()) - vectors = root.getFieldVectors().asScala.map { vector => - new ArrowColumnVector(vector) - }.toArray[ColumnVector] - read() - case SpecialLengths.TIMING_DATA => - handleTimingData() - read() - case SpecialLengths.PYTHON_EXCEPTION_THROWN => - throw handlePythonException() - case SpecialLengths.END_OF_DATA_SECTION => - handleEndOfDataSection() - null - } - } - } catch handleException - } - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/InterleavedArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/InterleavedArrowWriter.scala index f08763532eb23..eb9f1d4494b91 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/InterleavedArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/InterleavedArrowWriter.scala @@ -22,7 +22,7 @@ import java.nio.channels.Channels import org.apache.arrow.vector.{VectorSchemaRoot, VectorUnloader} import org.apache.arrow.vector.ipc.WriteChannel -import org.apache.arrow.vector.ipc.message.MessageSerializer +import org.apache.arrow.vector.ipc.message.{ArrowRecordBatch, MessageSerializer} class InterleavedArrowWriter( leftRoot: VectorSchemaRoot, @@ -40,12 +40,15 @@ class InterleavedArrowWriter( leftRoot: VectorSchemaRoot, def writeBatch(): Unit = { this.ensureStarted() - val leftBatch = leftUnloader.getRecordBatch - val rightBatch = rightUnloader.getRecordBatch - MessageSerializer.serialize(out, leftBatch) - MessageSerializer.serialize(out, rightBatch) - leftBatch.close() - rightBatch.close() + writeRecordBatch(leftUnloader.getRecordBatch) + writeRecordBatch(rightUnloader.getRecordBatch) + } + + private def writeRecordBatch(b: ArrowRecordBatch): Unit = { + try + MessageSerializer.serialize(out, b) + finally + b.close() } private def ensureStarted(): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GroupedDataTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/GroupedDataTest.scala new file mode 100644 index 0000000000000..cc9ba8bfc5325 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/GroupedDataTest.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.functions.sum +import org.apache.spark.sql.test.{SharedSQLContext, SharedSparkSession} +import org.scalatest.concurrent.Eventually +import org.apache.spark.sql.functions._ + +class GroupedDataTest extends QueryTest with SharedSQLContext with Eventually { + + test("SPARK-7150 range api") { + val df = spark + .range(0, 100) + .withColumn("x", lit("a")) + } + +} From 6d039e366488945f5e673c6aba29bfd710d46ef8 Mon Sep 17 00:00:00 2001 From: Chris Martin Date: Fri, 21 Jun 2019 06:47:28 +0100 Subject: [PATCH 03/10] removed incorrect test --- python/pyspark/sql/tests/test_group.py | 46 -------------------------- 1 file changed, 46 deletions(-) delete mode 100644 python/pyspark/sql/tests/test_group.py diff --git a/python/pyspark/sql/tests/test_group.py b/python/pyspark/sql/tests/test_group.py deleted file mode 100644 index 6de1b8ea0b3ce..0000000000000 --- a/python/pyspark/sql/tests/test_group.py +++ /dev/null @@ -1,46 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from pyspark.sql import Row -from pyspark.testing.sqlutils import ReusedSQLTestCase - - -class GroupTests(ReusedSQLTestCase): - - def test_aggregator(self): - df = self.df - g = df.groupBy() - self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0])) - self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect()) - - from pyspark.sql import functions - self.assertEqual((0, u'99'), - tuple(g.agg(functions.first(df.key), functions.last(df.value)).first())) - self.assertTrue(95 < g.agg(functions.approx_count_distinct(df.key)).first()[0]) - self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0]) - - -if __name__ == "__main__": - import unittest - from pyspark.sql.tests.test_group import * - - try: - import xmlrunner - testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') - except ImportError: - testRunner = None - unittest.main(testRunner=testRunner, verbosity=2) From d8a5c5dad1acfb4efb0a6255280bc95bb0ce99ba Mon Sep 17 00:00:00 2001 From: Chris Martin Date: Tue, 25 Jun 2019 10:50:40 +0100 Subject: [PATCH 04/10] tidies up test, fixed output cols --- python/pyspark/sql/cogroup.py | 10 ++- .../tests/test_pandas_udf_cogrouped_map.py | 81 +++++++++---------- 2 files changed, 44 insertions(+), 47 deletions(-) diff --git a/python/pyspark/sql/cogroup.py b/python/pyspark/sql/cogroup.py index 8946ab1202651..18dc397c8e348 100644 --- a/python/pyspark/sql/cogroup.py +++ b/python/pyspark/sql/cogroup.py @@ -9,7 +9,13 @@ def __init__(self, gd1, gd2): self.sql_ctx = gd1.sql_ctx def apply(self, udf): - df = self._gd1._df - udf_column = udf(*[df[col] for col in df.columns]) + all_cols = self._extract_cols(self._gd1) + self._extract_cols(self._gd2) + udf_column = udf(*all_cols) jdf = self._gd1._jgd.flatMapCoGroupsInPandas(self._gd2._jgd, udf_column._jc.expr()) return DataFrame(jdf, self.sql_ctx) + + @staticmethod + def _extract_cols(gd): + df = gd._df + return [df[col] for col in df.columns] + diff --git a/python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py b/python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py index 508825123df07..d74f9b10325ed 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py +++ b/python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py @@ -52,51 +52,42 @@ pandas_requirement_message or pyarrow_requirement_message) class CoGroupedMapPandasUDFTests(ReusedSQLTestCase): + @property + def data1(self): + return self.spark.range(10).toDF('id') \ + .withColumn("ks", array([lit(i) for i in range(20, 30)])) \ + .withColumn("k", explode(col('ks')))\ + .withColumn("v", col('k') * 10)\ + .drop('ks') + + @property + def data2(self): + return self.spark.range(10).toDF('id') \ + .withColumn("ks", array([lit(i) for i in range(20, 30)])) \ + .withColumn("k", explode(col('ks'))) \ + .withColumn("v2", col('k') * 100) \ + .drop('ks') + def test_simple(self): + import pandas as pd + + l = self.data1 + r = self.data2 + + @pandas_udf('id long, k int, v int, v2 int', PandasUDFType.COGROUPED_MAP) + def merge_pandas(left, right): + return pd.merge(left, right, how='outer', on=['k', 'id']) + + # TODO: Grouping by a string fails to resolve here as analyzer cannot determine side + result = l\ + .groupby(l.id)\ + .cogroup(r.groupby(r.id))\ + .apply(merge_pandas)\ + .sort(['id', 'k'])\ + .toPandas() + + expected = pd\ + .merge(l.toPandas(), r.toPandas(), how='outer', on=['k', 'id']) - pdf1 = pd.DataFrame.from_dict({ - 'id': ['a', 'a', 'b', 'b'], - 't': [1.0, 2.0, 1.0, 2.0], - 'x': [10, 10, 30, 40] - - }) - - pdf2 = pd.DataFrame.from_dict({ - 'id2': ['a', 'b'], - 't': [0.5, 0.5], - 'y': [7.0, 8.0] - }) - - output_schema = StructType([ - StructField("id", StringType()), - StructField("t", DoubleType()), - StructField("x", IntegerType()), - StructField("y", DoubleType()), - ]) - - - @pandas_udf(output_schema, functionType=PandasUDFType.COGROUPED_MAP) - def pandas_merge(left, right): - print(left) - print("#########") - print(right) - print("#########") - import pandas as pd - left.sort_values(by='t', inplace=True) - right.sort_values(by='t', inplace=True) - result = pd.merge_asof(left, right, on='t').reset_index() - print(result) - return result - - - df1 = self.spark.createDataFrame(pdf1) - df2 = self.spark.createDataFrame(pdf2) - - gd1 = df1.groupby('id') - gd2 = df2.groupby('id2') - - gd1\ - .cogroup(gd2)\ - .apply(pandas_merge)\ - .explain() + assert_frame_equal(expected, result, check_column_type=_check_column_type) From 73188f632a85098567302d1470e8190dbe65a783 Mon Sep 17 00:00:00 2001 From: Chris Martin Date: Tue, 25 Jun 2019 10:55:02 +0100 Subject: [PATCH 05/10] removed incorrect file --- .../python/FlatMapCoGroupsInPandasExec.scala | 4 +-- .../apache/spark/sql/GroupedDataTest.scala | 33 ------------------- 2 files changed, 1 insertion(+), 36 deletions(-) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/GroupedDataTest.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala index 22c781b9fe42c..12620264de087 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala @@ -67,8 +67,6 @@ case class FlatMapCoGroupsInPandasExec( val cogroup = new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup) .map{case (k, l, r) => (l, r)} val context = TaskContext.get() - println("in zipPartitions: left schema is " + left.schema) - println("in zipPartitions: right schema is " + right.schema) val columnarBatchIter = new InterleavedArrowPythonRunner( chainedFunc, @@ -83,7 +81,7 @@ case class FlatMapCoGroupsInPandasExec( val unsafeProj = UnsafeProjection.create(output, output) columnarBatchIter.flatMap { batch => - // Grouped Map UDF returns a StructType column in ColumnarBatch, select the children here + // UDF returns a StructType column in ColumnarBatch, select the children here val structVector = batch.column(0).asInstanceOf[ArrowColumnVector] val outputVectors = output.indices.map(structVector.getChild) val flattenedBatch = new ColumnarBatch(outputVectors.toArray) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GroupedDataTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/GroupedDataTest.scala deleted file mode 100644 index cc9ba8bfc5325..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/GroupedDataTest.scala +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import org.apache.spark.sql.functions.sum -import org.apache.spark.sql.test.{SharedSQLContext, SharedSparkSession} -import org.scalatest.concurrent.Eventually -import org.apache.spark.sql.functions._ - -class GroupedDataTest extends QueryTest with SharedSQLContext with Eventually { - - test("SPARK-7150 range api") { - val df = spark - .range(0, 100) - .withColumn("x", lit("a")) - } - -} From 690fa14e4ca511b82de08e22f721caf5ec930e0b Mon Sep 17 00:00:00 2001 From: Chris Martin Date: Tue, 25 Jun 2019 10:55:09 +0100 Subject: [PATCH 06/10] Revert: removed incorrect test --- python/pyspark/sql/tests/test_group.py | 46 ++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 python/pyspark/sql/tests/test_group.py diff --git a/python/pyspark/sql/tests/test_group.py b/python/pyspark/sql/tests/test_group.py new file mode 100644 index 0000000000000..6de1b8ea0b3ce --- /dev/null +++ b/python/pyspark/sql/tests/test_group.py @@ -0,0 +1,46 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.sql import Row +from pyspark.testing.sqlutils import ReusedSQLTestCase + + +class GroupTests(ReusedSQLTestCase): + + def test_aggregator(self): + df = self.df + g = df.groupBy() + self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0])) + self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect()) + + from pyspark.sql import functions + self.assertEqual((0, u'99'), + tuple(g.agg(functions.first(df.key), functions.last(df.value)).first())) + self.assertTrue(95 < g.agg(functions.approx_count_distinct(df.key)).first()[0]) + self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0]) + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.test_group import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports') + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) From e3b66acf4849f3ff0278ac4c5fcfc5b11e4fc1fe Mon Sep 17 00:00:00 2001 From: Chris Martin Date: Tue, 25 Jun 2019 17:25:39 +0100 Subject: [PATCH 07/10] fix for resolving key cols --- python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py | 3 +-- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 5 +++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py b/python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py index d74f9b10325ed..332bc260541a6 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py +++ b/python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py @@ -78,9 +78,8 @@ def test_simple(self): def merge_pandas(left, right): return pd.merge(left, right, how='outer', on=['k', 'id']) - # TODO: Grouping by a string fails to resolve here as analyzer cannot determine side result = l\ - .groupby(l.id)\ + .groupby('id')\ .cogroup(r.groupby(r.id))\ .apply(merge_pandas)\ .sort(['id', 'k'])\ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 60517f11a2491..9dad538686b30 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -970,6 +970,10 @@ class Analyzer( // To resolve duplicate expression IDs for Join and Intersect case j @ Join(left, right, _, _, _) if !j.duplicateResolved => j.copy(right = dedupRight(left, right)) + case f @ FlatMapCoGroupsInPandas(leftAttributes, rightAttributes, _, _, left, right) => + val leftAttributes2 = leftAttributes.map(x => resolveExpressionBottomUp(x, left).asInstanceOf[Attribute]) + val rightAttributes2 = rightAttributes.map(x => resolveExpressionBottomUp(x, right).asInstanceOf[Attribute]) + f.copy(leftAttributes=leftAttributes2, rightAttributes=rightAttributes2) case i @ Intersect(left, right, _) if !i.duplicateResolved => i.copy(right = dedupRight(left, right)) case e @ Except(left, right, _) if !e.duplicateResolved => @@ -2269,6 +2273,7 @@ class Analyzer( } } + /** * Removes natural or using joins by calculating output columns based on output from two sides, * Then apply a Project on a normal Join to eliminate natural or using join. From 8007fa66dd0810cd822137ae24f9ede08e498114 Mon Sep 17 00:00:00 2001 From: Chris Martin Date: Thu, 27 Jun 2019 07:54:06 +0100 Subject: [PATCH 08/10] common trait for grouped mandas udfs --- .../spark/sql/RelationalGroupedDataset.scala | 14 +- .../python/AbstractPandasGroupExec.scala | 128 ++++++++++++++++++ .../python/FlatMapCoGroupsInPandasExec.scala | 44 ++---- .../python/FlatMapGroupsInPandasExec.scala | 80 +---------- 4 files changed, 156 insertions(+), 110 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/python/AbstractPandasGroupExec.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 147cc00c0ba91..0018f6379e8fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -47,8 +47,8 @@ import org.apache.spark.sql.types.{NumericType, StructType} */ @Stable class RelationalGroupedDataset protected[sql]( - private val df: DataFrame, - private val groupingExprs: Seq[Expression], + val df: DataFrame, + val groupingExprs: Seq[Expression], groupType: RelationalGroupedDataset.GroupType) { private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = { @@ -542,11 +542,15 @@ class RelationalGroupedDataset protected[sql]( val leftAttributes = leftGroupingNamedExpressions.map(_.toAttribute) val rightAttributes = rightGroupingNamedExpressions.map(_.toAttribute) - val left = df.logicalPlan - val right = r.df.logicalPlan + + val leftChild = df.logicalPlan + val rightChild = r.df.logicalPlan + + val left = Project(leftGroupingNamedExpressions ++ leftChild.output, leftChild) + val right = Project(rightGroupingNamedExpressions ++ rightChild.output, rightChild) + val output = expr.dataType.asInstanceOf[StructType].toAttributes val plan = FlatMapCoGroupsInPandas(leftAttributes, rightAttributes, expr, output, left, right) - Dataset.ofRows(df.sparkSession, plan) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AbstractPandasGroupExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AbstractPandasGroupExec.scala new file mode 100644 index 0000000000000..0305f79557781 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AbstractPandasGroupExec.scala @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.python + +import org.apache.spark.TaskContext +import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, PythonUDF, UnsafeProjection} +import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} + +import scala.collection.mutable.ArrayBuffer +import scala.collection.JavaConverters._ + +trait AbstractPandasGroupExec extends SparkPlan { + + protected val sessionLocalTimeZone = conf.sessionLocalTimeZone + + protected val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) + + protected def chainedFunc = Seq( + ChainedPythonFunctions(Seq(func.asInstanceOf[PythonUDF].func))) + + def output: Seq[Attribute] + + def func: Expression + + protected def executePython[T](data: Iterator[T], + runner: BasePythonRunner[T, ColumnarBatch]): Iterator[InternalRow] = { + + val context = TaskContext.get() + val columnarBatchIter = runner.compute(data, context.partitionId(), context) + val unsafeProj = UnsafeProjection.create(output, output) + + columnarBatchIter.flatMap { batch => + // UDF returns a StructType column in ColumnarBatch, select the children here + val structVector = batch.column(0).asInstanceOf[ArrowColumnVector] + val outputVectors = output.indices.map(structVector.getChild) + val flattenedBatch = new ColumnarBatch(outputVectors.toArray) + flattenedBatch.setNumRows(batch.numRows()) + flattenedBatch.rowIterator.asScala + }.map(unsafeProj) + + } + + protected def groupAndDedup( + input: Iterator[InternalRow], groupingAttributes: Seq[Attribute], + inputSchema: Seq[Attribute], dedupSchema: Seq[Attribute]): Iterator[Iterator[InternalRow]] = { + if (groupingAttributes.isEmpty) { + Iterator(input) + } else { + val groupedIter = GroupedIterator(input, groupingAttributes, inputSchema) + val dedupProj = UnsafeProjection.create(dedupSchema, inputSchema) + groupedIter.map { + case (_, groupedRowIter) => groupedRowIter.map(dedupProj) + } + } + } + + protected def createSchema(child: SparkPlan, groupingAttributes: Seq[Attribute]) + : (StructType, Seq[Attribute], Array[Array[Int]]) = { + + // Deduplicate the grouping attributes. + // If a grouping attribute also appears in data attributes, then we don't need to send the + // grouping attribute to Python worker. If a grouping attribute is not in data attributes, + // then we need to send this grouping attribute to python worker. + // + // We use argOffsets to distinguish grouping attributes and data attributes as following: + // + // argOffsets[0] is the length of grouping attributes + // argOffsets[1 .. argOffsets[0]+1] is the arg offsets for grouping attributes + // argOffsets[argOffsets[0]+1 .. ] is the arg offsets for data attributes + + val dataAttributes = child.output.drop(groupingAttributes.length) + val groupingIndicesInData = groupingAttributes.map { attribute => + dataAttributes.indexWhere(attribute.semanticEquals) + } + + val groupingArgOffsets = new ArrayBuffer[Int] + val nonDupGroupingAttributes = new ArrayBuffer[Attribute] + val nonDupGroupingSize = groupingIndicesInData.count(_ == -1) + + // Non duplicate grouping attributes are added to nonDupGroupingAttributes and + // their offsets are 0, 1, 2 ... + // Duplicate grouping attributes are NOT added to nonDupGroupingAttributes and + // their offsets are n + index, where n is the total number of non duplicate grouping + // attributes and index is the index in the data attributes that the grouping attribute + // is a duplicate of. + + groupingAttributes.zip(groupingIndicesInData).foreach { + case (attribute, index) => + if (index == -1) { + groupingArgOffsets += nonDupGroupingAttributes.length + nonDupGroupingAttributes += attribute + } else { + groupingArgOffsets += index + nonDupGroupingSize + } + } + + val dataArgOffsets = nonDupGroupingAttributes.length until + (nonDupGroupingAttributes.length + dataAttributes.length) + + val argOffsets = Array(Array(groupingAttributes.length) ++ groupingArgOffsets ++ dataArgOffsets) + + // Attributes after deduplication + val dedupAttributes = nonDupGroupingAttributes ++ dataAttributes + val dedupSchema = StructType.fromAttributes(dedupAttributes) + (dedupSchema, dedupAttributes, argOffsets) + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala index 12620264de087..a70b4763fc0f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala @@ -17,17 +17,12 @@ package org.apache.spark.sql.execution.python -import scala.collection.JavaConverters._ - -import org.apache.spark.TaskContext -import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} +import org.apache.spark.api.python.PythonEvalType import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.{BinaryExecNode, CoGroupedIterator, GroupedIterator, SparkPlan} -import org.apache.spark.sql.util.ArrowUtils -import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} case class FlatMapCoGroupsInPandasExec( leftGroup: Seq[Attribute], @@ -36,9 +31,7 @@ case class FlatMapCoGroupsInPandasExec( output: Seq[Attribute], left: SparkPlan, right: SparkPlan) - extends BinaryExecNode { - - private val pandasFunction = func.asInstanceOf[PythonUDF].func + extends BinaryExecNode with AbstractPandasGroupExec { override def outputPartitioning: Partitioning = left.outputPartitioning @@ -53,41 +46,30 @@ case class FlatMapCoGroupsInPandasExec( .map(SortOrder(_, Ascending)) :: rightGroup.map(SortOrder(_, Ascending)) :: Nil } - override protected def doExecute(): RDD[InternalRow] = { - val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) - val sessionLocalTimeZone = conf.sessionLocalTimeZone - val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) - + val (schemaLeft, attrLeft, _) = createSchema(left, leftGroup) + val (schemaRight, attrRight, _) = createSchema(right, rightGroup) left.execute().zipPartitions(right.execute()) { (leftData, rightData) => val leftGrouped = GroupedIterator(leftData, leftGroup, left.output) val rightGrouped = GroupedIterator(rightData, rightGroup, right.output) - val cogroup = new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup) - .map{case (k, l, r) => (l, r)} - val context = TaskContext.get() + val projLeft = UnsafeProjection.create(attrLeft, left.output) + val projRight = UnsafeProjection.create(attrRight, right.output) + val data = new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup) + .map{case (k, l, r) => (l.map(projLeft), r.map(projRight))} - val columnarBatchIter = new InterleavedArrowPythonRunner( + val runner = new InterleavedArrowPythonRunner( chainedFunc, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, Array(Array.empty), - left.schema, - right.schema, + schemaLeft, + schemaRight, sessionLocalTimeZone, - pythonRunnerConf).compute(cogroup, context.partitionId(), context) - + pythonRunnerConf) - val unsafeProj = UnsafeProjection.create(output, output) + executePython(data, runner) - columnarBatchIter.flatMap { batch => - // UDF returns a StructType column in ColumnarBatch, select the children here - val structVector = batch.column(0).asInstanceOf[ArrowColumnVector] - val outputVectors = output.indices.map(structVector.getChild) - val flattenedBatch = new ColumnarBatch(outputVectors.toArray) - flattenedBatch.setNumRows(batch.numRows()) - flattenedBatch.rowIterator.asScala - }.map(unsafeProj) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index 267698d1bca50..474bbe04b8d62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -53,9 +53,7 @@ case class FlatMapGroupsInPandasExec( func: Expression, output: Seq[Attribute], child: SparkPlan) - extends UnaryExecNode { - - private val pandasFunction = func.asInstanceOf[PythonUDF].func + extends UnaryExecNode with AbstractPandasGroupExec { override def outputPartitioning: Partitioning = child.outputPartitioning @@ -75,88 +73,22 @@ case class FlatMapGroupsInPandasExec( override protected def doExecute(): RDD[InternalRow] = { val inputRDD = child.execute() - val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) - val sessionLocalTimeZone = conf.sessionLocalTimeZone - val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf) - - // Deduplicate the grouping attributes. - // If a grouping attribute also appears in data attributes, then we don't need to send the - // grouping attribute to Python worker. If a grouping attribute is not in data attributes, - // then we need to send this grouping attribute to python worker. - // - // We use argOffsets to distinguish grouping attributes and data attributes as following: - // - // argOffsets[0] is the length of grouping attributes - // argOffsets[1 .. argOffsets[0]+1] is the arg offsets for grouping attributes - // argOffsets[argOffsets[0]+1 .. ] is the arg offsets for data attributes - - val dataAttributes = child.output.drop(groupingAttributes.length) - val groupingIndicesInData = groupingAttributes.map { attribute => - dataAttributes.indexWhere(attribute.semanticEquals) - } - - val groupingArgOffsets = new ArrayBuffer[Int] - val nonDupGroupingAttributes = new ArrayBuffer[Attribute] - val nonDupGroupingSize = groupingIndicesInData.count(_ == -1) - - // Non duplicate grouping attributes are added to nonDupGroupingAttributes and - // their offsets are 0, 1, 2 ... - // Duplicate grouping attributes are NOT added to nonDupGroupingAttributes and - // their offsets are n + index, where n is the total number of non duplicate grouping - // attributes and index is the index in the data attributes that the grouping attribute - // is a duplicate of. - - groupingAttributes.zip(groupingIndicesInData).foreach { - case (attribute, index) => - if (index == -1) { - groupingArgOffsets += nonDupGroupingAttributes.length - nonDupGroupingAttributes += attribute - } else { - groupingArgOffsets += index + nonDupGroupingSize - } - } - - val dataArgOffsets = nonDupGroupingAttributes.length until - (nonDupGroupingAttributes.length + dataAttributes.length) - - val argOffsets = Array(Array(groupingAttributes.length) ++ groupingArgOffsets ++ dataArgOffsets) - - // Attributes after deduplication - val dedupAttributes = nonDupGroupingAttributes ++ dataAttributes - val dedupSchema = StructType.fromAttributes(dedupAttributes) + val (dedupSchema, dedupAttributes, argOffsets) = createSchema(child, groupingAttributes) // Map grouped rows to ArrowPythonRunner results, Only execute if partition is not empty inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else { - val grouped = if (groupingAttributes.isEmpty) { - Iterator(iter) - } else { - val groupedIter = GroupedIterator(iter, groupingAttributes, child.output) - val dedupProj = UnsafeProjection.create(dedupAttributes, child.output) - groupedIter.map { - case (_, groupedRowIter) => groupedRowIter.map(dedupProj) - } - } - val context = TaskContext.get() + val data = groupAndDedup(iter, groupingAttributes, child.output, dedupAttributes) - val columnarBatchIter = new ArrowPythonRunner( + val runner = new ArrowPythonRunner( chainedFunc, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, argOffsets, dedupSchema, sessionLocalTimeZone, - pythonRunnerConf).compute(grouped, context.partitionId(), context) - - val unsafeProj = UnsafeProjection.create(output, output) + pythonRunnerConf) - columnarBatchIter.flatMap { batch => - // Grouped Map UDF returns a StructType column in ColumnarBatch, select the children here - val structVector = batch.column(0).asInstanceOf[ArrowColumnVector] - val outputVectors = output.indices.map(structVector.getChild) - val flattenedBatch = new ColumnarBatch(outputVectors.toArray) - flattenedBatch.setNumRows(batch.numRows()) - flattenedBatch.rowIterator.asScala - }.map(unsafeProj) + executePython(data, runner) }} } } From 86d1385064e754bf52ae5d1ed6e5806b2fa47949 Mon Sep 17 00:00:00 2001 From: Chris Martin Date: Tue, 2 Jul 2019 09:46:11 +0100 Subject: [PATCH 09/10] fixed iterator issue under python2 --- python/pyspark/serializers.py | 3 +++ .../pyspark/sql/tests/test_pandas_udf_cogrouped_map.py | 9 +++++++++ 2 files changed, 12 insertions(+) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 9616450b1803c..d8925b8091730 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -373,6 +373,9 @@ def __next__(self): batch2 = pa.read_record_batch(self._reader.read_next_message(), self._schema2) return batch1, batch2 + def next(self): + return self.__next__() + class ArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer): """ diff --git a/python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py b/python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py index 332bc260541a6..9104aa2193dac 100644 --- a/python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py +++ b/python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py @@ -90,3 +90,12 @@ def merge_pandas(left, right): assert_frame_equal(expected, result, check_column_type=_check_column_type) +if __name__ == "__main__": + from pyspark.sql.tests.test_pandas_udf_cogrouped_map import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) From d15dabbf71ad3007ea0c37e71c997e6fa1799e51 Mon Sep 17 00:00:00 2001 From: Chris Martin Date: Tue, 2 Jul 2019 09:59:14 +0100 Subject: [PATCH 10/10] add license --- python/pyspark/sql/cogroup.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/python/pyspark/sql/cogroup.py b/python/pyspark/sql/cogroup.py index 18dc397c8e348..b758d1b7d8c45 100644 --- a/python/pyspark/sql/cogroup.py +++ b/python/pyspark/sql/cogroup.py @@ -1,3 +1,20 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + from pyspark.sql.dataframe import DataFrame