Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -1881,6 +1881,12 @@
],
"sqlState" : "42623"
},
"GEO_ENCODER_SRID_MISMATCH_ERROR" : {
"message" : [
"Failed to encode <type> value because provided SRID <valueSrid> of a value to encode does not match type SRID: <typeSrid>."
],
"sqlState" : "42K09"
},
"GET_TABLES_BY_TYPE_UNSUPPORTED_BY_HIVE_VERSION" : {
"message" : [
"Hive 2.2 and lower versions don't support getTablesByType. Please use Hive 2.3 or higher version."
Expand Down
14 changes: 14 additions & 0 deletions sql/api/src/main/scala/org/apache/spark/sql/Encoders.scala
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,20 @@ object Encoders {
*/
def BINARY: Encoder[Array[Byte]] = BinaryEncoder

/**
* An encoder for Geometry data type.
*
* @since 4.1.0
*/
def GEOMETRY(dt: GeometryType): Encoder[Geometry] = GeometryEncoder(dt)

/**
* An encoder for Geography data type.
*
* @since 4.1.0
*/
def GEOGRAPHY(dt: GeographyType): Encoder[Geography] = GeographyEncoder(dt)

/**
* Creates an encoder that serializes instances of the `java.time.Duration` class to the
* internal representation of nullable Catalyst's DayTimeIntervalType.
Expand Down
18 changes: 18 additions & 0 deletions sql/api/src/main/scala/org/apache/spark/sql/Row.scala
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,24 @@ trait Row extends Serializable {
*/
def getDecimal(i: Int): java.math.BigDecimal = getAs[java.math.BigDecimal](i)

/**
* Returns the value at position i of date type as org.apache.spark.sql.types.Geometry.
*
* @throws ClassCastException
* when data type does not match.
*/
def getGeometry(i: Int): org.apache.spark.sql.types.Geometry =
getAs[org.apache.spark.sql.types.Geometry](i)

/**
* Returns the value at position i of date type as org.apache.spark.sql.types.Geography.
*
* @throws ClassCastException
* when data type does not match.
*/
def getGeography(i: Int): org.apache.spark.sql.types.Geography =
getAs[org.apache.spark.sql.types.Geography](i)

/**
* Returns the value at position i of date type as java.sql.Date.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,14 @@ trait EncoderImplicits extends LowPrioritySQLImplicits with Serializable {
implicit def newScalaDecimalEncoder: Encoder[scala.math.BigDecimal] =
DEFAULT_SCALA_DECIMAL_ENCODER

/** @since 4.1.0 */
implicit def newGeometryEncoder: Encoder[org.apache.spark.sql.types.Geometry] =
DEFAULT_GEOMETRY_ENCODER

/** @since 4.1.0 */
implicit def newGeographyEncoder: Encoder[org.apache.spark.sql.types.Geography] =
DEFAULT_GEOGRAPHY_ENCODER

/** @since 2.2.0 */
implicit def newDateEncoder: Encoder[java.sql.Date] = Encoders.DATE

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import scala.reflect.ClassTag
import org.apache.commons.lang3.reflect.{TypeUtils => JavaTypeUtils}

import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, DayTimeIntervalEncoder, DEFAULT_JAVA_DECIMAL_ENCODER, EncoderField, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaEnumEncoder, LocalDateTimeEncoder, LocalTimeEncoder, MapEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, STRICT_DATE_ENCODER, STRICT_INSTANT_ENCODER, STRICT_LOCAL_DATE_ENCODER, STRICT_TIMESTAMP_ENCODER, StringEncoder, UDTEncoder, YearMonthIntervalEncoder}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, DayTimeIntervalEncoder, DEFAULT_GEOGRAPHY_ENCODER, DEFAULT_GEOMETRY_ENCODER, DEFAULT_JAVA_DECIMAL_ENCODER, EncoderField, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaEnumEncoder, LocalDateTimeEncoder, LocalTimeEncoder, MapEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder, PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder, PrimitiveLongEncoder, PrimitiveShortEncoder, STRICT_DATE_ENCODER, STRICT_INSTANT_ENCODER, STRICT_LOCAL_DATE_ENCODER, STRICT_TIMESTAMP_ENCODER, StringEncoder, UDTEncoder, YearMonthIntervalEncoder}
import org.apache.spark.sql.errors.ExecutionErrors
import org.apache.spark.sql.types._
import org.apache.spark.util.ArrayImplicits._
Expand Down Expand Up @@ -86,6 +86,10 @@ object JavaTypeInference {

case c: Class[_] if c == classOf[java.lang.String] => StringEncoder
case c: Class[_] if c == classOf[Array[Byte]] => BinaryEncoder
case c: Class[_] if c == classOf[org.apache.spark.sql.types.Geometry] =>
DEFAULT_GEOMETRY_ENCODER
case c: Class[_] if c == classOf[org.apache.spark.sql.types.Geography] =>
DEFAULT_GEOGRAPHY_ENCODER
case c: Class[_] if c == classOf[java.math.BigDecimal] => DEFAULT_JAVA_DECIMAL_ENCODER
case c: Class[_] if c == classOf[java.math.BigInteger] => JavaBigIntEncoder
case c: Class[_] if c == classOf[java.time.LocalDate] => STRICT_LOCAL_DATE_ENCODER
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,10 @@ object ScalaReflection extends ScalaReflection {
case t if isSubtype(t, localTypeOf[java.time.LocalDateTime]) => LocalDateTimeEncoder
case t if isSubtype(t, localTypeOf[java.time.LocalTime]) => LocalTimeEncoder
case t if isSubtype(t, localTypeOf[VariantVal]) => VariantEncoder
case t if isSubtype(t, localTypeOf[Geography]) =>
DEFAULT_GEOGRAPHY_ENCODER
case t if isSubtype(t, localTypeOf[Geometry]) =>
DEFAULT_GEOMETRY_ENCODER
case t if isSubtype(t, localTypeOf[Row]) => UnboundRowEncoder

// UDT encoders
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,8 @@ object AgnosticEncoders {
case object DayTimeIntervalEncoder extends LeafEncoder[Duration](DayTimeIntervalType())
case object YearMonthIntervalEncoder extends LeafEncoder[Period](YearMonthIntervalType())
case object VariantEncoder extends LeafEncoder[VariantVal](VariantType)
case class GeographyEncoder(dt: GeographyType) extends LeafEncoder[Geography](dt)
case class GeometryEncoder(dt: GeometryType) extends LeafEncoder[Geometry](dt)
case class DateEncoder(override val lenientSerialization: Boolean)
extends LeafEncoder[jsql.Date](DateType)
case class LocalDateEncoder(override val lenientSerialization: Boolean)
Expand Down Expand Up @@ -277,6 +279,10 @@ object AgnosticEncoders {
ScalaDecimalEncoder(DecimalType.SYSTEM_DEFAULT)
val DEFAULT_JAVA_DECIMAL_ENCODER: JavaDecimalEncoder =
JavaDecimalEncoder(DecimalType.SYSTEM_DEFAULT, lenientSerialization = false)
val DEFAULT_GEOMETRY_ENCODER: GeometryEncoder =
GeometryEncoder(GeometryType(Geometry.DEFAULT_SRID))
val DEFAULT_GEOGRAPHY_ENCODER: GeographyEncoder =
GeographyEncoder(GeographyType(Geography.DEFAULT_SRID))

/**
* Encoder that transforms external data into a representation that can be further processed by
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.collection.mutable
import scala.reflect.classTag

import org.apache.spark.sql.{AnalysisException, Row}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder, CharEncoder, DateEncoder, DayTimeIntervalEncoder, EncoderField, InstantEncoder, IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder, LocalTimeEncoder, MapEncoder, NullEncoder, RowEncoder => AgnosticRowEncoder, StringEncoder, TimestampEncoder, UDTEncoder, VarcharEncoder, VariantEncoder, YearMonthIntervalEncoder}
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder, CharEncoder, DateEncoder, DayTimeIntervalEncoder, EncoderField, GeographyEncoder, GeometryEncoder, InstantEncoder, IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder, LocalTimeEncoder, MapEncoder, NullEncoder, RowEncoder => AgnosticRowEncoder, StringEncoder, TimestampEncoder, UDTEncoder, VarcharEncoder, VariantEncoder, YearMonthIntervalEncoder}
import org.apache.spark.sql.errors.DataTypeErrorsBase
import org.apache.spark.sql.internal.SqlApiConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -120,6 +120,8 @@ object RowEncoder extends DataTypeErrorsBase {
field.nullable,
field.metadata)
}.toImmutableArraySeq)
case g: GeographyType => GeographyEncoder(g)
case g: GeometryType => GeometryEncoder(g)

case _ =>
throw new AnalysisException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.types

import org.json4s.JsonAST.{JString, JValue}

import org.apache.spark.SparkIllegalArgumentException
import org.apache.spark.{SparkIllegalArgumentException, SparkRuntimeException}
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.internal.types.GeographicSpatialReferenceSystemMapper

Expand Down Expand Up @@ -133,6 +133,27 @@ class GeographyType private (val crs: String, val algorithm: EdgeInterpolationAl
// If the SRID is not mixed, we can only accept the same SRID.
isMixedSrid || gt.srid == srid
}

private[sql] def assertSridAllowedForType(otherSrid: Int): Unit = {
// If SRID is not mixed, SRIDs must match.
if (!isMixedSrid && otherSrid != srid) {
throw new SparkRuntimeException(
errorClass = "GEO_ENCODER_SRID_MISMATCH_ERROR",
messageParameters = Map(
"type" -> "GEOGRAPHY",
"valueSrid" -> otherSrid.toString,
"typeSrid" -> srid.toString))
} else if (isMixedSrid) {
// For fixed SRID geom types, we have a check that value matches the type srid.
// For mixed SRID we need to do that check explicitly, as MIXED SRID can accept any SRID.
// However it should accept only valid SRIDs.
if (!GeographyType.isSridSupported(otherSrid)) {
throw new SparkIllegalArgumentException(
errorClass = "ST_INVALID_SRID_VALUE",
messageParameters = Map("srid" -> otherSrid.toString))
}
}
}
}

@Experimental
Expand All @@ -157,6 +178,11 @@ object GeographyType extends SpatialType {
private final val GEOGRAPHY_MIXED_TYPE: GeographyType =
GeographyType(MIXED_CRS, GEOGRAPHY_DEFAULT_ALGORITHM)

/** Returns whether the given SRID is supported. */
private[types] def isSridSupported(srid: Int): Boolean = {
GeographicSpatialReferenceSystemMapper.getStringId(srid) != null
}

/**
* Constructors for GeographyType.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.types

import org.json4s.JsonAST.{JString, JValue}

import org.apache.spark.SparkIllegalArgumentException
import org.apache.spark.{SparkIllegalArgumentException, SparkRuntimeException}
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.internal.types.CartesianSpatialReferenceSystemMapper

Expand Down Expand Up @@ -130,6 +130,27 @@ class GeometryType private (val crs: String) extends AtomicType with Serializabl
// If the SRID is not mixed, we can only accept the same SRID.
isMixedSrid || gt.srid == srid
}

private[sql] def assertSridAllowedForType(otherSrid: Int): Unit = {
// If SRID is not mixed, SRIDs must match.
if (!isMixedSrid && otherSrid != srid) {
throw new SparkRuntimeException(
errorClass = "GEO_ENCODER_SRID_MISMATCH_ERROR",
messageParameters = Map(
"type" -> "GEOMETRY",
"valueSrid" -> otherSrid.toString,
"typeSrid" -> srid.toString))
} else if (isMixedSrid) {
// For fixed SRID geom types, we have a check that value matches the type srid.
// For mixed SRID we need to do that check explicitly, as MIXED SRID can accept any SRID.
// However it should accept only valid SRIDs.
if (!GeometryType.isSridSupported(otherSrid)) {
throw new SparkIllegalArgumentException(
errorClass = "ST_INVALID_SRID_VALUE",
messageParameters = Map("srid" -> otherSrid.toString))
}
}
}
}

@Experimental
Expand All @@ -149,6 +170,11 @@ object GeometryType extends SpatialType {
private final val GEOMETRY_MIXED_TYPE: GeometryType =
GeometryType(MIXED_CRS)

/** Returns whether the given SRID is supported. */
private[types] def isSridSupported(srid: Int): Boolean = {
CartesianSpatialReferenceSystemMapper.getStringId(srid) != null
}

/**
* Constructors for GeometryType.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package org.apache.spark.sql.catalyst.util;

import org.apache.spark.sql.types.GeographyType;
import org.apache.spark.sql.types.GeometryType;
import org.apache.spark.unsafe.types.GeographyVal;
import org.apache.spark.unsafe.types.GeometryVal;

Expand Down Expand Up @@ -46,6 +48,38 @@ static GeometryVal toPhysVal(Geometry g) {
return g.getValue();
}

/** Geospatial type encoder/decoder utilities. */

public static GeometryVal serializeGeomFromWKB(org.apache.spark.sql.types.Geometry geometry,
GeometryType gt) {
int geometrySrid = geometry.getSrid();
gt.assertSridAllowedForType(geometrySrid);
return toPhysVal(Geometry.fromWkb(geometry.getBytes(), geometrySrid));
}

public static GeographyVal serializeGeogFromWKB(org.apache.spark.sql.types.Geography geography,
GeographyType gt) {
int geographySrid = geography.getSrid();
gt.assertSridAllowedForType(geographySrid);
return toPhysVal(Geography.fromWkb(geography.getBytes(), geographySrid));
}

public static org.apache.spark.sql.types.Geometry deserializeGeom(
GeometryVal geometry, GeometryType gt) {
int geometrySrid = stSrid(geometry);
gt.assertSridAllowedForType(geometrySrid);
byte[] wkb = stAsBinary(geometry);
return org.apache.spark.sql.types.Geometry.fromWKB(wkb, geometrySrid);
}

public static org.apache.spark.sql.types.Geography deserializeGeog(
GeographyVal geography, GeographyType gt) {
int geographySrid = stSrid(geography);
gt.assertSridAllowedForType(geographySrid);
byte[] wkb = stAsBinary(geography);
return org.apache.spark.sql.types.Geography.fromWKB(wkb, geographySrid);
}

/** Methods for implementing ST expressions. */

// ST_AsBinary
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.DayTimeIntervalType._
import org.apache.spark.sql.types.YearMonthIntervalType._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.unsafe.types.{GeographyVal, GeometryVal, UTF8String}
import org.apache.spark.util.ArrayImplicits._
import org.apache.spark.util.collection.Utils

Expand Down Expand Up @@ -69,6 +69,10 @@ object CatalystTypeConverters {
case CharType(length) => new CharConverter(length)
case VarcharType(length) => new VarcharConverter(length)
case _: StringType => StringConverter
case g: GeographyType =>
new GeographyConverter(g)
case g: GeometryType =>
new GeometryConverter(g)
case DateType if SQLConf.get.datetimeJava8ApiEnabled => LocalDateConverter
case DateType => DateConverter
case _: TimeType => TimeConverter
Expand Down Expand Up @@ -345,6 +349,42 @@ object CatalystTypeConverters {
row.getUTF8String(column).toString
}

private class GeometryConverter(dataType: GeometryType)
extends CatalystTypeConverter[Any, org.apache.spark.sql.types.Geometry, GeometryVal] {
override def toCatalystImpl(scalaValue: Any): GeometryVal = scalaValue match {
case g: org.apache.spark.sql.types.Geometry => STUtils.serializeGeomFromWKB(g, dataType)
case other => throw new SparkIllegalArgumentException(
errorClass = "_LEGACY_ERROR_TEMP_3219",
messageParameters = scala.collection.immutable.Map(
"other" -> other.toString,
"otherClass" -> other.getClass.getCanonicalName,
"dataType" -> StringType.sql))
}
override def toScala(catalystValue: GeometryVal): org.apache.spark.sql.types.Geometry =
if (catalystValue == null) null
else STUtils.deserializeGeom(catalystValue, dataType)
override def toScalaImpl(row: InternalRow, column: Int): org.apache.spark.sql.types.Geometry =
STUtils.deserializeGeom(row.getGeometry(0), dataType)
}

private class GeographyConverter(dataType: GeographyType)
extends CatalystTypeConverter[Any, org.apache.spark.sql.types.Geography, GeographyVal] {
override def toCatalystImpl(scalaValue: Any): GeographyVal = scalaValue match {
case g: org.apache.spark.sql.types.Geography => STUtils.serializeGeogFromWKB(g, dataType)
case other => throw new SparkIllegalArgumentException(
errorClass = "_LEGACY_ERROR_TEMP_3219",
messageParameters = scala.collection.immutable.Map(
"other" -> other.toString,
"otherClass" -> other.getClass.getCanonicalName,
"dataType" -> StringType.sql))
}
override def toScala(catalystValue: GeographyVal): org.apache.spark.sql.types.Geography =
if (catalystValue == null) null
else STUtils.deserializeGeog(catalystValue, dataType)
override def toScalaImpl(row: InternalRow, column: Int): org.apache.spark.sql.types.Geography =
STUtils.deserializeGeog(row.getGeography(0), dataType)
}

private object DateConverter extends CatalystTypeConverter[Any, Date, Any] {
override def toCatalystImpl(scalaValue: Any): Int = scalaValue match {
case d: Date => DateTimeUtils.fromJavaDate(d)
Expand Down
Loading