From f62a07e21a23fc635df6bc6f2f778769c93cee7f Mon Sep 17 00:00:00 2001 From: Uros Bojanic Date: Wed, 15 Oct 2025 23:47:56 +0200 Subject: [PATCH 1/5] Initial commit --- python/pyspark/errors/error-conditions.json | 18 ++ .../pyspark/sql/tests/test_geographytype.py | 97 +++++++++ python/pyspark/sql/tests/test_geometrytype.py | 97 +++++++++ python/pyspark/sql/tests/test_types.py | 98 +++++++++ python/pyspark/sql/types.py | 195 ++++++++++++++++++ 5 files changed, 505 insertions(+) create mode 100644 python/pyspark/sql/tests/test_geographytype.py create mode 100644 python/pyspark/sql/tests/test_geometrytype.py diff --git a/python/pyspark/errors/error-conditions.json b/python/pyspark/errors/error-conditions.json index 30c6efd4e32ad..d169e6293a1ba 100644 --- a/python/pyspark/errors/error-conditions.json +++ b/python/pyspark/errors/error-conditions.json @@ -1134,6 +1134,24 @@ "Cannot serialize the function ``. If you accessed the Spark session, or a DataFrame defined outside of the function, or any object that contains a Spark session, please be aware that they are not allowed in Spark Connect. For `foreachBatch`, please access the Spark session using `df.sparkSession`, where `df` is the first parameter in your `foreachBatch` function. For `StreamingQueryListener`, please access the Spark session using `self.spark`. For details please check out the PySpark doc for `foreachBatch` and `StreamingQueryListener`." ] }, + "ST_INVALID_ALGORITHM_VALUE" : { + "message" : [ + "Invalid or unsupported edge interpolation algorithm value: ''." + ], + "sqlState" : "22023" + }, + "ST_INVALID_CRS_VALUE" : { + "message" : [ + "Invalid or unsupported CRS (coordinate reference system) value: ''." + ], + "sqlState" : "22023" + }, + "ST_INVALID_SRID_VALUE" : { + "message" : [ + "Invalid or unsupported SRID (spatial reference identifier) value: ." + ], + "sqlState" : "22023" + }, "TEST_CLASS_NOT_COMPILED": { "message": [ " doesn't exist. Spark sql test classes are not compiled." diff --git a/python/pyspark/sql/tests/test_geographytype.py b/python/pyspark/sql/tests/test_geographytype.py new file mode 100644 index 0000000000000..22a33100b2665 --- /dev/null +++ b/python/pyspark/sql/tests/test_geographytype.py @@ -0,0 +1,97 @@ +# -*- encoding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.sql.types import GeographyType +from pyspark.sql.utils import IllegalArgumentException +from pyspark.testing.sqlutils import ReusedSQLTestCase + + +class GeographyTypeTestMixin: + + # Test cases for GeographyType construction based on SRID. + + def test_geographytype_specified_valid_srid(self): + """Test that GeographyType is constructed correctly when a valid SRID is specified.""" + + supported_srid = {4326: "OGC:CRS84"} + + for srid, crs in supported_srid.items(): + geography_type = GeographyType(srid) + self.assertEqual(geography_type.srid, srid) + self.assertEqual(geography_type.typeName(), "geography") + self.assertEqual(geography_type.simpleString(), f"geography({srid})") + self.assertEqual(geography_type.jsonValue(), f"geography({crs}, SPHERICAL)") + self.assertEqual(repr(geography_type), f"GeographyType({srid})") + + def test_geographytype_specified_invalid_srid(self): + """Test that the correct error is returned when an invalid SRID value is specified.""" + + for srid in [-4612, -4326, -2, -1, 1, 2]: + with self.assertRaises(IllegalArgumentException) as error_context: + GeographyType(srid) + srid_header = "[ST_INVALID_SRID_VALUE] Invalid or unsupported SRID" + self.assertEqual( + str(error_context.exception), + f"{srid_header} (spatial reference identifier) value: {srid}." + ) + + # Special string value "ANY" in place of SRID is used to denote a mixed GEOGRAPHY type. + + def test_geographytype_any_specifier(self): + """Test that GeographyType is constructed correctly with ANY specifier for mixed SRID.""" + + geography_type = GeographyType("ANY") + self.assertEqual(geography_type.srid, GeographyType.MIXED_SRID) + self.assertEqual(geography_type.typeName(), "geography") + self.assertEqual(geography_type.simpleString(), "geography(any)") + self.assertEqual(repr(geography_type), "GeographyType(ANY)") + + # The tests below verify GEOGRAPHY type object equality based on SRID values. + + def test_geographytype_same_srid_values(self): + """Test that two GeographyTypes with specified SRIDs have the same SRID values.""" + + for srid in [4326]: + geography_type_1 = GeographyType(srid) + geography_type_2 = GeographyType(srid) + self.assertEqual(geography_type_1.srid, geography_type_2.srid) + + def test_geographytype_different_srid_values(self): + """Test that two GeographyTypes with specified SRIDs have different SRID values.""" + + for srid in [4326]: + geography_type_1 = GeographyType(srid) + geography_type_2 = GeographyType("ANY") + self.assertNotEqual(geography_type_1.srid, geography_type_2.srid) + + +class GeographyTypeTest(GeographyTypeTestMixin, ReusedSQLTestCase): + pass + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.test_types import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_geometrytype.py b/python/pyspark/sql/tests/test_geometrytype.py new file mode 100644 index 0000000000000..36c165b7e4709 --- /dev/null +++ b/python/pyspark/sql/tests/test_geometrytype.py @@ -0,0 +1,97 @@ +# -*- encoding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.sql.types import GeometryType +from pyspark.sql.utils import IllegalArgumentException +from pyspark.testing.sqlutils import ReusedSQLTestCase + + +class GeometryTypeTestMixin: + + # Test cases for GeometryType construction based on SRID. + + def test_geometrytype_specified_valid_srid(self): + """Test that GeometryType is constructed correctly when a valid SRID is specified.""" + + supported_srid = {4326: "OGC:CRS84"} + + for srid, crs in supported_srid.items(): + geometry_type = GeometryType(srid) + self.assertEqual(geometry_type.srid, srid) + self.assertEqual(geometry_type.typeName(), "geometry") + self.assertEqual(geometry_type.simpleString(), f"geometry({srid})") + self.assertEqual(geometry_type.jsonValue(), f"geometry({crs})") + self.assertEqual(repr(geometry_type), f"GeometryType({srid})") + + def test_geometrytype_specified_invalid_srid(self): + """Test that the correct error is returned when an invalid SRID value is specified.""" + + for srid in [-4612, -4326, -2, -1, 1, 2]: + with self.assertRaises(IllegalArgumentException) as error_context: + GeometryType(srid) + srid_header = "[ST_INVALID_SRID_VALUE] Invalid or unsupported SRID" + self.assertEqual( + str(error_context.exception), + f"{srid_header} (spatial reference identifier) value: {srid}." + ) + + # Special string value "ANY" in place of SRID is used to denote a mixed GEOMETRY type. + + def test_geometrytype_any_specifier(self): + """Test that GeometryType is constructed correctly with ANY specifier for mixed SRID.""" + + geometry_type = GeometryType("ANY") + self.assertEqual(geometry_type.srid, GeometryType.MIXED_SRID) + self.assertEqual(geometry_type.typeName(), "geometry") + self.assertEqual(geometry_type.simpleString(), "geometry(any)") + self.assertEqual(repr(geometry_type), "GeometryType(ANY)") + + # The tests below verify GEOMETRY type object equality based on SRID values. + + def test_geometrytype_same_srid_values(self): + """Test that two GeometryTypes with specified SRIDs have the same SRID values.""" + + for srid in [4326]: + geometry_type_1 = GeometryType(srid) + geometry_type_2 = GeometryType(srid) + self.assertEqual(geometry_type_1.srid, geometry_type_2.srid) + + def test_geometrytype_different_srid_values(self): + """Test that two GeometryTypes with specified SRIDs have different SRID values.""" + + for srid in [4326]: + geometry_type_1 = GeometryType(srid) + geometry_type_2 = GeometryType("ANY") + self.assertNotEqual(geometry_type_1.srid, geometry_type_2.srid) + + +class GeometryTypeTest(GeometryTypeTestMixin, ReusedSQLTestCase): + pass + + +if __name__ == "__main__": + import unittest + from pyspark.sql.tests.test_types import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 319ff92dd362d..20a6e34a26a43 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -60,6 +60,8 @@ DecimalType, BinaryType, BooleanType, + GeographyType, + GeometryType, NullType, VariantType, VariantVal, @@ -921,6 +923,98 @@ def test_schema_with_bad_collations_provider(self): self.assertRaises(PySparkValueError, lambda: _parse_datatype_json_string(schema_json)) + def test_geography_json_serde(self): + from pyspark.sql.types import _parse_datatype_json_value, _parse_datatype_json_string + + valid_test_cases = [ + ("geography", GeographyType(4326)), + ("geography(OGC:CRS84)", GeographyType(4326)), + ("geography(OGC:CRS84, SPHERICAL)", GeographyType(4326)), + ("geography(SPHERICAL)", GeographyType(4326)), + ("geography(SRID:ANY)", GeographyType("ANY")), + ("geography(srid:any)", GeographyType("ANY")), + ] + for json, expected in valid_test_cases: + python_datatype = _parse_datatype_json_value(json) + self.assertEqual(python_datatype, expected) + self.assertEqual(expected, _parse_datatype_json_string(expected.json())) + + invalid_test_cases = [ + "geography()", + "geography(())", + "geography(0)", + "geography(1)", + "geography(3857)", + "geography(4326)", + "geography(ANY)", + "geography(any)", + "geography(SRID)", + "geography(srid)", + "geography(CRS)", + "geography(crs)", + "geography(asdf)", + "geography(asdf:fdsa)", + "geography(123:123)", + "geography(srid:srid)", + "geography(SRID:0)", + "geography(SRID:1)", + "geography(SRID:123)", + "geography(EPSG:123)", + "geography(ESRI:123)", + "geography(OCG:123)", + "geography(OCG:CRS123)", + "geography(SRID:0,)", + "geography(SRID0)", + "geography(SRID:4326, ALG)" + ] + for json in invalid_test_cases: + with self.assertRaises(Exception): + _parse_datatype_json_value(json) + + def test_geometry_json_serde(self): + from pyspark.sql.types import _parse_datatype_json_value, _parse_datatype_json_string + + valid_test_cases = [ + ("geometry", GeometryType(4326)), + ("geometry(OGC:CRS84)", GeometryType(4326)), + ("geometry(SRID:ANY)", GeometryType("ANY")), + ("geometry(srid:any)", GeometryType("ANY")), + ] + for json, expected in valid_test_cases: + python_datatype = _parse_datatype_json_value(json) + self.assertEqual(python_datatype, expected) + self.assertEqual(expected, _parse_datatype_json_string(expected.json())) + + invalid_test_cases = [ + "geometry()", + "geometry(())", + "geometry(0)", + "geometry(1)", + "geometry(3857)", + "geometry(4326)", + "geometry(ANY)", + "geometry(any)", + "geometry(SRID)", + "geometry(srid)", + "geometry(CRS)", + "geometry(crs)", + "geometry(asdf)", + "geometry(asdf:fdsa)", + "geometry(123:123)", + "geometry(srid:srid)", + "geometry(SRID:1)", + "geometry(SRID:123)", + "geometry(EPSG:123)", + "geometry(ESRI:123)", + "geometry(OCG:123)", + "geometry(OCG:CRS123)", + "geometry(SRID:0,)", + "geometry(SRID0)" + ] + for json in invalid_test_cases: + with self.assertRaises(Exception): + _parse_datatype_json_value(json) + def test_udt(self): from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _make_type_verifier @@ -1268,6 +1362,10 @@ def test_parse_datatype_json_string(self): TimestampType(), TimestampNTZType(), NullType(), + GeographyType(4326), + GeographyType("ANY"), + GeometryType(4326), + GeometryType("ANY"), VariantType(), YearMonthIntervalType(), YearMonthIntervalType(YearMonthIntervalType.YEAR), diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index db162e8b1c521..6efb0367f132c 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -51,6 +51,7 @@ from pyspark.sql.utils import ( get_active_spark_context, escape_meta_characters, + IllegalArgumentException, StringConcat, ) from pyspark.sql.variant_utils import VariantUtils @@ -518,6 +519,178 @@ class FloatType(FractionalType, metaclass=DataTypeSingleton): pass +class SpatialType(AtomicType): + """Super class of all spatial data types: GeographyType and GeometryType.""" + + # Mixed SRID value and the corresponding CRS for geospatial types (Geometry and Geography). + # These values represent a geospatial type that can hold different SRID values per row. + MIXED_SRID = -1 + MIXED_CRS = "SRID:ANY" + + +class GeographyType(SpatialType): + """ + The data type representing GEOGRAPHY values which are spatial objects, as defined in the Open + Geospatial Consortium (OGC) Simple Feature Access specification + (https://portal.ogc.org/files/?artifact_id=25355), with a geographic coordinate system. + + .. versionadded:: 4.1.0 + """ + + # The default coordinate reference system (CRS) value and the default edge interpolation + # algorithm used for geographies, as specified by the Parquet, Delta, and Iceberg + # specifications. If CRS or algorithm values are omitted, they should default to these. + DEFAULT_CRS = "OGC:CRS84" + DEFAULT_ALG = "SPHERICAL" + + def __init__(self, srid): + # Special string value "ANY" is used to represent the mixed SRID GEOGRAPHY type. + if srid == "ANY": + self.srid = GeographyType.MIXED_SRID + self._crs = GeographyType.MIXED_CRS + # Otherwise, the parameterized GEOMETRY type syntax requires a valid SRID value. + elif not isinstance(srid, int) or srid != 4326: + raise IllegalArgumentException( + errorClass="ST_INVALID_SRID_VALUE", + messageParameters={ + "srid": str(srid), + }, + ) + else: + self.srid = srid + self._crs = GeographyType.DEFAULT_CRS + self._alg = GeographyType.DEFAULT_ALG + + @classmethod + def _from_crs(cls, crs, alg) -> "GeographyType": + # Algorithm value must be validated, although only SPHERICAL is supported currently. + if alg != cls.DEFAULT_ALG: + raise IllegalArgumentException( + errorClass="INVALID_ALGORITHM_VALUE", + messageParameters={ + "alg": str(alg), + }, + ) + # Special CRS value "SRID:ANY" is used to represent the mixed SRID GEOGRAPHY type. + # Note: unlike the actual CRS values, the special "SRID:ANY" value is case-insensitive. + if crs.lower() == cls.MIXED_CRS.lower(): + return GeographyType("ANY") + # Otherwise, JSON parsing for the GEOGRAPHY type requires a valid CRS value. + srid = 4326 if crs == "OGC:CRS84" else None + if srid is None: + raise IllegalArgumentException( + errorClass="ST_INVALID_CRS_VALUE", + messageParameters={ + "crs": str(crs), + }, + ) + geography = GeographyType(srid) + geography._crs = crs + geography._alg = alg + return geography + + def simpleString(self) -> str: + if self.srid == GeographyType.MIXED_SRID: + # The mixed SRID type is displayed with a special string value "ANY". + return "geography(any)" + else: + # The fixed SRID type is displayed with the appropriate SRID value. + return f"geography({self.srid})" + + def __repr__(self) -> str: + if self.srid == GeographyType.MIXED_SRID: + # The mixed SRID type is displayed with a special string value "ANY". + return "GeographyType(ANY)" + else: + # The fixed SRID type is displayed with the appropriate SRID value. + return f"GeographyType({self.srid})" + + def jsonValue(self) -> Union[str, Dict[str, Any]]: + # The JSON representation always uses the CRS and algorithm value. + return f"geography({self._crs}, {self._alg})" + + +class GeometryType(SpatialType): + """ + The data type representing GEOMETRY values which are spatial objects, as defined in the Open + Geospatial Consortium (OGC) Simple Feature Access specification + (https://portal.ogc.org/files/?artifact_id=25355), with a Cartesian coordinate system. + + Parameters + ---------- + srid : int or str + The Spatial Reference System Identifier (SRID) value for the GEOMETRY. + + .. versionadded:: 4.1.0 + """ + + # The default coordinate reference system (CRS) value used for geometries, as specified by the + # Parquet, Delta, and Iceberg specifications. If CRS is omitted, it should default to this. + DEFAULT_CRS = "OGC:CRS84" + + """ The constructor for the GEOMETRY type can accept either a single valid geometric integer + SRID value, or a special string value "ANY" used to represent a mixed SRID GEOMETRY type.""" + def __init__(self, srid): + # Special string value "ANY" is used to represent the mixed SRID GEOMETRY type. + if srid == "ANY": + self.srid = GeometryType.MIXED_SRID + self._crs = GeometryType.MIXED_CRS + # Otherwise, the parameterized GEOMETRY type syntax requires a valid SRID value. + elif not isinstance(srid, int) or srid != 4326: + raise IllegalArgumentException( + errorClass="ST_INVALID_SRID_VALUE", + messageParameters={ + "srid": str(srid), + }, + ) + # If the SRID is valid, initialize the GEOMETRY type with the corresponding CRS value. + else: + self.srid = srid + self._crs = GeometryType.DEFAULT_CRS + + """ JSON parsing logic for the GEOMETRY type relies on the CRS value, instead of the SRID. + The method can accept either a single valid geometric string CRS value, or a special case + insensitive string value "SRID:ANY" used to represent a mixed SRID GEOMETRY type.""" + @classmethod + def _from_crs(cls, crs) -> "GeometryType": + # Special CRS value "SRID:ANY" is used to represent the mixed SRID GEOMETRY type. + # Note: unlike the actual CRS values, the special "SRID:ANY" value is case-insensitive. + if crs.lower() == cls.MIXED_CRS.lower(): + return GeometryType("ANY") + # Otherwise, JSON parsing for the GEOMETRY type requires a valid CRS value. + srid = 4326 if crs == "OGC:CRS84" else None + if srid is None: + raise IllegalArgumentException( + errorClass="ST_INVALID_CRS_VALUE", + messageParameters={ + "crs": str(crs), + }, + ) + geometry = GeometryType(srid) + geometry._crs = crs + return geometry + + def simpleString(self) -> str: + if self.srid == GeometryType.MIXED_SRID: + # The mixed SRID type is displayed with a special string value "ANY". + return "geometry(any)" + else: + # The fixed SRID type is displayed with the appropriate SRID value. + return f"geometry({self.srid})" + + def __repr__(self) -> str: + if self.srid == GeometryType.MIXED_SRID: + # The mixed SRID type is displayed with a special string value "ANY". + return "GeometryType(ANY)" + else: + # The fixed SRID type is displayed with the appropriate SRID value. + return f"GeometryType({self.srid})" + + def jsonValue(self) -> Union[str, Dict[str, Any]]: + # The JSON representation always uses the CRS value. + return f"geometry({self._crs})" + + class ByteType(IntegralType): """Byte data type, representing signed 8-bit integers.""" @@ -1921,6 +2094,12 @@ def parseJson(cls, json_str: str) -> "VariantVal": _INTERVAL_DAYTIME = re.compile(r"interval (day|hour|minute|second)( to (day|hour|minute|second))?") _INTERVAL_YEARMONTH = re.compile(r"interval (year|month)( to (year|month))?") _TIME = re.compile(r"time\(\s*(\d+)\s*\)") +_GEOMETRY = re.compile(r"^geometry$") +_GEOMETRY_CRS = re.compile(r"geometry\(\s*([\w]+:-?[\w]+)\s*\)") +_GEOGRAPHY = re.compile(r"^geography$") +_GEOGRAPHY_CRS = re.compile(r"geography\(\s*([\w]+:-?[\w]+)\s*\)") +_GEOGRAPHY_CRS_ALG = re.compile(r"geography\(\s*([\w]+:-?[\w]+)\s*,\s*(\w+)\s*\)") +_GEOGRAPHY_ALG = re.compile(r"geography\(\s*(\w+)\s*\)") _COLLATIONS_METADATA_KEY = "__COLLATIONS" @@ -2090,6 +2269,22 @@ def _parse_datatype_json_value( elif _LENGTH_VARCHAR.match(json_value): m = _LENGTH_VARCHAR.match(json_value) return VarcharType(int(m.group(1))) # type: ignore[union-attr] + elif _GEOMETRY.match(json_value): + return GeometryType._from_crs(GeometryType.DEFAULT_CRS) + elif _GEOMETRY_CRS.match(json_value): + crs = _GEOMETRY_CRS.match(json_value) + return GeometryType._from_crs(crs.group(1)) + elif _GEOGRAPHY.match(json_value): + return GeographyType._from_crs(GeographyType.DEFAULT_CRS, GeographyType.DEFAULT_ALG) + elif _GEOGRAPHY_CRS.match(json_value): + crs = _GEOGRAPHY_CRS.match(json_value) + return GeographyType._from_crs(crs.group(1), GeographyType.DEFAULT_ALG) + elif _GEOGRAPHY_CRS_ALG.match(json_value): + crs_alg = _GEOGRAPHY_CRS_ALG.match(json_value) + return GeographyType._from_crs(crs_alg.group(1), crs_alg.group(2)) + elif _GEOGRAPHY_ALG.match(json_value): + alg = _GEOGRAPHY_ALG.match(json_value) + return GeographyType._from_crs(GeographyType.DEFAULT_CRS, alg.group(1)) else: raise PySparkValueError( errorClass="CANNOT_PARSE_DATATYPE", From df438bc724b3eaf8656bca45eadbb6b455b43f88 Mon Sep 17 00:00:00 2001 From: Uros Bojanic Date: Thu, 16 Oct 2025 01:37:03 +0200 Subject: [PATCH 2/5] Fix lint issues --- .../pyspark/sql/tests/test_geographytype.py | 3 +-- python/pyspark/sql/tests/test_geometrytype.py | 3 +-- python/pyspark/sql/tests/test_types.py | 24 +++++++++---------- python/pyspark/sql/types.py | 2 ++ 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/python/pyspark/sql/tests/test_geographytype.py b/python/pyspark/sql/tests/test_geographytype.py index 22a33100b2665..ac37e3bb45e7c 100644 --- a/python/pyspark/sql/tests/test_geographytype.py +++ b/python/pyspark/sql/tests/test_geographytype.py @@ -22,7 +22,6 @@ class GeographyTypeTestMixin: - # Test cases for GeographyType construction based on SRID. def test_geographytype_specified_valid_srid(self): @@ -47,7 +46,7 @@ def test_geographytype_specified_invalid_srid(self): srid_header = "[ST_INVALID_SRID_VALUE] Invalid or unsupported SRID" self.assertEqual( str(error_context.exception), - f"{srid_header} (spatial reference identifier) value: {srid}." + f"{srid_header} (spatial reference identifier) value: {srid}.", ) # Special string value "ANY" in place of SRID is used to denote a mixed GEOGRAPHY type. diff --git a/python/pyspark/sql/tests/test_geometrytype.py b/python/pyspark/sql/tests/test_geometrytype.py index 36c165b7e4709..f2ad03b78c44d 100644 --- a/python/pyspark/sql/tests/test_geometrytype.py +++ b/python/pyspark/sql/tests/test_geometrytype.py @@ -22,7 +22,6 @@ class GeometryTypeTestMixin: - # Test cases for GeometryType construction based on SRID. def test_geometrytype_specified_valid_srid(self): @@ -47,7 +46,7 @@ def test_geometrytype_specified_invalid_srid(self): srid_header = "[ST_INVALID_SRID_VALUE] Invalid or unsupported SRID" self.assertEqual( str(error_context.exception), - f"{srid_header} (spatial reference identifier) value: {srid}." + f"{srid_header} (spatial reference identifier) value: {srid}.", ) # Special string value "ANY" in place of SRID is used to denote a mixed GEOMETRY type. diff --git a/python/pyspark/sql/tests/test_types.py b/python/pyspark/sql/tests/test_types.py index 20a6e34a26a43..d72aff20dbd83 100644 --- a/python/pyspark/sql/tests/test_types.py +++ b/python/pyspark/sql/tests/test_types.py @@ -927,12 +927,12 @@ def test_geography_json_serde(self): from pyspark.sql.types import _parse_datatype_json_value, _parse_datatype_json_string valid_test_cases = [ - ("geography", GeographyType(4326)), - ("geography(OGC:CRS84)", GeographyType(4326)), - ("geography(OGC:CRS84, SPHERICAL)", GeographyType(4326)), - ("geography(SPHERICAL)", GeographyType(4326)), - ("geography(SRID:ANY)", GeographyType("ANY")), - ("geography(srid:any)", GeographyType("ANY")), + ("geography", GeographyType(4326)), + ("geography(OGC:CRS84)", GeographyType(4326)), + ("geography(OGC:CRS84, SPHERICAL)", GeographyType(4326)), + ("geography(SPHERICAL)", GeographyType(4326)), + ("geography(SRID:ANY)", GeographyType("ANY")), + ("geography(srid:any)", GeographyType("ANY")), ] for json, expected in valid_test_cases: python_datatype = _parse_datatype_json_value(json) @@ -965,7 +965,7 @@ def test_geography_json_serde(self): "geography(OCG:CRS123)", "geography(SRID:0,)", "geography(SRID0)", - "geography(SRID:4326, ALG)" + "geography(SRID:4326, ALG)", ] for json in invalid_test_cases: with self.assertRaises(Exception): @@ -975,10 +975,10 @@ def test_geometry_json_serde(self): from pyspark.sql.types import _parse_datatype_json_value, _parse_datatype_json_string valid_test_cases = [ - ("geometry", GeometryType(4326)), - ("geometry(OGC:CRS84)", GeometryType(4326)), - ("geometry(SRID:ANY)", GeometryType("ANY")), - ("geometry(srid:any)", GeometryType("ANY")), + ("geometry", GeometryType(4326)), + ("geometry(OGC:CRS84)", GeometryType(4326)), + ("geometry(SRID:ANY)", GeometryType("ANY")), + ("geometry(srid:any)", GeometryType("ANY")), ] for json, expected in valid_test_cases: python_datatype = _parse_datatype_json_value(json) @@ -1009,7 +1009,7 @@ def test_geometry_json_serde(self): "geometry(OCG:123)", "geometry(OCG:CRS123)", "geometry(SRID:0,)", - "geometry(SRID0)" + "geometry(SRID0)", ] for json in invalid_test_cases: with self.assertRaises(Exception): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 6efb0367f132c..d2380b616f43b 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -630,6 +630,7 @@ class GeometryType(SpatialType): """ The constructor for the GEOMETRY type can accept either a single valid geometric integer SRID value, or a special string value "ANY" used to represent a mixed SRID GEOMETRY type.""" + def __init__(self, srid): # Special string value "ANY" is used to represent the mixed SRID GEOMETRY type. if srid == "ANY": @@ -651,6 +652,7 @@ def __init__(self, srid): """ JSON parsing logic for the GEOMETRY type relies on the CRS value, instead of the SRID. The method can accept either a single valid geometric string CRS value, or a special case insensitive string value "SRID:ANY" used to represent a mixed SRID GEOMETRY type.""" + @classmethod def _from_crs(cls, crs) -> "GeometryType": # Special CRS value "SRID:ANY" is used to represent the mixed SRID GEOMETRY type. From 9e9c2ceab8a41426aaa9820d3498076fe1c0cff3 Mon Sep 17 00:00:00 2001 From: Uros Bojanic Date: Mon, 20 Oct 2025 14:20:32 +0200 Subject: [PATCH 3/5] Address comments --- dev/sparktestsupport/modules.py | 2 ++ python/docs/source/reference/pyspark.sql/data_types.rst | 2 ++ python/pyspark/sql/tests/test_geographytype.py | 2 +- python/pyspark/sql/tests/test_geometrytype.py | 2 +- 4 files changed, 6 insertions(+), 2 deletions(-) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 945a2ac9189b0..8a0c9391e88aa 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -574,6 +574,8 @@ def __hash__(self): "pyspark.sql.tests.test_session", "pyspark.sql.tests.test_subquery", "pyspark.sql.tests.test_types", + "pyspark.sql.tests.test_geographytype", + "pyspark.sql.tests.test_geometrytype", "pyspark.sql.tests.test_udf", "pyspark.sql.tests.test_udf_combinations", "pyspark.sql.tests.test_udf_profiler", diff --git a/python/docs/source/reference/pyspark.sql/data_types.rst b/python/docs/source/reference/pyspark.sql/data_types.rst index 99f8c5bb87ef9..ffb24c68445dd 100644 --- a/python/docs/source/reference/pyspark.sql/data_types.rst +++ b/python/docs/source/reference/pyspark.sql/data_types.rst @@ -35,6 +35,8 @@ Data Types DecimalType DoubleType FloatType + GeographyType + GeometryType IntegerType LongType MapType diff --git a/python/pyspark/sql/tests/test_geographytype.py b/python/pyspark/sql/tests/test_geographytype.py index ac37e3bb45e7c..701ee78cd7ad9 100644 --- a/python/pyspark/sql/tests/test_geographytype.py +++ b/python/pyspark/sql/tests/test_geographytype.py @@ -85,7 +85,7 @@ class GeographyTypeTest(GeographyTypeTestMixin, ReusedSQLTestCase): if __name__ == "__main__": import unittest - from pyspark.sql.tests.test_types import * # noqa: F401 + from pyspark.sql.tests.test_geographytype import * # noqa: F401 try: import xmlrunner diff --git a/python/pyspark/sql/tests/test_geometrytype.py b/python/pyspark/sql/tests/test_geometrytype.py index f2ad03b78c44d..8647404d4f886 100644 --- a/python/pyspark/sql/tests/test_geometrytype.py +++ b/python/pyspark/sql/tests/test_geometrytype.py @@ -85,7 +85,7 @@ class GeometryTypeTest(GeometryTypeTestMixin, ReusedSQLTestCase): if __name__ == "__main__": import unittest - from pyspark.sql.tests.test_types import * # noqa: F401 + from pyspark.sql.tests.test_geometrytype import * # noqa: F401 try: import xmlrunner From 0a1828e613cddbe93468d563200f670e500318f1 Mon Sep 17 00:00:00 2001 From: Uros Bojanic Date: Mon, 20 Oct 2025 14:29:08 +0200 Subject: [PATCH 4/5] Fix Python linter issues --- python/pyspark/sql/types.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index d2380b616f43b..df7858ba7721b 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -543,7 +543,7 @@ class GeographyType(SpatialType): DEFAULT_CRS = "OGC:CRS84" DEFAULT_ALG = "SPHERICAL" - def __init__(self, srid): + def __init__(self, srid: int | str): # Special string value "ANY" is used to represent the mixed SRID GEOGRAPHY type. if srid == "ANY": self.srid = GeographyType.MIXED_SRID @@ -562,7 +562,7 @@ def __init__(self, srid): self._alg = GeographyType.DEFAULT_ALG @classmethod - def _from_crs(cls, crs, alg) -> "GeographyType": + def _from_crs(cls, crs: str, alg: str) -> "GeographyType": # Algorithm value must be validated, although only SPHERICAL is supported currently. if alg != cls.DEFAULT_ALG: raise IllegalArgumentException( @@ -631,7 +631,7 @@ class GeometryType(SpatialType): """ The constructor for the GEOMETRY type can accept either a single valid geometric integer SRID value, or a special string value "ANY" used to represent a mixed SRID GEOMETRY type.""" - def __init__(self, srid): + def __init__(self, srid: int | str): # Special string value "ANY" is used to represent the mixed SRID GEOMETRY type. if srid == "ANY": self.srid = GeometryType.MIXED_SRID @@ -654,7 +654,7 @@ def __init__(self, srid): insensitive string value "SRID:ANY" used to represent a mixed SRID GEOMETRY type.""" @classmethod - def _from_crs(cls, crs) -> "GeometryType": + def _from_crs(cls, crs: str) -> "GeometryType": # Special CRS value "SRID:ANY" is used to represent the mixed SRID GEOMETRY type. # Note: unlike the actual CRS values, the special "SRID:ANY" value is case-insensitive. if crs.lower() == cls.MIXED_CRS.lower(): @@ -2275,18 +2275,22 @@ def _parse_datatype_json_value( return GeometryType._from_crs(GeometryType.DEFAULT_CRS) elif _GEOMETRY_CRS.match(json_value): crs = _GEOMETRY_CRS.match(json_value) - return GeometryType._from_crs(crs.group(1)) + if crs is not None: + return GeometryType._from_crs(crs.group(1)) elif _GEOGRAPHY.match(json_value): return GeographyType._from_crs(GeographyType.DEFAULT_CRS, GeographyType.DEFAULT_ALG) elif _GEOGRAPHY_CRS.match(json_value): crs = _GEOGRAPHY_CRS.match(json_value) - return GeographyType._from_crs(crs.group(1), GeographyType.DEFAULT_ALG) + if crs is not None: + return GeographyType._from_crs(crs.group(1), GeographyType.DEFAULT_ALG) elif _GEOGRAPHY_CRS_ALG.match(json_value): crs_alg = _GEOGRAPHY_CRS_ALG.match(json_value) - return GeographyType._from_crs(crs_alg.group(1), crs_alg.group(2)) + if crs_alg is not None: + return GeographyType._from_crs(crs_alg.group(1), crs_alg.group(2)) elif _GEOGRAPHY_ALG.match(json_value): alg = _GEOGRAPHY_ALG.match(json_value) - return GeographyType._from_crs(GeographyType.DEFAULT_CRS, alg.group(1)) + if alg is not None: + return GeographyType._from_crs(GeographyType.DEFAULT_CRS, alg.group(1)) else: raise PySparkValueError( errorClass="CANNOT_PARSE_DATATYPE", From 04a1393eac3a3bb716ce9320283cd069ae30d493 Mon Sep 17 00:00:00 2001 From: Uros Bojanic Date: Tue, 21 Oct 2025 18:30:14 +0200 Subject: [PATCH 5/5] Address comments --- python/pyspark/sql/types.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index df7858ba7721b..1295936be33b8 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -542,6 +542,7 @@ class GeographyType(SpatialType): # specifications. If CRS or algorithm values are omitted, they should default to these. DEFAULT_CRS = "OGC:CRS84" DEFAULT_ALG = "SPHERICAL" + DEFAULT_SRID = 4326 def __init__(self, srid: int | str): # Special string value "ANY" is used to represent the mixed SRID GEOGRAPHY type. @@ -549,7 +550,7 @@ def __init__(self, srid: int | str): self.srid = GeographyType.MIXED_SRID self._crs = GeographyType.MIXED_CRS # Otherwise, the parameterized GEOMETRY type syntax requires a valid SRID value. - elif not isinstance(srid, int) or srid != 4326: + elif not isinstance(srid, int) or srid != GeographyType.DEFAULT_SRID: raise IllegalArgumentException( errorClass="ST_INVALID_SRID_VALUE", messageParameters={ @@ -576,7 +577,7 @@ def _from_crs(cls, crs: str, alg: str) -> "GeographyType": if crs.lower() == cls.MIXED_CRS.lower(): return GeographyType("ANY") # Otherwise, JSON parsing for the GEOGRAPHY type requires a valid CRS value. - srid = 4326 if crs == "OGC:CRS84" else None + srid = GeographyType.DEFAULT_SRID if crs == "OGC:CRS84" else None if srid is None: raise IllegalArgumentException( errorClass="ST_INVALID_CRS_VALUE", @@ -627,6 +628,7 @@ class GeometryType(SpatialType): # The default coordinate reference system (CRS) value used for geometries, as specified by the # Parquet, Delta, and Iceberg specifications. If CRS is omitted, it should default to this. DEFAULT_CRS = "OGC:CRS84" + DEFAULT_SRID = 4326 """ The constructor for the GEOMETRY type can accept either a single valid geometric integer SRID value, or a special string value "ANY" used to represent a mixed SRID GEOMETRY type.""" @@ -637,7 +639,7 @@ def __init__(self, srid: int | str): self.srid = GeometryType.MIXED_SRID self._crs = GeometryType.MIXED_CRS # Otherwise, the parameterized GEOMETRY type syntax requires a valid SRID value. - elif not isinstance(srid, int) or srid != 4326: + elif not isinstance(srid, int) or srid != GeometryType.DEFAULT_SRID: raise IllegalArgumentException( errorClass="ST_INVALID_SRID_VALUE", messageParameters={ @@ -660,7 +662,7 @@ def _from_crs(cls, crs: str) -> "GeometryType": if crs.lower() == cls.MIXED_CRS.lower(): return GeometryType("ANY") # Otherwise, JSON parsing for the GEOMETRY type requires a valid CRS value. - srid = 4326 if crs == "OGC:CRS84" else None + srid = GeometryType.DEFAULT_SRID if crs == "OGC:CRS84" else None if srid is None: raise IllegalArgumentException( errorClass="ST_INVALID_CRS_VALUE", @@ -2228,7 +2230,7 @@ def _parse_datatype_json_string(json_string: str) -> DataType: return _parse_datatype_json_value(json.loads(json_string)) -def _parse_datatype_json_value( +def _parse_datatype_json_value( # type: ignore[return] json_value: Union[dict, str], fieldPath: str = "", collationsMap: Optional[Dict[str, str]] = None,