From 5d103b9473af13d0d49187b52fd25656e74f2acb Mon Sep 17 00:00:00 2001 From: jiangtian Date: Thu, 25 Dec 2025 16:47:56 +0800 Subject: [PATCH 1/2] Make PartialProject support array and map with null values --- .../org/apache/gluten/udf/DuplicateArray.java | 85 +++++++ .../sql/execution/GlutenHiveUDFSuite.scala | 55 ++++- .../gluten/vectorized/ArrowColumnarArray.java | 223 ++++++++++++++++++ .../gluten/vectorized/ArrowColumnarMap.java | 55 +++++ .../vectorized/ArrowWritableColumnVector.java | 30 ++- .../gluten/vectorized/ArrowColumnarRow.scala | 10 +- 6 files changed, 444 insertions(+), 14 deletions(-) create mode 100644 backends-velox/src/test/java/org/apache/gluten/udf/DuplicateArray.java create mode 100644 gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowColumnarArray.java create mode 100644 gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowColumnarMap.java diff --git a/backends-velox/src/test/java/org/apache/gluten/udf/DuplicateArray.java b/backends-velox/src/test/java/org/apache/gluten/udf/DuplicateArray.java new file mode 100644 index 000000000000..dc7905859aef --- /dev/null +++ b/backends-velox/src/test/java/org/apache/gluten/udf/DuplicateArray.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.udf; + +import org.apache.hadoop.hive.ql.exec.Description; +import org.apache.hadoop.hive.ql.exec.UDFArgumentException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF; +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** UDF for duplicating array. */ +@Description( + name = "array_duplicate", + value = + "_FUNC_(array(obj1, obj2,...)) - " + + "The function returns an array of the same type as every element" + + "in array is duplicated.", + extended = + "Example:\n" + + " > SELECT _FUNC_(array('b', 'd')) FROM src LIMIT 1;\n" + + " ['b', 'b', 'd', 'd']") +public class DuplicateArray extends GenericUDF { + + ListObjectInspector arrayOI; + + public DuplicateArray() {} + + @Override + public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException { + if (arguments.length != 1) { + throw new UDFArgumentException("Argument size of array_duplicate must be 1."); + } + + arrayOI = (ListObjectInspector) arguments[0]; + return ObjectInspectorFactory.getStandardListObjectInspector( + arrayOI.getListElementObjectInspector()); + } + + @Override + public Object evaluate(DeferredObject[] arguments) throws HiveException { + + Object array = arguments[0].get(); + + // If the array is empty, return back the empty array + if (arrayOI.getListLength(array) == 0) { + return Collections.emptyList(); + } else if (arrayOI.getListLength(array) < 0) { + return null; + } + + List retArray = arrayOI.getList(array); + List result = new ArrayList<>(); + retArray.forEach( + element -> { + result.add(element); + result.add(element); + }); + return result; + } + + @Override + public String getDisplayString(String[] children) { + return "array_duplicate"; + } +} diff --git a/backends-velox/src/test/scala/org/apache/spark/sql/execution/GlutenHiveUDFSuite.scala b/backends-velox/src/test/scala/org/apache/spark/sql/execution/GlutenHiveUDFSuite.scala index c5c6981a0800..6919a303903a 100644 --- a/backends-velox/src/test/scala/org/apache/spark/sql/execution/GlutenHiveUDFSuite.scala +++ b/backends-velox/src/test/scala/org/apache/spark/sql/execution/GlutenHiveUDFSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import org.apache.gluten.config.GlutenConfig import org.apache.gluten.execution.{ColumnarPartialGenerateExec, ColumnarPartialProjectExec, GlutenQueryComparisonTest} import org.apache.gluten.expression.UDFMappings -import org.apache.gluten.udf.CustomerUDF +import org.apache.gluten.udf.{CustomerUDF, DuplicateArray} import org.apache.gluten.udtf.{ConditionalOutputUDTF, CustomerUDTF, NoInputUDTF, SimpleUDTF} import org.apache.spark.SparkConf @@ -326,4 +326,57 @@ class GlutenHiveUDFSuite extends GlutenQueryComparisonTest with SQLTestUtils { } } } + + test("udf with map with null values") { + withTempFunction("udf_map_values") { + sql(""" + |CREATE TEMPORARY FUNCTION udf_map_values AS + |'org.apache.hadoop.hive.ql.udf.generic.GenericUDFMapValues'; + |""".stripMargin) + + runQueryAndCompare(""" + |SELECT + | l_partkey, + | udf_map_values(map_data) + |FROM ( + | SELECT l_partkey, + | map( + | concat('hello', l_orderkey % 2), + | CASE WHEN l_orderkey % 2 == 0 THEN l_orderkey ELSE null END, + | concat('world', l_orderkey % 2), + | CASE WHEN l_orderkey % 2 == 0 THEN l_orderkey ELSE null END + | ) as map_data + | FROM lineitem + |) + |""".stripMargin) { + checkOperatorMatch[ColumnarPartialProjectExec] + } + } + } + + test("udf with array with null values") { + withTempFunction("udf_array_distinct") { + sql(s""" + |CREATE TEMPORARY FUNCTION udf_array_distinct AS '${classOf[DuplicateArray].getName}' + |""".stripMargin) + + runQueryAndCompare(""" + |SELECT + | l_partkey, + | udf_array_distinct(map_data) + |FROM ( + | SELECT l_partkey, + | array( + | l_orderkey % 2, + | CASE WHEN l_orderkey % 2 == 0 THEN l_orderkey ELSE null END, + | l_orderkey % 2, + | CASE WHEN l_orderkey % 2 == 0 THEN l_orderkey ELSE null END + | ) as map_data + | FROM lineitem + |) + |""".stripMargin) { + checkOperatorMatch[ColumnarPartialProjectExec] + } + } + } } diff --git a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowColumnarArray.java b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowColumnarArray.java new file mode 100644 index 000000000000..f1edced075ee --- /dev/null +++ b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowColumnarArray.java @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.vectorized; + +import org.apache.spark.sql.catalyst.expressions.SpecializedGettersReader; +import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.GenericArrayData; +import org.apache.spark.sql.types.*; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarMap; +import org.apache.spark.sql.vectorized.ColumnarRow; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * Because `get` method in `ColumnarArray` don't check whether the data to get is null and arrow + * vectors will throw exception when we try to access null value, so we define the following class + * as a workaround. Its implementation is copied from Spark-3.5, except that the `handleNull` + * parameter is set to true when we call `SpecializedGettersReader.read` in `get`. + */ +public class ArrowColumnarArray extends ArrayData { + // The data for this array. This array contains elements from + // data[offset] to data[offset + length). + private final ColumnVector data; + private final int offset; + private final int length; + + public ArrowColumnarArray(ColumnVector data, int offset, int length) { + this.data = data; + this.offset = offset; + this.length = length; + } + + @Override + public int numElements() { + return length; + } + + @Override + public ArrayData copy() { + DataType dt = data.dataType(); + + if (dt instanceof BooleanType) { + return UnsafeArrayData.fromPrimitiveArray(toBooleanArray()); + } else if (dt instanceof ByteType) { + return UnsafeArrayData.fromPrimitiveArray(toByteArray()); + } else if (dt instanceof ShortType) { + return UnsafeArrayData.fromPrimitiveArray(toShortArray()); + } else if (dt instanceof IntegerType + || dt instanceof DateType + || dt instanceof YearMonthIntervalType) { + return UnsafeArrayData.fromPrimitiveArray(toIntArray()); + } else if (dt instanceof LongType + || dt instanceof TimestampType + || dt instanceof DayTimeIntervalType) { + return UnsafeArrayData.fromPrimitiveArray(toLongArray()); + } else if (dt instanceof FloatType) { + return UnsafeArrayData.fromPrimitiveArray(toFloatArray()); + } else if (dt instanceof DoubleType) { + return UnsafeArrayData.fromPrimitiveArray(toDoubleArray()); + } else { + return new GenericArrayData(toObjectArray(dt)).copy(); // ensure the elements are copied. + } + } + + @Override + public boolean[] toBooleanArray() { + return data.getBooleans(offset, length); + } + + @Override + public byte[] toByteArray() { + return data.getBytes(offset, length); + } + + @Override + public short[] toShortArray() { + return data.getShorts(offset, length); + } + + @Override + public int[] toIntArray() { + return data.getInts(offset, length); + } + + @Override + public long[] toLongArray() { + return data.getLongs(offset, length); + } + + @Override + public float[] toFloatArray() { + return data.getFloats(offset, length); + } + + @Override + public double[] toDoubleArray() { + return data.getDoubles(offset, length); + } + + // TODO: this is extremely expensive. + @Override + public Object[] array() { + DataType dt = data.dataType(); + Object[] list = new Object[length]; + try { + for (int i = 0; i < length; i++) { + if (!data.isNullAt(offset + i)) { + list[i] = get(i, dt); + } + } + return list; + } catch (Exception e) { + throw new RuntimeException("Could not get the array", e); + } + } + + @Override + public boolean isNullAt(int ordinal) { + return data.isNullAt(offset + ordinal); + } + + @Override + public boolean getBoolean(int ordinal) { + return data.getBoolean(offset + ordinal); + } + + @Override + public byte getByte(int ordinal) { + return data.getByte(offset + ordinal); + } + + @Override + public short getShort(int ordinal) { + return data.getShort(offset + ordinal); + } + + @Override + public int getInt(int ordinal) { + return data.getInt(offset + ordinal); + } + + @Override + public long getLong(int ordinal) { + return data.getLong(offset + ordinal); + } + + @Override + public float getFloat(int ordinal) { + return data.getFloat(offset + ordinal); + } + + @Override + public double getDouble(int ordinal) { + return data.getDouble(offset + ordinal); + } + + @Override + public Decimal getDecimal(int ordinal, int precision, int scale) { + return data.getDecimal(offset + ordinal, precision, scale); + } + + @Override + public UTF8String getUTF8String(int ordinal) { + return data.getUTF8String(offset + ordinal); + } + + @Override + public byte[] getBinary(int ordinal) { + return data.getBinary(offset + ordinal); + } + + @Override + public CalendarInterval getInterval(int ordinal) { + return data.getInterval(offset + ordinal); + } + + @Override + public ColumnarRow getStruct(int ordinal, int numFields) { + return data.getStruct(offset + ordinal); + } + + @Override + public ColumnarArray getArray(int ordinal) { + return data.getArray(offset + ordinal); + } + + @Override + public ColumnarMap getMap(int ordinal) { + return data.getMap(offset + ordinal); + } + + @Override + public Object get(int ordinal, DataType dataType) { + return SpecializedGettersReader.read(this, ordinal, dataType, true, false); + } + + @Override + public void update(int ordinal, Object value) { + throw new UnsupportedOperationException(); + } + + @Override + public void setNullAt(int ordinal) { + throw new UnsupportedOperationException(); + } +} diff --git a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowColumnarMap.java b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowColumnarMap.java new file mode 100644 index 000000000000..b6bfacb835b5 --- /dev/null +++ b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowColumnarMap.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.vectorized; + +import org.apache.spark.sql.catalyst.util.ArrayBasedMapData; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.vectorized.ColumnVector; + +/** See [[ArrowColumnarArray]]. */ +public class ArrowColumnarMap extends MapData { + private final ArrowColumnarArray keys; + private final ArrowColumnarArray values; + private final int length; + + public ArrowColumnarMap(ColumnVector keys, ColumnVector values, int offset, int length) { + this.length = length; + this.keys = new ArrowColumnarArray(keys, offset, length); + this.values = new ArrowColumnarArray(values, offset, length); + } + + @Override + public int numElements() { + return length; + } + + @Override + public ArrayData keyArray() { + return keys; + } + + @Override + public ArrayData valueArray() { + return values; + } + + @Override + public MapData copy() { + return new ArrayBasedMapData(keys.copy(), values.copy()); + } +} diff --git a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowWritableColumnVector.java b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowWritableColumnVector.java index d00786f3f4ca..4086b2db0e20 100644 --- a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowWritableColumnVector.java +++ b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowWritableColumnVector.java @@ -38,8 +38,6 @@ import org.apache.spark.sql.types.*; import org.apache.spark.sql.utils.SparkArrowUtil; import org.apache.spark.sql.utils.SparkSchemaUtil; -import org.apache.spark.sql.vectorized.ColumnarArray; -import org.apache.spark.sql.vectorized.ColumnarMap; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.types.UTF8String; import org.slf4j.Logger; @@ -411,6 +409,22 @@ public static String stat() { return "vectorCounter is " + vectorCount.get(); } + // `get` method in Spark `ColumnarArray` doesn't check null values and arrow vectors will throw + // exception when we try to access a null value, so here we return `ArrowColumnarMap` + // as a workaround. + public MapData getMapInternal(int rowId) { + return accessor.getMap(rowId); + } + + // `get` method in Spark `ColumnarArray` doesn't check null values and arrow vectors will throw + // exception when we try to access a null value, so here we return `ArrowColumnarMap` + // as a workaround. + public ArrayData getArrayInternal(int rowId) { + return accessor.getArray(rowId); + } + + // `get` method in Spark `ColumnarRow` doesn't check whether the data to get is a null value, + // so we return `ArrowColumnarRow` as a workaround. public ArrowColumnarRow getStructInternal(int rowId) { if (isNullAt(rowId)) return null; ArrowWritableColumnVector[] writableColumns = @@ -893,7 +907,7 @@ byte[] getBinary(int rowId) { throw new UnsupportedOperationException(); } - ColumnarArray getArray(int rowId) { + ArrayData getArray(int rowId) { throw new UnsupportedOperationException(); } @@ -905,7 +919,7 @@ int getArrayOffset(int rowId) { throw new UnsupportedOperationException(); } - ColumnarMap getMap(int rowId) { + MapData getMap(int rowId) { throw new UnsupportedOperationException(); } } @@ -1239,8 +1253,8 @@ public int getArrayOffset(int rowId) { } @Override - final ColumnarArray getArray(int rowId) { - return new ColumnarArray(elements, getArrayOffset(rowId), getArrayLength(rowId)); + final ArrayData getArray(int rowId) { + return new ArrowColumnarArray(elements, getArrayOffset(rowId), getArrayLength(rowId)); } @Override @@ -1264,11 +1278,11 @@ private static class MapAccessor extends ArrowVectorAccessor { } @Override - final ColumnarMap getMap(int rowId) { + final MapData getMap(int rowId) { int index = rowId * MapVector.OFFSET_WIDTH; int offset = accessor.getOffsetBuffer().getInt(index); int length = accessor.getInnerValueCountAt(rowId); - return new ColumnarMap(keys, values, offset, length); + return new ArrowColumnarMap(keys, values, offset, length); } @Override diff --git a/gluten-arrow/src/main/scala/org/apache/gluten/vectorized/ArrowColumnarRow.scala b/gluten-arrow/src/main/scala/org/apache/gluten/vectorized/ArrowColumnarRow.scala index f0e2c4dabf10..b768d237edc3 100644 --- a/gluten-arrow/src/main/scala/org/apache/gluten/vectorized/ArrowColumnarRow.scala +++ b/gluten-arrow/src/main/scala/org/apache/gluten/vectorized/ArrowColumnarRow.scala @@ -21,8 +21,8 @@ import org.apache.gluten.execution.InternalRowGetVariantCompatible import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ -import org.apache.spark.sql.vectorized.{ColumnarArray, ColumnarMap} import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} import java.math.BigDecimal @@ -111,11 +111,11 @@ final class ArrowColumnarRow(writableColumns: Array[ArrowWritableColumnVector], override def getStruct(ordinal: Int, numFields: Int): ArrowColumnarRow = columns(ordinal).getStructInternal(rowId) - override def getArray(ordinal: Int): ColumnarArray = - columns(ordinal).getArray(rowId) + override def getArray(ordinal: Int): ArrayData = + columns(ordinal).getArrayInternal(rowId) - override def getMap(ordinal: Int): ColumnarMap = - columns(ordinal).getMap(rowId) + override def getMap(ordinal: Int): MapData = + columns(ordinal).getMapInternal(rowId) override def get(ordinal: Int, dataType: DataType): AnyRef = { if (isNullAt(ordinal)) { From 16ab9d6311a94ddaf920d2402f04c227f3ab6337 Mon Sep 17 00:00:00 2001 From: jiangtian Date: Thu, 25 Dec 2025 19:22:47 +0800 Subject: [PATCH 2/2] fix --- .../gluten/vectorized/ArrowColumnarArray.java | 205 +-------------- .../vectorized/ColumnarArrayShim.java | 234 +++++++++++++++++ .../vectorized/ColumnarArrayShim.java | 234 +++++++++++++++++ .../vectorized/ColumnarArrayShim.java | 234 +++++++++++++++++ .../vectorized/ColumnarArrayShim.java | 234 +++++++++++++++++ .../vectorized/ColumnarArrayShim.java | 241 ++++++++++++++++++ 6 files changed, 1186 insertions(+), 196 deletions(-) create mode 100644 shims/spark32/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArrayShim.java create mode 100644 shims/spark33/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArrayShim.java create mode 100644 shims/spark34/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArrayShim.java create mode 100644 shims/spark35/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArrayShim.java create mode 100644 shims/spark40/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArrayShim.java diff --git a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowColumnarArray.java b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowColumnarArray.java index f1edced075ee..3ea0444ee0f1 100644 --- a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowColumnarArray.java +++ b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowColumnarArray.java @@ -16,208 +16,21 @@ */ package org.apache.gluten.vectorized; -import org.apache.spark.sql.catalyst.expressions.SpecializedGettersReader; -import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData; -import org.apache.spark.sql.catalyst.util.ArrayData; -import org.apache.spark.sql.catalyst.util.GenericArrayData; -import org.apache.spark.sql.types.*; +import org.apache.spark.sql.execution.vectorized.ColumnarArrayShim; import org.apache.spark.sql.vectorized.ColumnVector; -import org.apache.spark.sql.vectorized.ColumnarArray; -import org.apache.spark.sql.vectorized.ColumnarMap; -import org.apache.spark.sql.vectorized.ColumnarRow; -import org.apache.spark.unsafe.types.CalendarInterval; -import org.apache.spark.unsafe.types.UTF8String; /** * Because `get` method in `ColumnarArray` don't check whether the data to get is null and arrow * vectors will throw exception when we try to access null value, so we define the following class - * as a workaround. Its implementation is copied from Spark-3.5, except that the `handleNull` - * parameter is set to true when we call `SpecializedGettersReader.read` in `get`. + * as a workaround. Its implementation is copied from Spark-4.0, except that the `handleNull` + * parameter is set to true when we call `SpecializedGettersReader.read` in `get`, which means that + * when trying to access a value of the array, we will check whether the value to get is null first. + * + *

The actual implementation is put in [[ColumnarArrayShim]] because Variant data type is + * introduced in Spark-4.0. */ -public class ArrowColumnarArray extends ArrayData { - // The data for this array. This array contains elements from - // data[offset] to data[offset + length). - private final ColumnVector data; - private final int offset; - private final int length; - +public class ArrowColumnarArray extends ColumnarArrayShim { public ArrowColumnarArray(ColumnVector data, int offset, int length) { - this.data = data; - this.offset = offset; - this.length = length; - } - - @Override - public int numElements() { - return length; - } - - @Override - public ArrayData copy() { - DataType dt = data.dataType(); - - if (dt instanceof BooleanType) { - return UnsafeArrayData.fromPrimitiveArray(toBooleanArray()); - } else if (dt instanceof ByteType) { - return UnsafeArrayData.fromPrimitiveArray(toByteArray()); - } else if (dt instanceof ShortType) { - return UnsafeArrayData.fromPrimitiveArray(toShortArray()); - } else if (dt instanceof IntegerType - || dt instanceof DateType - || dt instanceof YearMonthIntervalType) { - return UnsafeArrayData.fromPrimitiveArray(toIntArray()); - } else if (dt instanceof LongType - || dt instanceof TimestampType - || dt instanceof DayTimeIntervalType) { - return UnsafeArrayData.fromPrimitiveArray(toLongArray()); - } else if (dt instanceof FloatType) { - return UnsafeArrayData.fromPrimitiveArray(toFloatArray()); - } else if (dt instanceof DoubleType) { - return UnsafeArrayData.fromPrimitiveArray(toDoubleArray()); - } else { - return new GenericArrayData(toObjectArray(dt)).copy(); // ensure the elements are copied. - } - } - - @Override - public boolean[] toBooleanArray() { - return data.getBooleans(offset, length); - } - - @Override - public byte[] toByteArray() { - return data.getBytes(offset, length); - } - - @Override - public short[] toShortArray() { - return data.getShorts(offset, length); - } - - @Override - public int[] toIntArray() { - return data.getInts(offset, length); - } - - @Override - public long[] toLongArray() { - return data.getLongs(offset, length); - } - - @Override - public float[] toFloatArray() { - return data.getFloats(offset, length); - } - - @Override - public double[] toDoubleArray() { - return data.getDoubles(offset, length); - } - - // TODO: this is extremely expensive. - @Override - public Object[] array() { - DataType dt = data.dataType(); - Object[] list = new Object[length]; - try { - for (int i = 0; i < length; i++) { - if (!data.isNullAt(offset + i)) { - list[i] = get(i, dt); - } - } - return list; - } catch (Exception e) { - throw new RuntimeException("Could not get the array", e); - } - } - - @Override - public boolean isNullAt(int ordinal) { - return data.isNullAt(offset + ordinal); - } - - @Override - public boolean getBoolean(int ordinal) { - return data.getBoolean(offset + ordinal); - } - - @Override - public byte getByte(int ordinal) { - return data.getByte(offset + ordinal); - } - - @Override - public short getShort(int ordinal) { - return data.getShort(offset + ordinal); - } - - @Override - public int getInt(int ordinal) { - return data.getInt(offset + ordinal); - } - - @Override - public long getLong(int ordinal) { - return data.getLong(offset + ordinal); - } - - @Override - public float getFloat(int ordinal) { - return data.getFloat(offset + ordinal); - } - - @Override - public double getDouble(int ordinal) { - return data.getDouble(offset + ordinal); - } - - @Override - public Decimal getDecimal(int ordinal, int precision, int scale) { - return data.getDecimal(offset + ordinal, precision, scale); - } - - @Override - public UTF8String getUTF8String(int ordinal) { - return data.getUTF8String(offset + ordinal); - } - - @Override - public byte[] getBinary(int ordinal) { - return data.getBinary(offset + ordinal); - } - - @Override - public CalendarInterval getInterval(int ordinal) { - return data.getInterval(offset + ordinal); - } - - @Override - public ColumnarRow getStruct(int ordinal, int numFields) { - return data.getStruct(offset + ordinal); - } - - @Override - public ColumnarArray getArray(int ordinal) { - return data.getArray(offset + ordinal); - } - - @Override - public ColumnarMap getMap(int ordinal) { - return data.getMap(offset + ordinal); - } - - @Override - public Object get(int ordinal, DataType dataType) { - return SpecializedGettersReader.read(this, ordinal, dataType, true, false); - } - - @Override - public void update(int ordinal, Object value) { - throw new UnsupportedOperationException(); - } - - @Override - public void setNullAt(int ordinal) { - throw new UnsupportedOperationException(); + super(data, offset, length); } } diff --git a/shims/spark32/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArrayShim.java b/shims/spark32/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArrayShim.java new file mode 100644 index 000000000000..21594a155a8b --- /dev/null +++ b/shims/spark32/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArrayShim.java @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.vectorized; + +import org.apache.spark.sql.catalyst.expressions.SpecializedGettersReader; +import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.GenericArrayData; +import org.apache.spark.sql.types.*; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarMap; +import org.apache.spark.sql.vectorized.ColumnarRow; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +public class ColumnarArrayShim extends ArrayData { + // The data for this array. This array contains elements from + // data[offset] to data[offset + length). + private final ColumnVector data; + private final int offset; + private final int length; + + public ColumnarArrayShim(ColumnVector data, int offset, int length) { + this.data = data; + this.offset = offset; + this.length = length; + } + + @Override + public int numElements() { + return length; + } + + /** + * Sets all the appropriate null bits in the input UnsafeArrayData. + * + * @param arrayData The UnsafeArrayData to set the null bits for + * @return The UnsafeArrayData with the null bits set + */ + private UnsafeArrayData setNullBits(UnsafeArrayData arrayData) { + if (data.hasNull()) { + for (int i = 0; i < length; i++) { + if (data.isNullAt(offset + i)) { + arrayData.setNullAt(i); + } + } + } + return arrayData; + } + + @Override + public ArrayData copy() { + DataType dt = data.dataType(); + + if (dt instanceof BooleanType) { + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toBooleanArray())); + } else if (dt instanceof ByteType) { + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toByteArray())); + } else if (dt instanceof ShortType) { + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toShortArray())); + } else if (dt instanceof IntegerType + || dt instanceof DateType + || dt instanceof YearMonthIntervalType) { + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toIntArray())); + } else if (dt instanceof LongType + || dt instanceof TimestampType + || dt instanceof DayTimeIntervalType) { + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toLongArray())); + } else if (dt instanceof FloatType) { + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toFloatArray())); + } else if (dt instanceof DoubleType) { + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toDoubleArray())); + } else { + return new GenericArrayData(toObjectArray(dt)).copy(); // ensure the elements are copied. + } + } + + @Override + public boolean[] toBooleanArray() { + return data.getBooleans(offset, length); + } + + @Override + public byte[] toByteArray() { + return data.getBytes(offset, length); + } + + @Override + public short[] toShortArray() { + return data.getShorts(offset, length); + } + + @Override + public int[] toIntArray() { + return data.getInts(offset, length); + } + + @Override + public long[] toLongArray() { + return data.getLongs(offset, length); + } + + @Override + public float[] toFloatArray() { + return data.getFloats(offset, length); + } + + @Override + public double[] toDoubleArray() { + return data.getDoubles(offset, length); + } + + // TODO: this is extremely expensive. + @Override + public Object[] array() { + DataType dt = data.dataType(); + Object[] list = new Object[length]; + try { + for (int i = 0; i < length; i++) { + if (!data.isNullAt(offset + i)) { + list[i] = get(i, dt); + } + } + return list; + } catch (Exception e) { + throw new RuntimeException("Could not get the array", e); + } + } + + @Override + public boolean isNullAt(int ordinal) { + return data.isNullAt(offset + ordinal); + } + + @Override + public boolean getBoolean(int ordinal) { + return data.getBoolean(offset + ordinal); + } + + @Override + public byte getByte(int ordinal) { + return data.getByte(offset + ordinal); + } + + @Override + public short getShort(int ordinal) { + return data.getShort(offset + ordinal); + } + + @Override + public int getInt(int ordinal) { + return data.getInt(offset + ordinal); + } + + @Override + public long getLong(int ordinal) { + return data.getLong(offset + ordinal); + } + + @Override + public float getFloat(int ordinal) { + return data.getFloat(offset + ordinal); + } + + @Override + public double getDouble(int ordinal) { + return data.getDouble(offset + ordinal); + } + + @Override + public Decimal getDecimal(int ordinal, int precision, int scale) { + return data.getDecimal(offset + ordinal, precision, scale); + } + + @Override + public UTF8String getUTF8String(int ordinal) { + return data.getUTF8String(offset + ordinal); + } + + @Override + public byte[] getBinary(int ordinal) { + return data.getBinary(offset + ordinal); + } + + @Override + public CalendarInterval getInterval(int ordinal) { + return data.getInterval(offset + ordinal); + } + + @Override + public ColumnarRow getStruct(int ordinal, int numFields) { + return data.getStruct(offset + ordinal); + } + + @Override + public ColumnarArray getArray(int ordinal) { + return data.getArray(offset + ordinal); + } + + @Override + public ColumnarMap getMap(int ordinal) { + return data.getMap(offset + ordinal); + } + + @Override + public Object get(int ordinal, DataType dataType) { + return SpecializedGettersReader.read(this, ordinal, dataType, true, false); + } + + @Override + public void update(int ordinal, Object value) { + throw new UnsupportedOperationException(); + } + + @Override + public void setNullAt(int ordinal) { + throw new UnsupportedOperationException(); + } +} diff --git a/shims/spark33/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArrayShim.java b/shims/spark33/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArrayShim.java new file mode 100644 index 000000000000..21594a155a8b --- /dev/null +++ b/shims/spark33/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArrayShim.java @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.vectorized; + +import org.apache.spark.sql.catalyst.expressions.SpecializedGettersReader; +import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.GenericArrayData; +import org.apache.spark.sql.types.*; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarMap; +import org.apache.spark.sql.vectorized.ColumnarRow; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +public class ColumnarArrayShim extends ArrayData { + // The data for this array. This array contains elements from + // data[offset] to data[offset + length). + private final ColumnVector data; + private final int offset; + private final int length; + + public ColumnarArrayShim(ColumnVector data, int offset, int length) { + this.data = data; + this.offset = offset; + this.length = length; + } + + @Override + public int numElements() { + return length; + } + + /** + * Sets all the appropriate null bits in the input UnsafeArrayData. + * + * @param arrayData The UnsafeArrayData to set the null bits for + * @return The UnsafeArrayData with the null bits set + */ + private UnsafeArrayData setNullBits(UnsafeArrayData arrayData) { + if (data.hasNull()) { + for (int i = 0; i < length; i++) { + if (data.isNullAt(offset + i)) { + arrayData.setNullAt(i); + } + } + } + return arrayData; + } + + @Override + public ArrayData copy() { + DataType dt = data.dataType(); + + if (dt instanceof BooleanType) { + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toBooleanArray())); + } else if (dt instanceof ByteType) { + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toByteArray())); + } else if (dt instanceof ShortType) { + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toShortArray())); + } else if (dt instanceof IntegerType + || dt instanceof DateType + || dt instanceof YearMonthIntervalType) { + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toIntArray())); + } else if (dt instanceof LongType + || dt instanceof TimestampType + || dt instanceof DayTimeIntervalType) { + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toLongArray())); + } else if (dt instanceof FloatType) { + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toFloatArray())); + } else if (dt instanceof DoubleType) { + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toDoubleArray())); + } else { + return new GenericArrayData(toObjectArray(dt)).copy(); // ensure the elements are copied. + } + } + + @Override + public boolean[] toBooleanArray() { + return data.getBooleans(offset, length); + } + + @Override + public byte[] toByteArray() { + return data.getBytes(offset, length); + } + + @Override + public short[] toShortArray() { + return data.getShorts(offset, length); + } + + @Override + public int[] toIntArray() { + return data.getInts(offset, length); + } + + @Override + public long[] toLongArray() { + return data.getLongs(offset, length); + } + + @Override + public float[] toFloatArray() { + return data.getFloats(offset, length); + } + + @Override + public double[] toDoubleArray() { + return data.getDoubles(offset, length); + } + + // TODO: this is extremely expensive. + @Override + public Object[] array() { + DataType dt = data.dataType(); + Object[] list = new Object[length]; + try { + for (int i = 0; i < length; i++) { + if (!data.isNullAt(offset + i)) { + list[i] = get(i, dt); + } + } + return list; + } catch (Exception e) { + throw new RuntimeException("Could not get the array", e); + } + } + + @Override + public boolean isNullAt(int ordinal) { + return data.isNullAt(offset + ordinal); + } + + @Override + public boolean getBoolean(int ordinal) { + return data.getBoolean(offset + ordinal); + } + + @Override + public byte getByte(int ordinal) { + return data.getByte(offset + ordinal); + } + + @Override + public short getShort(int ordinal) { + return data.getShort(offset + ordinal); + } + + @Override + public int getInt(int ordinal) { + return data.getInt(offset + ordinal); + } + + @Override + public long getLong(int ordinal) { + return data.getLong(offset + ordinal); + } + + @Override + public float getFloat(int ordinal) { + return data.getFloat(offset + ordinal); + } + + @Override + public double getDouble(int ordinal) { + return data.getDouble(offset + ordinal); + } + + @Override + public Decimal getDecimal(int ordinal, int precision, int scale) { + return data.getDecimal(offset + ordinal, precision, scale); + } + + @Override + public UTF8String getUTF8String(int ordinal) { + return data.getUTF8String(offset + ordinal); + } + + @Override + public byte[] getBinary(int ordinal) { + return data.getBinary(offset + ordinal); + } + + @Override + public CalendarInterval getInterval(int ordinal) { + return data.getInterval(offset + ordinal); + } + + @Override + public ColumnarRow getStruct(int ordinal, int numFields) { + return data.getStruct(offset + ordinal); + } + + @Override + public ColumnarArray getArray(int ordinal) { + return data.getArray(offset + ordinal); + } + + @Override + public ColumnarMap getMap(int ordinal) { + return data.getMap(offset + ordinal); + } + + @Override + public Object get(int ordinal, DataType dataType) { + return SpecializedGettersReader.read(this, ordinal, dataType, true, false); + } + + @Override + public void update(int ordinal, Object value) { + throw new UnsupportedOperationException(); + } + + @Override + public void setNullAt(int ordinal) { + throw new UnsupportedOperationException(); + } +} diff --git a/shims/spark34/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArrayShim.java b/shims/spark34/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArrayShim.java new file mode 100644 index 000000000000..21594a155a8b --- /dev/null +++ b/shims/spark34/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArrayShim.java @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.vectorized; + +import org.apache.spark.sql.catalyst.expressions.SpecializedGettersReader; +import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.GenericArrayData; +import org.apache.spark.sql.types.*; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarMap; +import org.apache.spark.sql.vectorized.ColumnarRow; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +public class ColumnarArrayShim extends ArrayData { + // The data for this array. This array contains elements from + // data[offset] to data[offset + length). + private final ColumnVector data; + private final int offset; + private final int length; + + public ColumnarArrayShim(ColumnVector data, int offset, int length) { + this.data = data; + this.offset = offset; + this.length = length; + } + + @Override + public int numElements() { + return length; + } + + /** + * Sets all the appropriate null bits in the input UnsafeArrayData. + * + * @param arrayData The UnsafeArrayData to set the null bits for + * @return The UnsafeArrayData with the null bits set + */ + private UnsafeArrayData setNullBits(UnsafeArrayData arrayData) { + if (data.hasNull()) { + for (int i = 0; i < length; i++) { + if (data.isNullAt(offset + i)) { + arrayData.setNullAt(i); + } + } + } + return arrayData; + } + + @Override + public ArrayData copy() { + DataType dt = data.dataType(); + + if (dt instanceof BooleanType) { + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toBooleanArray())); + } else if (dt instanceof ByteType) { + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toByteArray())); + } else if (dt instanceof ShortType) { + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toShortArray())); + } else if (dt instanceof IntegerType + || dt instanceof DateType + || dt instanceof YearMonthIntervalType) { + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toIntArray())); + } else if (dt instanceof LongType + || dt instanceof TimestampType + || dt instanceof DayTimeIntervalType) { + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toLongArray())); + } else if (dt instanceof FloatType) { + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toFloatArray())); + } else if (dt instanceof DoubleType) { + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toDoubleArray())); + } else { + return new GenericArrayData(toObjectArray(dt)).copy(); // ensure the elements are copied. + } + } + + @Override + public boolean[] toBooleanArray() { + return data.getBooleans(offset, length); + } + + @Override + public byte[] toByteArray() { + return data.getBytes(offset, length); + } + + @Override + public short[] toShortArray() { + return data.getShorts(offset, length); + } + + @Override + public int[] toIntArray() { + return data.getInts(offset, length); + } + + @Override + public long[] toLongArray() { + return data.getLongs(offset, length); + } + + @Override + public float[] toFloatArray() { + return data.getFloats(offset, length); + } + + @Override + public double[] toDoubleArray() { + return data.getDoubles(offset, length); + } + + // TODO: this is extremely expensive. + @Override + public Object[] array() { + DataType dt = data.dataType(); + Object[] list = new Object[length]; + try { + for (int i = 0; i < length; i++) { + if (!data.isNullAt(offset + i)) { + list[i] = get(i, dt); + } + } + return list; + } catch (Exception e) { + throw new RuntimeException("Could not get the array", e); + } + } + + @Override + public boolean isNullAt(int ordinal) { + return data.isNullAt(offset + ordinal); + } + + @Override + public boolean getBoolean(int ordinal) { + return data.getBoolean(offset + ordinal); + } + + @Override + public byte getByte(int ordinal) { + return data.getByte(offset + ordinal); + } + + @Override + public short getShort(int ordinal) { + return data.getShort(offset + ordinal); + } + + @Override + public int getInt(int ordinal) { + return data.getInt(offset + ordinal); + } + + @Override + public long getLong(int ordinal) { + return data.getLong(offset + ordinal); + } + + @Override + public float getFloat(int ordinal) { + return data.getFloat(offset + ordinal); + } + + @Override + public double getDouble(int ordinal) { + return data.getDouble(offset + ordinal); + } + + @Override + public Decimal getDecimal(int ordinal, int precision, int scale) { + return data.getDecimal(offset + ordinal, precision, scale); + } + + @Override + public UTF8String getUTF8String(int ordinal) { + return data.getUTF8String(offset + ordinal); + } + + @Override + public byte[] getBinary(int ordinal) { + return data.getBinary(offset + ordinal); + } + + @Override + public CalendarInterval getInterval(int ordinal) { + return data.getInterval(offset + ordinal); + } + + @Override + public ColumnarRow getStruct(int ordinal, int numFields) { + return data.getStruct(offset + ordinal); + } + + @Override + public ColumnarArray getArray(int ordinal) { + return data.getArray(offset + ordinal); + } + + @Override + public ColumnarMap getMap(int ordinal) { + return data.getMap(offset + ordinal); + } + + @Override + public Object get(int ordinal, DataType dataType) { + return SpecializedGettersReader.read(this, ordinal, dataType, true, false); + } + + @Override + public void update(int ordinal, Object value) { + throw new UnsupportedOperationException(); + } + + @Override + public void setNullAt(int ordinal) { + throw new UnsupportedOperationException(); + } +} diff --git a/shims/spark35/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArrayShim.java b/shims/spark35/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArrayShim.java new file mode 100644 index 000000000000..21594a155a8b --- /dev/null +++ b/shims/spark35/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArrayShim.java @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.vectorized; + +import org.apache.spark.sql.catalyst.expressions.SpecializedGettersReader; +import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.GenericArrayData; +import org.apache.spark.sql.types.*; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarMap; +import org.apache.spark.sql.vectorized.ColumnarRow; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +public class ColumnarArrayShim extends ArrayData { + // The data for this array. This array contains elements from + // data[offset] to data[offset + length). + private final ColumnVector data; + private final int offset; + private final int length; + + public ColumnarArrayShim(ColumnVector data, int offset, int length) { + this.data = data; + this.offset = offset; + this.length = length; + } + + @Override + public int numElements() { + return length; + } + + /** + * Sets all the appropriate null bits in the input UnsafeArrayData. + * + * @param arrayData The UnsafeArrayData to set the null bits for + * @return The UnsafeArrayData with the null bits set + */ + private UnsafeArrayData setNullBits(UnsafeArrayData arrayData) { + if (data.hasNull()) { + for (int i = 0; i < length; i++) { + if (data.isNullAt(offset + i)) { + arrayData.setNullAt(i); + } + } + } + return arrayData; + } + + @Override + public ArrayData copy() { + DataType dt = data.dataType(); + + if (dt instanceof BooleanType) { + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toBooleanArray())); + } else if (dt instanceof ByteType) { + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toByteArray())); + } else if (dt instanceof ShortType) { + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toShortArray())); + } else if (dt instanceof IntegerType + || dt instanceof DateType + || dt instanceof YearMonthIntervalType) { + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toIntArray())); + } else if (dt instanceof LongType + || dt instanceof TimestampType + || dt instanceof DayTimeIntervalType) { + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toLongArray())); + } else if (dt instanceof FloatType) { + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toFloatArray())); + } else if (dt instanceof DoubleType) { + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toDoubleArray())); + } else { + return new GenericArrayData(toObjectArray(dt)).copy(); // ensure the elements are copied. + } + } + + @Override + public boolean[] toBooleanArray() { + return data.getBooleans(offset, length); + } + + @Override + public byte[] toByteArray() { + return data.getBytes(offset, length); + } + + @Override + public short[] toShortArray() { + return data.getShorts(offset, length); + } + + @Override + public int[] toIntArray() { + return data.getInts(offset, length); + } + + @Override + public long[] toLongArray() { + return data.getLongs(offset, length); + } + + @Override + public float[] toFloatArray() { + return data.getFloats(offset, length); + } + + @Override + public double[] toDoubleArray() { + return data.getDoubles(offset, length); + } + + // TODO: this is extremely expensive. + @Override + public Object[] array() { + DataType dt = data.dataType(); + Object[] list = new Object[length]; + try { + for (int i = 0; i < length; i++) { + if (!data.isNullAt(offset + i)) { + list[i] = get(i, dt); + } + } + return list; + } catch (Exception e) { + throw new RuntimeException("Could not get the array", e); + } + } + + @Override + public boolean isNullAt(int ordinal) { + return data.isNullAt(offset + ordinal); + } + + @Override + public boolean getBoolean(int ordinal) { + return data.getBoolean(offset + ordinal); + } + + @Override + public byte getByte(int ordinal) { + return data.getByte(offset + ordinal); + } + + @Override + public short getShort(int ordinal) { + return data.getShort(offset + ordinal); + } + + @Override + public int getInt(int ordinal) { + return data.getInt(offset + ordinal); + } + + @Override + public long getLong(int ordinal) { + return data.getLong(offset + ordinal); + } + + @Override + public float getFloat(int ordinal) { + return data.getFloat(offset + ordinal); + } + + @Override + public double getDouble(int ordinal) { + return data.getDouble(offset + ordinal); + } + + @Override + public Decimal getDecimal(int ordinal, int precision, int scale) { + return data.getDecimal(offset + ordinal, precision, scale); + } + + @Override + public UTF8String getUTF8String(int ordinal) { + return data.getUTF8String(offset + ordinal); + } + + @Override + public byte[] getBinary(int ordinal) { + return data.getBinary(offset + ordinal); + } + + @Override + public CalendarInterval getInterval(int ordinal) { + return data.getInterval(offset + ordinal); + } + + @Override + public ColumnarRow getStruct(int ordinal, int numFields) { + return data.getStruct(offset + ordinal); + } + + @Override + public ColumnarArray getArray(int ordinal) { + return data.getArray(offset + ordinal); + } + + @Override + public ColumnarMap getMap(int ordinal) { + return data.getMap(offset + ordinal); + } + + @Override + public Object get(int ordinal, DataType dataType) { + return SpecializedGettersReader.read(this, ordinal, dataType, true, false); + } + + @Override + public void update(int ordinal, Object value) { + throw new UnsupportedOperationException(); + } + + @Override + public void setNullAt(int ordinal) { + throw new UnsupportedOperationException(); + } +} diff --git a/shims/spark40/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArrayShim.java b/shims/spark40/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArrayShim.java new file mode 100644 index 000000000000..25adf5d233bb --- /dev/null +++ b/shims/spark40/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArrayShim.java @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.vectorized; + +import org.apache.spark.SparkUnsupportedOperationException; +import org.apache.spark.sql.catalyst.expressions.SpecializedGettersReader; +import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData; +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.GenericArrayData; +import org.apache.spark.sql.types.*; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarMap; +import org.apache.spark.sql.vectorized.ColumnarRow; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.unsafe.types.VariantVal; + +public class ColumnarArrayShim extends ArrayData { + // The data for this array. This array contains elements from + // data[offset] to data[offset + length). + private final ColumnVector data; + private final int offset; + private final int length; + + public ColumnarArrayShim(ColumnVector data, int offset, int length) { + this.data = data; + this.offset = offset; + this.length = length; + } + + @Override + public int numElements() { + return length; + } + + /** + * Sets all the appropriate null bits in the input UnsafeArrayData. + * + * @param arrayData The UnsafeArrayData to set the null bits for + * @return The UnsafeArrayData with the null bits set + */ + private UnsafeArrayData setNullBits(UnsafeArrayData arrayData) { + if (data.hasNull()) { + for (int i = 0; i < length; i++) { + if (data.isNullAt(offset + i)) { + arrayData.setNullAt(i); + } + } + } + return arrayData; + } + + @Override + public ArrayData copy() { + DataType dt = data.dataType(); + + if (dt instanceof BooleanType) { + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toBooleanArray())); + } else if (dt instanceof ByteType) { + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toByteArray())); + } else if (dt instanceof ShortType) { + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toShortArray())); + } else if (dt instanceof IntegerType + || dt instanceof DateType + || dt instanceof YearMonthIntervalType) { + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toIntArray())); + } else if (dt instanceof LongType + || dt instanceof TimestampType + || dt instanceof DayTimeIntervalType) { + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toLongArray())); + } else if (dt instanceof FloatType) { + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toFloatArray())); + } else if (dt instanceof DoubleType) { + return setNullBits(UnsafeArrayData.fromPrimitiveArray(toDoubleArray())); + } else { + return new GenericArrayData(toObjectArray(dt)).copy(); // ensure the elements are copied. + } + } + + @Override + public boolean[] toBooleanArray() { + return data.getBooleans(offset, length); + } + + @Override + public byte[] toByteArray() { + return data.getBytes(offset, length); + } + + @Override + public short[] toShortArray() { + return data.getShorts(offset, length); + } + + @Override + public int[] toIntArray() { + return data.getInts(offset, length); + } + + @Override + public long[] toLongArray() { + return data.getLongs(offset, length); + } + + @Override + public float[] toFloatArray() { + return data.getFloats(offset, length); + } + + @Override + public double[] toDoubleArray() { + return data.getDoubles(offset, length); + } + + // TODO: this is extremely expensive. + @Override + public Object[] array() { + DataType dt = data.dataType(); + Object[] list = new Object[length]; + try { + for (int i = 0; i < length; i++) { + if (!data.isNullAt(offset + i)) { + list[i] = get(i, dt); + } + } + return list; + } catch (Exception e) { + throw new RuntimeException("Could not get the array", e); + } + } + + @Override + public boolean isNullAt(int ordinal) { + return data.isNullAt(offset + ordinal); + } + + @Override + public boolean getBoolean(int ordinal) { + return data.getBoolean(offset + ordinal); + } + + @Override + public byte getByte(int ordinal) { + return data.getByte(offset + ordinal); + } + + @Override + public short getShort(int ordinal) { + return data.getShort(offset + ordinal); + } + + @Override + public int getInt(int ordinal) { + return data.getInt(offset + ordinal); + } + + @Override + public long getLong(int ordinal) { + return data.getLong(offset + ordinal); + } + + @Override + public float getFloat(int ordinal) { + return data.getFloat(offset + ordinal); + } + + @Override + public double getDouble(int ordinal) { + return data.getDouble(offset + ordinal); + } + + @Override + public Decimal getDecimal(int ordinal, int precision, int scale) { + return data.getDecimal(offset + ordinal, precision, scale); + } + + @Override + public UTF8String getUTF8String(int ordinal) { + return data.getUTF8String(offset + ordinal); + } + + @Override + public byte[] getBinary(int ordinal) { + return data.getBinary(offset + ordinal); + } + + @Override + public CalendarInterval getInterval(int ordinal) { + return data.getInterval(offset + ordinal); + } + + @Override + public VariantVal getVariant(int ordinal) { + return data.getVariant(offset + ordinal); + } + + @Override + public ColumnarRow getStruct(int ordinal, int numFields) { + return data.getStruct(offset + ordinal); + } + + @Override + public ColumnarArray getArray(int ordinal) { + return data.getArray(offset + ordinal); + } + + @Override + public ColumnarMap getMap(int ordinal) { + return data.getMap(offset + ordinal); + } + + @Override + public Object get(int ordinal, DataType dataType) { + return SpecializedGettersReader.read(this, ordinal, dataType, true, false); + } + + @Override + public void update(int ordinal, Object value) { + throw SparkUnsupportedOperationException.apply(); + } + + @Override + public void setNullAt(int ordinal) { + throw SparkUnsupportedOperationException.apply(); + } +}