Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.paimon.table.FileStoreTable

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.analysis.{NamedRelation, ResolvedTable}
import org.apache.spark.sql.catalyst.expressions.{Alias, ArrayTransform, Attribute, CreateStruct, Expression, GetArrayItem, GetStructField, LambdaFunction, Literal, NamedExpression, NamedLambdaVariable}
import org.apache.spark.sql.catalyst.expressions.{Alias, ArrayTransform, Attribute, CreateNamedStruct, CreateStruct, Expression, GetArrayItem, GetStructField, If, IsNull, LambdaFunction, Literal, NamedExpression, NamedLambdaVariable}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
Expand Down Expand Up @@ -206,10 +206,7 @@ class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] {
val sourceField = source(sourceIndex)
castStructField(parent, sourceIndex, sourceField.name, targetField)
}
Alias(CreateStruct(fields), parent.name)(
parent.exprId,
parent.qualifier,
Option(parent.metadata))
structAlias(fields, parent)
}

private def addCastToStructByPosition(
Expand All @@ -234,10 +231,19 @@ class PaimonAnalysis(session: SparkSession) extends Rule[LogicalPlan] {
val sourceField = source(i)
castStructField(parent, i, sourceField.name, targetField)
}
Alias(CreateStruct(fields), parent.name)(
parent.exprId,
parent.qualifier,
Option(parent.metadata))
structAlias(fields, parent)
}

private def structAlias(
fields: Seq[NamedExpression],
parent: NamedExpression): NamedExpression = {
val struct = CreateStruct(fields)
val res = if (parent.nullable) {
If(IsNull(parent), Literal(null, struct.dataType), struct)
} else {
struct
}
Alias(res, parent.name)(parent.exprId, parent.qualifier, Option(parent.metadata))
}

private def castStructField(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
package org.apache.paimon.spark

import org.apache.paimon.catalog.{Catalog, Identifier}
import org.apache.paimon.fs.FileIO
import org.apache.paimon.fs.local.LocalFileIO
import org.apache.paimon.spark.catalog.WithPaimonCatalog
import org.apache.paimon.spark.extensions.PaimonSparkSessionExtensions
import org.apache.paimon.spark.sql.{SparkVersionSupport, WithTableOptions}
Expand Down Expand Up @@ -46,6 +48,8 @@ class PaimonSparkTestBase
with WithTableOptions
with SparkVersionSupport {

protected lazy val fileIO: FileIO = LocalFileIO.create

protected lazy val tempDBDir: File = Utils.createTempDir

protected def paimonCatalog: Catalog = {
Expand All @@ -64,6 +68,7 @@ class PaimonSparkTestBase
"org.apache.spark.serializer.JavaSerializer"
}
super.sparkConf
.set("spark.sql.warehouse.dir", tempDBDir.getCanonicalPath)
.set("spark.sql.catalog.paimon", classOf[SparkCatalog].getName)
.set("spark.sql.catalog.paimon.warehouse", tempDBDir.getCanonicalPath)
.set("spark.sql.extensions", classOf[PaimonSparkSessionExtensions].getName)
Expand Down Expand Up @@ -152,8 +157,10 @@ class PaimonSparkTestBase

override def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit
pos: Position): Unit = {
println(testName)
super.test(testName, testTags: _*)(testFun)(pos)
super.test(testName, testTags: _*) {
println(testName)
testFun
}(pos)
}

def loadTable(tableName: String): FileStoreTable = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -560,4 +560,26 @@ abstract class InsertOverwriteTableTestBase extends PaimonSparkTestBase {
}
checkAnswer(sql("SELECT * FROM T ORDER BY name"), Row("g", null, "Shanghai"))
}

test("Paimon Insert: read and write struct with null") {
fileFormats {
format =>
withTable("t") {
sql(
s"CREATE TABLE t (i INT, s STRUCT<f1: INT, f2: INT>) TBLPROPERTIES ('file.format' = '$format')")
sql(
"INSERT INTO t VALUES (1, STRUCT(1, 1)), (2, null), (3, STRUCT(1, null)), (4, STRUCT(null, null))")
if (format.equals("parquet")) {
// todo: fix it, see https://github.com/apache/paimon/issues/4785
checkAnswer(
sql("SELECT * FROM t ORDER BY i"),
Seq(Row(1, Row(1, 1)), Row(2, null), Row(3, Row(1, null)), Row(4, null)))
} else {
checkAnswer(
sql("SELECT * FROM t ORDER BY i"),
Seq(Row(1, Row(1, 1)), Row(2, null), Row(3, Row(1, null)), Row(4, Row(null, null))))
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ trait WithTableOptions {

protected val withPk: Seq[Boolean] = Seq(true, false)

protected def fileFormats(fn: String => Unit): Unit = Seq("parquet", "orc", "avro").foreach(fn)
}
Loading