From e0fa118ec38e4751077e5ce51e5aa63fe8e7d5a9 Mon Sep 17 00:00:00 2001 From: Uros Bojanic Date: Fri, 31 Oct 2025 09:29:02 +0100 Subject: [PATCH 1/4] Initial commit --- .../resources/error/error-conditions.json | 6 + .../scala/org/apache/spark/sql/Encoders.scala | 14 ++ .../main/scala/org/apache/spark/sql/Row.scala | 16 ++ .../org/apache/spark/sql/SQLImplicits.scala | 8 + .../sql/catalyst/JavaTypeInference.scala | 6 +- .../spark/sql/catalyst/ScalaReflection.scala | 4 + .../catalyst/encoders/AgnosticEncoder.scala | 6 + .../sql/catalyst/encoders/RowEncoder.scala | 4 +- .../spark/sql/types/GeographyType.scala | 31 ++- .../apache/spark/sql/types/GeometryType.scala | 31 ++- .../spark/sql/catalyst/util/STUtils.java | 32 ++++ .../sql/catalyst/CatalystTypeConverters.scala | 42 +++- .../catalyst/DeserializerBuildHelper.scala | 26 ++- .../sql/catalyst/SerializerBuildHelper.scala | 24 ++- .../spark/sql/GeographyDataFrameSuite.scala | 180 +++++++++++++++++ .../spark/sql/GeometryDataFrameSuite.scala | 181 ++++++++++++++++++ .../scala/org/apache/spark/sql/RowSuite.scala | 11 ++ 17 files changed, 613 insertions(+), 9 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/GeographyDataFrameSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/GeometryDataFrameSuite.scala diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 2f8bda18ca3ac..db776768fce43 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -1881,6 +1881,12 @@ ], "sqlState" : "42623" }, + "GEO_ENCODER_SRID_MISMATCH_ERROR" : { + "message" : [ + "Failed to encode value because provided SRID of a value to encode does not match type SRID: ." + ], + "sqlState" : "42K09" + }, "GET_TABLES_BY_TYPE_UNSUPPORTED_BY_HIVE_VERSION" : { "message" : [ "Hive 2.2 and lower versions don't support getTablesByType. Please use Hive 2.3 or higher version." diff --git a/sql/api/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/api/src/main/scala/org/apache/spark/sql/Encoders.scala index cb1402e1b0f4a..7e698e58321ee 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/Encoders.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -162,6 +162,20 @@ object Encoders { */ def BINARY: Encoder[Array[Byte]] = BinaryEncoder + /** + * An encoder for Geometry data type. + * + * @since 4.1.0 + */ + def GEOMETRY(dt: GeometryType): Encoder[Geometry] = GeometryEncoder(dt) + + /** + * An encoder for Geography data type. + * + * @since 4.1.0 + */ + def GEOGRAPHY(dt: GeographyType): Encoder[Geography] = GeographyEncoder(dt) + /** * Creates an encoder that serializes instances of the `java.time.Duration` class to the * internal representation of nullable Catalyst's DayTimeIntervalType. diff --git a/sql/api/src/main/scala/org/apache/spark/sql/Row.scala b/sql/api/src/main/scala/org/apache/spark/sql/Row.scala index 764bdb17b37e2..c99b9ef34f86a 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/Row.scala @@ -302,6 +302,22 @@ trait Row extends Serializable { */ def getDecimal(i: Int): java.math.BigDecimal = getAs[java.math.BigDecimal](i) + /** + * Returns the value at position i of date type as org.apache.spark.sql.types.Geometry. + * + * @throws ClassCastException when data type does not match. + */ + def getGeometry(i: Int): org.apache.spark.sql.types.Geometry = + getAs[org.apache.spark.sql.types.Geometry](i) + + /** + * Returns the value at position i of date type as org.apache.spark.sql.types.Geography. + * + * @throws ClassCastException when data type does not match. + */ + def getGeography(i: Int): org.apache.spark.sql.types.Geography = + getAs[org.apache.spark.sql.types.Geography](i) + /** * Returns the value at position i of date type as java.sql.Date. * diff --git a/sql/api/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/api/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index a5b1060ca03db..9d64225b96633 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -104,6 +104,14 @@ trait EncoderImplicits extends LowPrioritySQLImplicits with Serializable { implicit def newScalaDecimalEncoder: Encoder[scala.math.BigDecimal] = DEFAULT_SCALA_DECIMAL_ENCODER + /** @since 4.1.0 */ + implicit def newGeometryEncoder: Encoder[org.apache.spark.sql.types.Geometry] = + DEFAULT_GEOMETRY_ENCODER + + /** @since 4.1.0 */ + implicit def newGeographyEncoder: Encoder[org.apache.spark.sql.types.Geography] = + DEFAULT_GEOGRAPHY_ENCODER + /** @since 2.2.0 */ implicit def newDateEncoder: Encoder[java.sql.Date] = Encoders.DATE diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 906e6419b3607..91947cf416fb6 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -27,7 +27,7 @@ import scala.reflect.ClassTag import org.apache.commons.lang3.reflect.{TypeUtils => JavaTypeUtils} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, DayTimeIntervalEncoder, DEFAULT_JAVA_DECIMAL_ENCODER, EncoderField, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaEnumEncoder, LocalDateTimeEncoder, LocalTimeEncoder, MapEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, STRICT_DATE_ENCODER, STRICT_INSTANT_ENCODER, STRICT_LOCAL_DATE_ENCODER, STRICT_TIMESTAMP_ENCODER, StringEncoder, UDTEncoder, YearMonthIntervalEncoder} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, DayTimeIntervalEncoder, DEFAULT_GEOGRAPHY_ENCODER, DEFAULT_GEOMETRY_ENCODER, DEFAULT_JAVA_DECIMAL_ENCODER, EncoderField, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaEnumEncoder, LocalDateTimeEncoder, LocalTimeEncoder, MapEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, STRICT_DATE_ENCODER, STRICT_INSTANT_ENCODER, STRICT_LOCAL_DATE_ENCODER, STRICT_TIMESTAMP_ENCODER, StringEncoder, UDTEncoder, YearMonthIntervalEncoder} import org.apache.spark.sql.errors.ExecutionErrors import org.apache.spark.sql.types._ import org.apache.spark.util.ArrayImplicits._ @@ -86,6 +86,10 @@ object JavaTypeInference { case c: Class[_] if c == classOf[java.lang.String] => StringEncoder case c: Class[_] if c == classOf[Array[Byte]] => BinaryEncoder + case c: Class[_] if c == classOf[org.apache.spark.sql.types.Geometry] => + DEFAULT_GEOMETRY_ENCODER + case c: Class[_] if c == classOf[org.apache.spark.sql.types.Geography] => + DEFAULT_GEOGRAPHY_ENCODER case c: Class[_] if c == classOf[java.math.BigDecimal] => DEFAULT_JAVA_DECIMAL_ENCODER case c: Class[_] if c == classOf[java.math.BigInteger] => JavaBigIntEncoder case c: Class[_] if c == classOf[java.time.LocalDate] => STRICT_LOCAL_DATE_ENCODER diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index d2e0053597e4f..6f5c4be42bbd4 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -332,6 +332,10 @@ object ScalaReflection extends ScalaReflection { case t if isSubtype(t, localTypeOf[java.time.LocalDateTime]) => LocalDateTimeEncoder case t if isSubtype(t, localTypeOf[java.time.LocalTime]) => LocalTimeEncoder case t if isSubtype(t, localTypeOf[VariantVal]) => VariantEncoder + case t if isSubtype(t, localTypeOf[Geography]) => + DEFAULT_GEOGRAPHY_ENCODER + case t if isSubtype(t, localTypeOf[Geometry]) => + DEFAULT_GEOMETRY_ENCODER case t if isSubtype(t, localTypeOf[Row]) => UnboundRowEncoder // UDT encoders diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala index 0c5295176608f..20949c188cb81 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala @@ -246,6 +246,8 @@ object AgnosticEncoders { case object DayTimeIntervalEncoder extends LeafEncoder[Duration](DayTimeIntervalType()) case object YearMonthIntervalEncoder extends LeafEncoder[Period](YearMonthIntervalType()) case object VariantEncoder extends LeafEncoder[VariantVal](VariantType) + case class GeographyEncoder(dt: GeographyType) extends LeafEncoder[Geography](dt) + case class GeometryEncoder(dt: GeometryType) extends LeafEncoder[Geometry](dt) case class DateEncoder(override val lenientSerialization: Boolean) extends LeafEncoder[jsql.Date](DateType) case class LocalDateEncoder(override val lenientSerialization: Boolean) @@ -277,6 +279,10 @@ object AgnosticEncoders { ScalaDecimalEncoder(DecimalType.SYSTEM_DEFAULT) val DEFAULT_JAVA_DECIMAL_ENCODER: JavaDecimalEncoder = JavaDecimalEncoder(DecimalType.SYSTEM_DEFAULT, lenientSerialization = false) + val DEFAULT_GEOMETRY_ENCODER: GeometryEncoder = + GeometryEncoder(GeometryType(Geometry.DEFAULT_SRID)) + val DEFAULT_GEOGRAPHY_ENCODER: GeographyEncoder = + GeographyEncoder(GeographyType(Geography.DEFAULT_SRID)) /** * Encoder that transforms external data into a representation that can be further processed by diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 620278c66d21d..73152017cf225 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import scala.reflect.classTag import org.apache.spark.sql.{AnalysisException, Row} -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder, CharEncoder, DateEncoder, DayTimeIntervalEncoder, EncoderField, InstantEncoder, IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder, LocalTimeEncoder, MapEncoder, NullEncoder, RowEncoder => AgnosticRowEncoder, StringEncoder, TimestampEncoder, UDTEncoder, VarcharEncoder, VariantEncoder, YearMonthIntervalEncoder} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder, CharEncoder, DateEncoder, DayTimeIntervalEncoder, EncoderField, GeographyEncoder, GeometryEncoder, InstantEncoder, IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder, LocalTimeEncoder, MapEncoder, NullEncoder, RowEncoder => AgnosticRowEncoder, StringEncoder, TimestampEncoder, UDTEncoder, VarcharEncoder, VariantEncoder, YearMonthIntervalEncoder} import org.apache.spark.sql.errors.DataTypeErrorsBase import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.types._ @@ -120,6 +120,8 @@ object RowEncoder extends DataTypeErrorsBase { field.nullable, field.metadata) }.toImmutableArraySeq) + case g: GeographyType => GeographyEncoder(g) + case g: GeometryType => GeometryEncoder(g) case _ => throw new AnalysisException( diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/GeographyType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/GeographyType.scala index 638ae79351846..03618c6ddd61a 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/GeographyType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/GeographyType.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.types import org.json4s.JsonAST.{JString, JValue} -import org.apache.spark.SparkIllegalArgumentException +import org.apache.spark.{SparkIllegalArgumentException, SparkRuntimeException} import org.apache.spark.annotation.Experimental import org.apache.spark.sql.internal.types.GeographicSpatialReferenceSystemMapper @@ -133,6 +133,30 @@ class GeographyType private (val crs: String, val algorithm: EdgeInterpolationAl // If the SRID is not mixed, we can only accept the same SRID. isMixedSrid || gt.srid == srid } + + def assertSridAllowedForType(otherSrid: Int): Unit = { + // If SRID is not mixed, SRIDs must match. + if (!isMixedSrid && otherSrid != srid) { + throw new SparkRuntimeException( + errorClass = "GEO_ENCODER_SRID_MISMATCH_ERROR", + messageParameters = Map( + "type"-> "GEOGRAPHY", + "valueSrid" -> otherSrid.toString, + "typeSrid" -> srid.toString, + ) + ) + } else if (isMixedSrid) { + // For fixed SRID geom types, we have a check that value matches the type srid. + // For mixed SRID we need to do that check explicitly, as MIXED SRID can accept any SRID. + // However it should accept only valid SRIDs. + if (!GeographyType.isSridSupported(otherSrid)) { + throw new SparkIllegalArgumentException( + errorClass = "ST_INVALID_SRID_VALUE", + messageParameters = Map("srid" -> otherSrid.toString) + ) + } + } + } } @Experimental @@ -157,6 +181,11 @@ object GeographyType extends SpatialType { private final val GEOGRAPHY_MIXED_TYPE: GeographyType = GeographyType(MIXED_CRS, GEOGRAPHY_DEFAULT_ALGORITHM) + /** Returns whether the given SRID is supported. */ + def isSridSupported(srid: Int): Boolean = { + GeographicSpatialReferenceSystemMapper.getStringId(srid) != null + } + /** * Constructors for GeographyType. */ diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/GeometryType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/GeometryType.scala index 77a6b365c042a..f4c4b8503c14b 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/GeometryType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/GeometryType.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.types import org.json4s.JsonAST.{JString, JValue} -import org.apache.spark.SparkIllegalArgumentException +import org.apache.spark.{SparkIllegalArgumentException, SparkRuntimeException} import org.apache.spark.annotation.Experimental import org.apache.spark.sql.internal.types.CartesianSpatialReferenceSystemMapper @@ -130,6 +130,30 @@ class GeometryType private (val crs: String) extends AtomicType with Serializabl // If the SRID is not mixed, we can only accept the same SRID. isMixedSrid || gt.srid == srid } + + def assertSridAllowedForType(otherSrid: Int): Unit = { + // If SRID is not mixed, SRIDs must match. + if (!isMixedSrid && otherSrid != srid) { + throw new SparkRuntimeException( + errorClass = "GEO_ENCODER_SRID_MISMATCH_ERROR", + messageParameters = Map( + "type"-> "GEOMETRY", + "valueSrid" -> otherSrid.toString, + "typeSrid" -> srid.toString, + ) + ) + } else if (isMixedSrid) { + // For fixed SRID geom types, we have a check that value matches the type srid. + // For mixed SRID we need to do that check explicitly, as MIXED SRID can accept any SRID. + // However it should accept only valid SRIDs. + if (!GeometryType.isSridSupported(otherSrid)) { + throw new SparkIllegalArgumentException( + errorClass = "ST_INVALID_SRID_VALUE", + messageParameters = Map("srid" -> otherSrid.toString) + ) + } + } + } } @Experimental @@ -149,6 +173,11 @@ object GeometryType extends SpatialType { private final val GEOMETRY_MIXED_TYPE: GeometryType = GeometryType(MIXED_CRS) + /** Returns whether the given SRID is supported. */ + def isSridSupported(srid: Int): Boolean = { + CartesianSpatialReferenceSystemMapper.getStringId(srid) != null + } + /** * Constructors for GeometryType. */ diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/STUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/STUtils.java index aca3fdf1f1000..026cf7dfa304c 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/STUtils.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/STUtils.java @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.catalyst.util; +import org.apache.spark.sql.types.GeographyType; +import org.apache.spark.sql.types.GeometryType; import org.apache.spark.unsafe.types.GeographyVal; import org.apache.spark.unsafe.types.GeometryVal; @@ -46,6 +48,36 @@ static GeometryVal toPhysVal(Geometry g) { return g.getValue(); } + /** Geospatial type encoder/decoder utilities. */ + + public static GeometryVal serializeGeomFromWKB(org.apache.spark.sql.types.Geometry geometry, GeometryType gt) { + int geometrySrid = geometry.getSrid(); + gt.assertSridAllowedForType(geometrySrid); + return toPhysVal(Geometry.fromWkb(geometry.getBytes(), geometrySrid)); + } + + public static GeographyVal serializeGeogFromWKB(org.apache.spark.sql.types.Geography geography, GeographyType gt) { + int geographySrid = geography.getSrid(); + gt.assertSridAllowedForType(geographySrid); + return toPhysVal(Geography.fromWkb(geography.getBytes(), geographySrid)); + } + + public static org.apache.spark.sql.types.Geometry deserializeGeom( + GeometryVal geometry, GeometryType gt) { + int geometrySrid = stSrid(geometry); + gt.assertSridAllowedForType(geometrySrid); + byte[] wkb = stAsBinary(geometry); + return org.apache.spark.sql.types.Geometry.fromWKB(wkb, geometrySrid); + } + + public static org.apache.spark.sql.types.Geography deserializeGeog( + GeographyVal geography, GeographyType gt) { + int geographySrid = stSrid(geography); + gt.assertSridAllowedForType(geographySrid); + byte[] wkb = stAsBinary(geography); + return org.apache.spark.sql.types.Geography.fromWKB(wkb, geographySrid); + } + /** Methods for implementing ST expressions. */ // ST_AsBinary diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index c1e0674d391d2..b8eee5e1c7c6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.types.DayTimeIntervalType._ import org.apache.spark.sql.types.YearMonthIntervalType._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.{GeographyVal, GeometryVal, UTF8String} import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.collection.Utils @@ -69,6 +69,10 @@ object CatalystTypeConverters { case CharType(length) => new CharConverter(length) case VarcharType(length) => new VarcharConverter(length) case _: StringType => StringConverter + case g: GeographyType => + new GeographyConverter(g) + case g: GeometryType => + new GeometryConverter(g) case DateType if SQLConf.get.datetimeJava8ApiEnabled => LocalDateConverter case DateType => DateConverter case _: TimeType => TimeConverter @@ -345,6 +349,42 @@ object CatalystTypeConverters { row.getUTF8String(column).toString } + private class GeometryConverter(dataType: GeometryType) + extends CatalystTypeConverter[Any, org.apache.spark.sql.types.Geometry, GeometryVal] { + override def toCatalystImpl(scalaValue: Any): GeometryVal = scalaValue match { + case g: org.apache.spark.sql.types.Geometry => STUtils.serializeGeomFromWKB(g, dataType) + case other => throw new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_3219", + messageParameters = scala.collection.immutable.Map( + "other" -> other.toString, + "otherClass" -> other.getClass.getCanonicalName, + "dataType" -> StringType.sql)) + } + override def toScala(catalystValue: GeometryVal): org.apache.spark.sql.types.Geometry = + if (catalystValue == null) null + else STUtils.deserializeGeom(catalystValue, dataType) + override def toScalaImpl(row: InternalRow, column: Int): org.apache.spark.sql.types.Geometry = + STUtils.deserializeGeom(row.getGeometry(0), dataType) + } + + private class GeographyConverter(dataType: GeographyType) + extends CatalystTypeConverter[Any, org.apache.spark.sql.types.Geography, GeographyVal] { + override def toCatalystImpl(scalaValue: Any): GeographyVal = scalaValue match { + case g: org.apache.spark.sql.types.Geography => STUtils.serializeGeogFromWKB(g, dataType) + case other => throw new SparkIllegalArgumentException( + errorClass = "_LEGACY_ERROR_TEMP_3219", + messageParameters = scala.collection.immutable.Map( + "other" -> other.toString, + "otherClass" -> other.getClass.getCanonicalName, + "dataType" -> StringType.sql)) + } + override def toScala(catalystValue: GeographyVal): org.apache.spark.sql.types.Geography = + if (catalystValue == null) null + else STUtils.deserializeGeog(catalystValue, dataType) + override def toScalaImpl(row: InternalRow, column: Int): org.apache.spark.sql.types.Geography = + STUtils.deserializeGeog(row.getGeography(0), dataType) + } + private object DateConverter extends CatalystTypeConverter[Any, Date, Any] { override def toCatalystImpl(scalaValue: Any): Int = scalaValue match { case d: Date => DateTimeUtils.fromJavaDate(d) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala index a051205829a11..60de179edb799 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala @@ -20,11 +20,11 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.{expressions => exprs} import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, AgnosticExpressionPathEncoder, Codec, JavaSerializationCodec, KryoSerializationCodec} -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedLeafEncoder, CharEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, LocalTimeEncoder, MapEncoder, OptionEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, VarcharEncoder, YearMonthIntervalEncoder} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedLeafEncoder, CharEncoder, DateEncoder, DayTimeIntervalEncoder, GeographyEncoder, GeometryEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, LocalTimeEncoder, MapEncoder, OptionEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, VarcharEncoder, YearMonthIntervalEncoder} import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{dataTypeForClass, externalDataTypeFor, isNativeEncoder} import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, IsNull, Literal, MapKeys, MapValues, UpCast} import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, CreateExternalRow, DecodeUsingSerializer, InitializeJavaBean, Invoke, NewInstance, StaticInvoke, UnresolvedCatalystToExternalMap, UnresolvedMapObjects, WrapOption} -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, CharVarcharCodegenUtils, DateTimeUtils, IntervalUtils} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, CharVarcharCodegenUtils, DateTimeUtils, IntervalUtils, STUtils} import org.apache.spark.sql.types._ object DeserializerBuildHelper { @@ -80,6 +80,24 @@ object DeserializerBuildHelper { returnNullable = false) } + def createDeserializerForGeometryType(inputObject: Expression, gt: GeometryType): Expression = { + StaticInvoke( + classOf[STUtils], + ObjectType(classOf[Geometry]), + "deserializeGeom", + inputObject :: Literal.fromObject(gt) :: Nil, + returnNullable = false) + } + + def createDeserializerForGeographyType(inputObject: Expression, gt: GeographyType): Expression = { + StaticInvoke( + classOf[STUtils], + ObjectType(classOf[Geography]), + "deserializeGeog", + inputObject :: Literal.fromObject(gt) :: Nil, + returnNullable = false) + } + def createDeserializerForChar( path: Expression, returnNullable: Boolean, @@ -290,6 +308,10 @@ object DeserializerBuildHelper { "withName", createDeserializerForString(path, returnNullable = false) :: Nil, returnNullable = false) + case g: GeographyEncoder => + createDeserializerForGeographyType(path, g.dt) + case g: GeometryEncoder => + createDeserializerForGeometryType(path, g.dt) case CharEncoder(length) => createDeserializerForChar(path, returnNullable = false, length) case VarcharEncoder(length) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala index 82b3cdc508bf9..06267bca02189 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala @@ -22,11 +22,11 @@ import scala.language.existentials import org.apache.spark.sql.catalyst.{expressions => exprs} import org.apache.spark.sql.catalyst.DeserializerBuildHelper.expressionWithNullSafety import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, AgnosticEncoders, AgnosticExpressionPathEncoder, Codec, JavaSerializationCodec, KryoSerializationCodec} -import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLeafEncoder, BoxedLongEncoder, BoxedShortEncoder, CharEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, LocalTimeEncoder, MapEncoder, OptionEncoder, PrimitiveLeafEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, VarcharEncoder, YearMonthIntervalEncoder} +import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLeafEncoder, BoxedLongEncoder, BoxedShortEncoder, CharEncoder, DateEncoder, DayTimeIntervalEncoder, GeographyEncoder, GeometryEncoder, InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, LocalTimeEncoder, MapEncoder, OptionEncoder, PrimitiveLeafEncoder, ProductEncoder, ScalaBigIntEncoder, ScalaDecimalEncoder, ScalaEnumEncoder, StringEncoder, TimestampEncoder, TransformingEncoder, UDTEncoder, VarcharEncoder, YearMonthIntervalEncoder} import org.apache.spark.sql.catalyst.encoders.EncoderUtils.{externalDataTypeFor, isNativeEncoder, lenientExternalDataTypeFor} import org.apache.spark.sql.catalyst.expressions.{BoundReference, CheckOverflow, CreateNamedStruct, Expression, IsNull, KnownNotNull, Literal, UnsafeArrayData} import org.apache.spark.sql.catalyst.expressions.objects._ -import org.apache.spark.sql.catalyst.util.{ArrayData, CharVarcharCodegenUtils, DateTimeUtils, GenericArrayData, IntervalUtils} +import org.apache.spark.sql.catalyst.util.{ArrayData, CharVarcharCodegenUtils, DateTimeUtils, GenericArrayData, IntervalUtils, STUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -63,6 +63,24 @@ object SerializerBuildHelper { Invoke(inputObject, "doubleValue", DoubleType) } + def createSerializerForGeographyType(inputObject: Expression, gt: GeographyType): Expression = { + StaticInvoke( + classOf[STUtils], + gt, + "serializeGeogFromWKB", + inputObject :: Literal.fromObject(gt) :: Nil, + returnNullable = false) + } + + def createSerializerForGeometryType(inputObject: Expression, gt: GeometryType): Expression = { + StaticInvoke( + classOf[STUtils], + gt, + "serializeGeomFromWKB", + inputObject :: Literal.fromObject(gt) :: Nil, + returnNullable = false) + } + def createSerializerForChar(inputObject: Expression, length: Int): Expression = { StaticInvoke( classOf[CharVarcharCodegenUtils], @@ -326,6 +344,8 @@ object SerializerBuildHelper { case BoxedDoubleEncoder => createSerializerForDouble(input) case JavaEnumEncoder(_) => createSerializerForJavaEnum(input) case ScalaEnumEncoder(_, _) => createSerializerForScalaEnum(input) + case g: GeographyEncoder => createSerializerForGeographyType(input, g.dt) + case g: GeometryEncoder => createSerializerForGeometryType(input, g.dt) case CharEncoder(length) => createSerializerForChar(input, length) case VarcharEncoder(length) => createSerializerForVarchar(input, length) case StringEncoder => createSerializerForString(input) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeographyDataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeographyDataFrameSuite.scala new file mode 100644 index 0000000000000..eeb1ba5ea9e25 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeographyDataFrameSuite.scala @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.collection.immutable.Seq + +import org.apache.spark.{SparkIllegalArgumentException, SparkRuntimeException} +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types._ + +class GeographyDataFrameSuite extends QueryTest with SharedSparkSession { + + val point1 = "010100000000000000000031400000000000001C40" + .grouped(2).map(Integer.parseInt(_, 16).toByte).toArray + val point2 = "010100000000000000000035400000000000001E40" + .grouped(2).map(Integer.parseInt(_, 16).toByte).toArray + + test("decode geography value: SRID schema does not match input SRID data schema") { + val rdd = sparkContext.parallelize(Seq(Row(Geography.fromWKB(point1, 0)))) + val schema = StructType(Seq(StructField("col1", GeographyType(4326), nullable = false))) + checkError( + // We look for cause, as all exception encoder errors are wrapped in + // EXPRESSION_ENCODING_FAILED. + exception = intercept[SparkRuntimeException] { + spark.createDataFrame(rdd, schema).collect() + }.getCause.asInstanceOf[SparkRuntimeException], + condition = "GEO_ENCODER_SRID_MISMATCH_ERROR", + parameters = Map("type" -> "GEOGRAPHY", "valueSrid" -> "0", "typeSrid" -> "4326") + ) + + val javaRDD = sparkContext.parallelize(Seq(Row(Geography.fromWKB(point1, 0)))).toJavaRDD() + checkError( + // We look for cause, as all exception encoder errors are wrapped in + // EXPRESSION_ENCODING_FAILED. + exception = intercept[SparkRuntimeException] { + spark.createDataFrame(javaRDD, schema).collect() + }.getCause.asInstanceOf[SparkRuntimeException], + condition = "GEO_ENCODER_SRID_MISMATCH_ERROR", + parameters = Map("type" -> "GEOGRAPHY", "valueSrid" -> "0", "typeSrid" -> "4326") + ) + + // For some reason this API does not use expression encoders, + // but CatalystTypeConverter, so we are not looking at cause. + val javaList = java.util.Arrays.asList(Row(Geography.fromWKB(point1, 0))) + checkError( + exception = intercept[SparkRuntimeException] { + spark.createDataFrame(javaList, schema).collect() + }, + condition = "GEO_ENCODER_SRID_MISMATCH_ERROR", + parameters = Map("type" -> "GEOGRAPHY", "valueSrid" -> "0", "typeSrid" -> "4326") + ) + + val geography1 = Geography.fromWKB(point1, 0) + val rdd2 = sparkContext.parallelize(Seq((geography1, 1))) + checkError( + exception = intercept[SparkRuntimeException] { + spark.createDataFrame(rdd2).collect() + }, + condition = "GEO_ENCODER_SRID_MISMATCH_ERROR", + parameters = Map("type" -> "GEOGRAPHY", "valueSrid" -> "0", "typeSrid" -> "4326") + ) + + // For some reason this API does not use expression encoders, + // but CatalystTypeConverter, so we are not looking at cause. + val seq = Seq((geography1, 1)) + checkError( + exception = intercept[SparkRuntimeException] { + spark.createDataFrame(seq).collect() + }, + condition = "GEO_ENCODER_SRID_MISMATCH_ERROR", + parameters = Map("type" -> "GEOGRAPHY", "valueSrid" -> "0", "typeSrid" -> "4326") + ) + + import testImplicits._ + checkError( + exception = intercept[SparkRuntimeException] { + Seq(geography1).toDF().collect() + }.getCause.asInstanceOf[SparkRuntimeException], + condition = "GEO_ENCODER_SRID_MISMATCH_ERROR", + parameters = Map("type" -> "GEOGRAPHY", "valueSrid" -> "0", "typeSrid" -> "4326") + ) + } + + test("decode geography value: mixed SRID schema is provided") { + val rdd = sparkContext.parallelize( + Seq(Row(Geography.fromWKB(point1, 4326)), Row(Geography.fromWKB(point2, 4326)))) + val schema = StructType(Seq(StructField("col1", GeographyType("ANY"), nullable = false))) + val expectedResult = Seq( + Row(Geography.fromWKB(point1, 4326)), Row(Geography.fromWKB(point2, 4326))) + + val resultDF = spark.createDataFrame(rdd, schema) + checkAnswer(resultDF, expectedResult) + + val javaRDD = sparkContext.parallelize( + Seq(Row(Geography.fromWKB(point1, 4326)), Row(Geography.fromWKB(point2, 4326)))).toJavaRDD() + val resultJavaDF = spark.createDataFrame(javaRDD, schema) + checkAnswer(resultJavaDF, expectedResult) + + val javaList = java.util.Arrays.asList( + Row(Geography.fromWKB(point1, 4326)), Row(Geography.fromWKB(point2, 4326))) + val resultJavaListDF = spark.createDataFrame(javaList, schema) + checkAnswer(resultJavaListDF, expectedResult) + + // Test that unsupported SRID with mixed schema will throw an error. + val rdd2 = sparkContext.parallelize( + Seq(Row(Geography.fromWKB(point1, 0)), Row(Geography.fromWKB(point2, 4326)))) + checkError( + exception = intercept[SparkRuntimeException] { + spark.createDataFrame(rdd2, schema).collect() + }.getCause.asInstanceOf[SparkIllegalArgumentException], + condition = "ST_INVALID_SRID_VALUE", + parameters = Map("srid" -> "0") + ) + } + + test("createDataFrame APIs with Geography.fromWKB") { + // 1. Test createDataFrame with RDD of Geography objects + val geography1 = Geography.fromWKB(point1, 4326) + val geography2 = Geography.fromWKB(point2) + val rdd = sparkContext.parallelize(Seq((geography1, 1), (geography2, 2), (null, 3))) + val dfFromRDD = spark.createDataFrame(rdd) + checkAnswer(dfFromRDD, Seq(Row(geography1, 1), Row(geography2, 2), Row(null, 3))) + + // 2. Test createDataFrame with Seq of Geography objects + val seq = Seq((geography1, 1), (geography2, 2), (null, 3)) + val dfFromSeq = spark.createDataFrame(seq) + checkAnswer(dfFromSeq, Seq(Row(geography1, 1), Row(geography2, 2), Row(null, 3))) + + // 3. Test createDataFrame with RDD of Rows and StructType schema + val geography3 = Geography.fromWKB(point1, 4326) + val geography4 = Geography.fromWKB(point2, 4326) + val rowRDD = sparkContext.parallelize(Seq(Row(geography3), Row(geography4), Row(null))) + val schema = StructType(Seq( + StructField("geography", GeographyType(4326), nullable = true) + )) + val dfFromRowRDD = spark.createDataFrame(rowRDD, schema) + checkAnswer(dfFromRowRDD, Seq(Row(geography3), Row(geography4), Row(null))) + + // 4. Test createDataFrame with JavaRDD of Rows and StructType schema + val javaRDD = sparkContext.parallelize(Seq(Row(geography3), Row(geography4), Row(null))) + .toJavaRDD() + val dfFromJavaRDD = spark.createDataFrame(javaRDD, schema) + checkAnswer(dfFromJavaRDD, Seq(Row(geography3), Row(geography4), Row(null))) + + // 5. Test createDataFrame with Java List of Rows and StructType schema + val javaList = java.util.Arrays.asList(Row(geography3), Row(geography4), Row(null)) + val dfFromJavaList = spark.createDataFrame(javaList, schema) + checkAnswer(dfFromJavaList, Seq(Row(geography3), Row(geography4), Row(null))) + + // 6. Implicit conversion from Seq to DF + import testImplicits._ + val implicitDf = Seq(geography1, geography2, null).toDF() + checkAnswer(implicitDf, Seq(Row(geography1), Row(geography2), Row(null))) + } + + test("encode geography type") { + // A test WKB value corresponding to: POINT (17 7). + val pointString: String = "010100000000000000000031400000000000001C40" + val pointBytes: Array[Byte] = pointString + .grouped(2).map(Integer.parseInt(_, 16).toByte).toArray + val df = spark.sql(s"SELECT ST_GeogFromWKB(X'$pointString')") + val expectedGeog = Geography.fromWKB(pointBytes, 4326) + checkAnswer(df, Seq(Row(expectedGeog))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeometryDataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeometryDataFrameSuite.scala new file mode 100644 index 0000000000000..bcc3cee7ebe39 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeometryDataFrameSuite.scala @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.collection.immutable.Seq + +import org.apache.spark.{SparkIllegalArgumentException, SparkRuntimeException} +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types._ + +class GeometryDataFrameSuite extends QueryTest with SharedSparkSession { + + val point1 = "010100000000000000000031400000000000001C40" + .grouped(2).map(Integer.parseInt(_, 16).toByte).toArray + val point2 = "010100000000000000000035400000000000001E40" + .grouped(2).map(Integer.parseInt(_, 16).toByte).toArray + + test("decode geometry value: SRID schema does not match input SRID data schema") { + val rdd = sparkContext.parallelize(Seq(Row(Geometry.fromWKB(point1, 0)))) + val schema = StructType(Seq(StructField("col1", GeometryType(3857), nullable = false))) + checkError( + // We look for cause, as all exception encoder errors are wrapped in + // EXPRESSION_ENCODING_FAILED. + exception = intercept[SparkRuntimeException] { + spark.createDataFrame(rdd, schema).collect() + }.getCause.asInstanceOf[SparkRuntimeException], + condition = "GEO_ENCODER_SRID_MISMATCH_ERROR", + parameters = Map("type" -> "GEOMETRY", "valueSrid" -> "0", "typeSrid" -> "3857") + ) + + val schema2 = StructType(Seq(StructField("col1", GeometryType(0), nullable = false))) + val javaRDD = sparkContext.parallelize(Seq(Row(Geometry.fromWKB(point1, 4326)))).toJavaRDD() + checkError( + // We look for cause, as all exception encoder errors are wrapped in + // EXPRESSION_ENCODING_FAILED. + exception = intercept[SparkRuntimeException] { + spark.createDataFrame(javaRDD, schema2).collect() + }.getCause.asInstanceOf[SparkRuntimeException], + condition = "GEO_ENCODER_SRID_MISMATCH_ERROR", + parameters = Map("type" -> "GEOMETRY", "valueSrid" -> "4326", "typeSrid" -> "0") + ) + + // For some reason this API does not use expression encoders, + // but CatalystTypeConverter, so we are not looking at cause. + val javaList = java.util.Arrays.asList(Row(Geometry.fromWKB(point1, 4326))) + checkError( + exception = intercept[SparkRuntimeException] { + spark.createDataFrame(javaList, schema).collect() + }, + condition = "GEO_ENCODER_SRID_MISMATCH_ERROR", + parameters = Map("type" -> "GEOMETRY", "valueSrid" -> "4326", "typeSrid" -> "3857") + ) + + val geometry1 = Geometry.fromWKB(point1, 4326) + val rdd2 = sparkContext.parallelize(Seq((geometry1, 1))) + checkError( + exception = intercept[SparkRuntimeException] { + spark.createDataFrame(rdd2).collect() + }, + condition = "GEO_ENCODER_SRID_MISMATCH_ERROR", + parameters = Map("type" -> "GEOMETRY", "valueSrid" -> "4326", "typeSrid" -> "0") + ) + + // For some reason this API does not use expression encoders, + // but CatalystTypeConverter, so we are not looking at cause. + val seq = Seq((geometry1, 1)) + checkError( + exception = intercept[SparkRuntimeException] { + spark.createDataFrame(seq).collect() + }, + condition = "GEO_ENCODER_SRID_MISMATCH_ERROR", + parameters = Map("type" -> "GEOMETRY", "valueSrid" -> "4326", "typeSrid" -> "0") + ) + + import testImplicits._ + checkError( + exception = intercept[SparkRuntimeException] { + Seq(geometry1).toDF().collect() + }.getCause.asInstanceOf[SparkRuntimeException], + condition = "GEO_ENCODER_SRID_MISMATCH_ERROR", + parameters = Map("type" -> "GEOMETRY", "valueSrid" -> "4326", "typeSrid" -> "0") + ) + } + + test("decode geometry value: mixed SRID schema is provided") { + val rdd = sparkContext.parallelize( + Seq(Row(Geometry.fromWKB(point1, 0)), Row(Geometry.fromWKB(point2, 4326)))) + val schema = StructType(Seq(StructField("col1", GeometryType("ANY"), nullable = false))) + val expectedResult = Seq( + Row(Geometry.fromWKB(point1, 0)), Row(Geometry.fromWKB(point2, 4326))) + + val resultDF = spark.createDataFrame(rdd, schema) + checkAnswer(resultDF, expectedResult) + + val javaRDD = sparkContext.parallelize( + Seq(Row(Geometry.fromWKB(point1, 0)), Row(Geometry.fromWKB(point2, 4326)))).toJavaRDD() + val resultJavaDF = spark.createDataFrame(javaRDD, schema) + checkAnswer(resultJavaDF, expectedResult) + + val javaList = java.util.Arrays.asList( + Row(Geometry.fromWKB(point1, 0)), Row(Geometry.fromWKB(point2, 4326))) + val resultJavaListDF = spark.createDataFrame(javaList, schema) + checkAnswer(resultJavaListDF, expectedResult) + + // Test that unsupported SRID with mixed schema will throw an error. + val rdd2 = sparkContext.parallelize( + Seq(Row(Geometry.fromWKB(point1, 1)), Row(Geometry.fromWKB(point2, 4326)))) + checkError( + exception = intercept[SparkRuntimeException] { + spark.createDataFrame(rdd2, schema).collect() + }.getCause.asInstanceOf[SparkIllegalArgumentException], + condition = "ST_INVALID_SRID_VALUE", + parameters = Map("srid" -> "1") + ) + } + + test("createDataFrame APIs with Geometry.fromWKB") { + // 1. Test createDataFrame with RDD of Geometry objects + val geometry1 = Geometry.fromWKB(point1, 0) + val geometry2 = Geometry.fromWKB(point2, 0) + val rdd = sparkContext.parallelize(Seq((geometry1, 1), (geometry2, 2), (null, 3))) + val dfFromRDD = spark.createDataFrame(rdd) + checkAnswer(dfFromRDD, Seq(Row(geometry1, 1), Row(geometry2, 2), Row(null, 3))) + + // 2. Test createDataFrame with Seq of Geometry objects + val seq = Seq((geometry1, 1), (geometry2, 2), (null, 3)) + val dfFromSeq = spark.createDataFrame(seq) + checkAnswer(dfFromSeq, Seq(Row(geometry1, 1), Row(geometry2, 2), Row(null, 3))) + + // 3. Test createDataFrame with RDD of Rows and StructType schema + val geometry3 = Geometry.fromWKB(point1, 4326) + val geometry4 = Geometry.fromWKB(point2, 4326) + val rowRDD = sparkContext.parallelize(Seq(Row(geometry3), Row(geometry4), Row(null))) + val schema = StructType(Seq( + StructField("geometry", GeometryType(4326), nullable = true) + )) + val dfFromRowRDD = spark.createDataFrame(rowRDD, schema) + checkAnswer(dfFromRowRDD, Seq(Row(geometry3), Row(geometry4), Row(null))) + + // 4. Test createDataFrame with JavaRDD of Rows and StructType schema + val javaRDD = sparkContext.parallelize(Seq(Row(geometry3), Row(geometry4), Row(null))) + .toJavaRDD() + val dfFromJavaRDD = spark.createDataFrame(javaRDD, schema) + checkAnswer(dfFromJavaRDD, Seq(Row(geometry3), Row(geometry4), Row(null))) + + // 5. Test createDataFrame with Java List of Rows and StructType schema + val javaList = java.util.Arrays.asList(Row(geometry3), Row(geometry4), Row(null)) + val dfFromJavaList = spark.createDataFrame(javaList, schema) + checkAnswer(dfFromJavaList, Seq(Row(geometry3), Row(geometry4), Row(null))) + + // 6. Implicit conversion from Seq to DF + import testImplicits._ + val implicitDf = Seq(geometry1, geometry2, null).toDF() + checkAnswer(implicitDf, Seq(Row(geometry1), Row(geometry2), Row(null))) + } + + test("encode geometry type") { + // A test WKB value corresponding to: POINT (17 7). + val pointString: String = "010100000000000000000031400000000000001C40" + val pointBytes: Array[Byte] = pointString + .grouped(2).map(Integer.parseInt(_, 16).toByte).toArray + val df = spark.sql(s"SELECT ST_GeomFromWKB(X'$pointString')") + val expectedGeom = Geometry.fromWKB(pointBytes, 0) + checkAnswer(df, Seq(Row(expectedGeom))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index 5de4170a1c112..eb36b68cd6171 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -136,4 +136,15 @@ class RowSuite extends SparkFunSuite with SharedSparkSession { parameters = Map("index" -> position.toString) ) } + + test("Geospatial row API - Geography and Geometry") { + // A test WKB value corresponding to: POINT (17 7). + val point = "010100000000000000000031400000000000001C40" + .grouped(2).map(Integer.parseInt(_, 16).toByte).toArray + + val row = Row(Geometry.fromWKB(point), Geography.fromWKB(point)) + + assert(row.getGeometry(0).getBytes() == point) + assert(row.getGeography(1).getBytes() == point) + } } From 6cb114bc606164a8e53fa76751fd603b55808bb5 Mon Sep 17 00:00:00 2001 From: Uros Bojanic Date: Sat, 1 Nov 2025 10:21:08 +0100 Subject: [PATCH 2/4] Address comments and fix scalastyle --- .../scala/org/apache/spark/sql/types/GeographyType.scala | 8 ++++---- .../scala/org/apache/spark/sql/types/GeometryType.scala | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/GeographyType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/GeographyType.scala index 03618c6ddd61a..77ddfcc15d3fb 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/GeographyType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/GeographyType.scala @@ -134,15 +134,15 @@ class GeographyType private (val crs: String, val algorithm: EdgeInterpolationAl isMixedSrid || gt.srid == srid } - def assertSridAllowedForType(otherSrid: Int): Unit = { + private[sql] def assertSridAllowedForType(otherSrid: Int): Unit = { // If SRID is not mixed, SRIDs must match. if (!isMixedSrid && otherSrid != srid) { throw new SparkRuntimeException( errorClass = "GEO_ENCODER_SRID_MISMATCH_ERROR", messageParameters = Map( - "type"-> "GEOGRAPHY", + "type" -> "GEOGRAPHY", "valueSrid" -> otherSrid.toString, - "typeSrid" -> srid.toString, + "typeSrid" -> srid.toString ) ) } else if (isMixedSrid) { @@ -182,7 +182,7 @@ object GeographyType extends SpatialType { GeographyType(MIXED_CRS, GEOGRAPHY_DEFAULT_ALGORITHM) /** Returns whether the given SRID is supported. */ - def isSridSupported(srid: Int): Boolean = { + private[types] def isSridSupported(srid: Int): Boolean = { GeographicSpatialReferenceSystemMapper.getStringId(srid) != null } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/GeometryType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/GeometryType.scala index f4c4b8503c14b..8116973be246f 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/GeometryType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/GeometryType.scala @@ -131,15 +131,15 @@ class GeometryType private (val crs: String) extends AtomicType with Serializabl isMixedSrid || gt.srid == srid } - def assertSridAllowedForType(otherSrid: Int): Unit = { + private[sql] def assertSridAllowedForType(otherSrid: Int): Unit = { // If SRID is not mixed, SRIDs must match. if (!isMixedSrid && otherSrid != srid) { throw new SparkRuntimeException( errorClass = "GEO_ENCODER_SRID_MISMATCH_ERROR", messageParameters = Map( - "type"-> "GEOMETRY", + "type" -> "GEOMETRY", "valueSrid" -> otherSrid.toString, - "typeSrid" -> srid.toString, + "typeSrid" -> srid.toString ) ) } else if (isMixedSrid) { @@ -174,7 +174,7 @@ object GeometryType extends SpatialType { GeometryType(MIXED_CRS) /** Returns whether the given SRID is supported. */ - def isSridSupported(srid: Int): Boolean = { + private[types] def isSridSupported(srid: Int): Boolean = { CartesianSpatialReferenceSystemMapper.getStringId(srid) != null } From 22b05e04824bef94f0cbcd34e3d8bf2e0a1ab0a0 Mon Sep 17 00:00:00 2001 From: Uros Bojanic Date: Sat, 1 Nov 2025 22:57:47 +0100 Subject: [PATCH 3/4] Fix scalastyle issues --- sql/api/src/main/scala/org/apache/spark/sql/Row.scala | 6 ++++-- .../scala/org/apache/spark/sql/types/GeographyType.scala | 7 ++----- .../scala/org/apache/spark/sql/types/GeometryType.scala | 7 ++----- 3 files changed, 8 insertions(+), 12 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/Row.scala b/sql/api/src/main/scala/org/apache/spark/sql/Row.scala index c99b9ef34f86a..1019d4c9a2276 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/Row.scala @@ -305,7 +305,8 @@ trait Row extends Serializable { /** * Returns the value at position i of date type as org.apache.spark.sql.types.Geometry. * - * @throws ClassCastException when data type does not match. + * @throws ClassCastException + * when data type does not match. */ def getGeometry(i: Int): org.apache.spark.sql.types.Geometry = getAs[org.apache.spark.sql.types.Geometry](i) @@ -313,7 +314,8 @@ trait Row extends Serializable { /** * Returns the value at position i of date type as org.apache.spark.sql.types.Geography. * - * @throws ClassCastException when data type does not match. + * @throws ClassCastException + * when data type does not match. */ def getGeography(i: Int): org.apache.spark.sql.types.Geography = getAs[org.apache.spark.sql.types.Geography](i) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/GeographyType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/GeographyType.scala index 77ddfcc15d3fb..d72e5987abebd 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/GeographyType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/GeographyType.scala @@ -142,9 +142,7 @@ class GeographyType private (val crs: String, val algorithm: EdgeInterpolationAl messageParameters = Map( "type" -> "GEOGRAPHY", "valueSrid" -> otherSrid.toString, - "typeSrid" -> srid.toString - ) - ) + "typeSrid" -> srid.toString)) } else if (isMixedSrid) { // For fixed SRID geom types, we have a check that value matches the type srid. // For mixed SRID we need to do that check explicitly, as MIXED SRID can accept any SRID. @@ -152,8 +150,7 @@ class GeographyType private (val crs: String, val algorithm: EdgeInterpolationAl if (!GeographyType.isSridSupported(otherSrid)) { throw new SparkIllegalArgumentException( errorClass = "ST_INVALID_SRID_VALUE", - messageParameters = Map("srid" -> otherSrid.toString) - ) + messageParameters = Map("srid" -> otherSrid.toString)) } } } diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/GeometryType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/GeometryType.scala index 8116973be246f..f5bbbcba6706e 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/GeometryType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/GeometryType.scala @@ -139,9 +139,7 @@ class GeometryType private (val crs: String) extends AtomicType with Serializabl messageParameters = Map( "type" -> "GEOMETRY", "valueSrid" -> otherSrid.toString, - "typeSrid" -> srid.toString - ) - ) + "typeSrid" -> srid.toString)) } else if (isMixedSrid) { // For fixed SRID geom types, we have a check that value matches the type srid. // For mixed SRID we need to do that check explicitly, as MIXED SRID can accept any SRID. @@ -149,8 +147,7 @@ class GeometryType private (val crs: String) extends AtomicType with Serializabl if (!GeometryType.isSridSupported(otherSrid)) { throw new SparkIllegalArgumentException( errorClass = "ST_INVALID_SRID_VALUE", - messageParameters = Map("srid" -> otherSrid.toString) - ) + messageParameters = Map("srid" -> otherSrid.toString)) } } } From 5e1529e8efef690097b9884ba47c8025c02dc9da Mon Sep 17 00:00:00 2001 From: Uros Bojanic Date: Sun, 2 Nov 2025 16:51:22 +0100 Subject: [PATCH 4/4] Fix scalastyle issues --- .../java/org/apache/spark/sql/catalyst/util/STUtils.java | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/STUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/STUtils.java index 026cf7dfa304c..9edeee26eb98a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/STUtils.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/STUtils.java @@ -50,13 +50,15 @@ static GeometryVal toPhysVal(Geometry g) { /** Geospatial type encoder/decoder utilities. */ - public static GeometryVal serializeGeomFromWKB(org.apache.spark.sql.types.Geometry geometry, GeometryType gt) { + public static GeometryVal serializeGeomFromWKB(org.apache.spark.sql.types.Geometry geometry, + GeometryType gt) { int geometrySrid = geometry.getSrid(); gt.assertSridAllowedForType(geometrySrid); return toPhysVal(Geometry.fromWkb(geometry.getBytes(), geometrySrid)); } - public static GeographyVal serializeGeogFromWKB(org.apache.spark.sql.types.Geography geography, GeographyType gt) { + public static GeographyVal serializeGeogFromWKB(org.apache.spark.sql.types.Geography geography, + GeographyType gt) { int geographySrid = geography.getSrid(); gt.assertSridAllowedForType(geographySrid); return toPhysVal(Geography.fromWkb(geography.getBytes(), geographySrid));