Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.paimon.spark.sql

class VariantTest extends VariantTestBase {}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.apache.paimon.utils.DateTimeUtils;

import org.apache.spark.sql.Row;
import org.apache.spark.sql.paimon.shims.SparkShimLoader;

import java.io.Serializable;
import java.sql.Date;
Expand Down Expand Up @@ -145,8 +146,8 @@ public byte[] getBinary(int i) {
}

@Override
public Variant getVariant(int pos) {
throw new UnsupportedOperationException();
public Variant getVariant(int i) {
return SparkShimLoader.getSparkShim().toPaimonVariant(row.getAs(i));
}

@Override
Expand Down Expand Up @@ -307,8 +308,8 @@ public byte[] getBinary(int i) {
}

@Override
public Variant getVariant(int pos) {
throw new UnsupportedOperationException();
public Variant getVariant(int i) {
return SparkShimLoader.getSparkShim().toPaimonVariant(getAs(i));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@
import org.apache.paimon.types.TinyIntType;
import org.apache.paimon.types.VarBinaryType;
import org.apache.paimon.types.VarCharType;
import org.apache.paimon.types.VariantType;

import org.apache.spark.sql.paimon.shims.SparkShimLoader;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.LongType;
Expand Down Expand Up @@ -217,6 +219,11 @@ public DataType visit(LocalZonedTimestampType localZonedTimestampType) {
return DataTypes.TimestampType;
}

@Override
public DataType visit(VariantType variantType) {
return SparkShimLoader.getSparkShim().SparkVariantType();
}

@Override
public DataType visit(ArrayType arrayType) {
org.apache.paimon.types.DataType elementType = arrayType.getElementType();
Expand Down Expand Up @@ -381,6 +388,8 @@ public org.apache.paimon.types.DataType atomic(DataType atomic) {
} else if (atomic instanceof org.apache.spark.sql.types.TimestampNTZType) {
// Move TimestampNTZType to the end for compatibility with spark3.3 and below
return new TimestampType();
} else if (SparkShimLoader.getSparkShim().isSparkVariantType(atomic)) {
return new VariantType();
}

throw new UnsupportedOperationException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

package org.apache.spark.sql.paimon.shims

import org.apache.paimon.data.variant.Variant
import org.apache.paimon.spark.data.{SparkArrayData, SparkInternalRow}
import org.apache.paimon.types.{DataType, RowType}

Expand All @@ -33,7 +34,7 @@ import org.apache.spark.sql.types.StructType
import java.util.{Map => JMap}

/**
* A spark shim trait. It declare methods which have incompatible implementations between Spark 3
* A spark shim trait. It declares methods which have incompatible implementations between Spark 3
* and Spark 4. The specific SparkShim implementation will be loaded through Service Provider
* Interface.
*/
Expand Down Expand Up @@ -62,4 +63,10 @@ trait SparkShim {

def convertToExpression(spark: SparkSession, column: Column): Expression

// for variant
def toPaimonVariant(o: Object): Variant

def isSparkVariantType(dataType: org.apache.spark.sql.types.DataType): Boolean

def SparkVariantType(): org.apache.spark.sql.types.DataType
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.paimon.spark.sql

import org.apache.paimon.spark.PaimonSparkTestBase

import org.apache.spark.sql.Row

abstract class VariantTestBase extends PaimonSparkTestBase {

test("Paimon Variant: read and write variant") {
sql("CREATE TABLE T (id INT, v VARIANT)")
sql("""
|INSERT INTO T VALUES
| (1, parse_json('{"age":26,"city":"Beijing"}')),
| (2, parse_json('{"age":27,"city":"Hangzhou"}'))
| """.stripMargin)

checkAnswer(
sql(
"SELECT id, variant_get(v, '$.age', 'int'), variant_get(v, '$.city', 'string') FROM T ORDER BY id"),
Seq(Row(1, 26, "Beijing"), Row(2, 27, "Hangzhou"))
)
checkAnswer(
sql(
"SELECT variant_get(v, '$.city', 'string') FROM T WHERE variant_get(v, '$.age', 'int') == 26"),
Seq(Row("Beijing"))
)
checkAnswer(
sql("SELECT * FROM T WHERE variant_get(v, '$.age', 'int') == 27"),
sql("""SELECT 2, parse_json('{"age":27,"city":"Hangzhou"}')""")
)
}

test("Paimon Variant: read and write array variant") {
sql("CREATE TABLE T (id INT, v ARRAY<VARIANT>)")
sql(
"""
|INSERT INTO T VALUES
| (1, array(parse_json('{"age":26,"city":"Beijing"}'), parse_json('{"age":27,"city":"Hangzhou"}'))),
| (2, array(parse_json('{"age":27,"city":"Shanghai"}')))
| """.stripMargin)

withSparkSQLConf("spark.sql.ansi.enabled" -> "false") {
checkAnswer(
sql(
"SELECT id, variant_get(v[1], '$.age', 'int'), variant_get(v[0], '$.city', 'string') FROM T ORDER BY id"),
Seq(Row(1, 27, "Beijing"), Row(2, null, "Shanghai"))
)
}
}

test("Paimon Variant: complex json") {
val json =
"""
|{
| "object" : {
| "name" : "Apache Paimon",
| "age" : 2,
| "address" : {
| "street" : "Main St",
| "city" : "Hangzhou"
| }
| },
| "array" : [ 1, 2, 3, 4, 5 ],
| "string" : "Hello, World!",
| "long" : 12345678901234,
| "double" : 1.0123456789012346,
| "decimal" : 100.99,
| "boolean1" : true,
| "boolean2" : false,
| "nullField" : null
|}
|""".stripMargin

sql("CREATE TABLE T (v VARIANT)")
sql(s"""
|INSERT INTO T VALUES parse_json('$json')
| """.stripMargin)

checkAnswer(
sql("""
|SELECT
| variant_get(v, '$.object', 'string'),
| variant_get(v, '$.object.name', 'string'),
| variant_get(v, '$.object.address.street', 'string'),
| variant_get(v, '$["object"]["address"].city', 'string'),
| variant_get(v, '$.array', 'string'),
| variant_get(v, '$.array[0]', 'int'),
| variant_get(v, '$.array[3]', 'int'),
| variant_get(v, '$.string', 'string'),
| variant_get(v, '$.double', 'double'),
| variant_get(v, '$.decimal', 'decimal(5,2)'),
| variant_get(v, '$.boolean1', 'boolean'),
| variant_get(v, '$.boolean2', 'boolean'),
| variant_get(v, '$.nullField', 'string')
|FROM T
|""".stripMargin),
Seq(
Row(
"""{"address":{"city":"Hangzhou","street":"Main St"},"age":2,"name":"Apache Paimon"}""",
"Apache Paimon",
"Main St",
"Hangzhou",
"[1,2,3,4,5]",
1,
4,
"Hello, World!",
1.0123456789012346,
100.99,
true,
false,
null
))
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

package org.apache.spark.sql.paimon.shims

import org.apache.paimon.data.variant.Variant
import org.apache.paimon.spark.catalyst.analysis.Spark3ResolutionRules
import org.apache.paimon.spark.catalyst.parser.extensions.PaimonSpark3SqlExtensionsParser
import org.apache.paimon.spark.data.{Spark3ArrayData, Spark3InternalRow, SparkArrayData, SparkInternalRow}
Expand Down Expand Up @@ -71,4 +72,11 @@ class Spark3Shim extends SparkShim {

override def convertToExpression(spark: SparkSession, column: Column): Expression = column.expr

override def toPaimonVariant(o: Object): Variant = throw new UnsupportedOperationException()

override def isSparkVariantType(dataType: org.apache.spark.sql.types.DataType): Boolean =
throw new UnsupportedOperationException()

override def SparkVariantType(): org.apache.spark.sql.types.DataType =
throw new UnsupportedOperationException()
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import org.apache.spark.unsafe.types.VariantVal

class Spark4ArrayData(override val elementType: DataType) extends AbstractSparkArrayData {

override def getVariant(ordinal: Int): VariantVal = throw new UnsupportedOperationException

override def getVariant(ordinal: Int): VariantVal = {
val v = paimonArray.getVariant(ordinal)
new VariantVal(v.value(), v.metadata())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,9 @@ import org.apache.paimon.types.RowType
import org.apache.spark.unsafe.types.VariantVal

class Spark4InternalRow(rowType: RowType) extends AbstractSparkInternalRow(rowType) {
override def getVariant(i: Int): VariantVal = throw new UnsupportedOperationException

override def getVariant(i: Int): VariantVal = {
val v = row.getVariant(i)
new VariantVal(v.value(), v.metadata())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

package org.apache.spark.sql.paimon.shims

import org.apache.paimon.data.variant.{GenericVariant, Variant}
import org.apache.paimon.spark.catalyst.analysis.Spark4ResolutionRules
import org.apache.paimon.spark.catalyst.parser.extensions.PaimonSpark4SqlExtensionsParser
import org.apache.paimon.spark.data.{Spark4ArrayData, Spark4InternalRow, SparkArrayData, SparkInternalRow}
Expand All @@ -31,7 +32,8 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, Table, TableCatalog}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.internal.ExpressionUtils
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{DataTypes, StructType, VariantType}
import org.apache.spark.unsafe.types.VariantVal

import java.util.{Map => JMap}

Expand Down Expand Up @@ -73,4 +75,14 @@ class Spark4Shim extends SparkShim {

def convertToExpression(spark: SparkSession, column: Column): Expression =
spark.expression(column)

override def toPaimonVariant(o: Object): Variant = {
val v = o.asInstanceOf[VariantVal]
new GenericVariant(v.getValue, v.getMetadata)
}

override def isSparkVariantType(dataType: org.apache.spark.sql.types.DataType): Boolean =
dataType.isInstanceOf[VariantType]

override def SparkVariantType(): org.apache.spark.sql.types.DataType = DataTypes.VariantType
}
Loading