diff --git a/pyiceberg/catalog/glue.py b/pyiceberg/catalog/glue.py index 6cf9462b71..bccbfa4f0a 100644 --- a/pyiceberg/catalog/glue.py +++ b/pyiceberg/catalog/glue.py @@ -18,6 +18,7 @@ from typing import ( Any, + Dict, List, Optional, Set, @@ -28,6 +29,7 @@ import boto3 from mypy_boto3_glue.client import GlueClient from mypy_boto3_glue.type_defs import ( + ColumnTypeDef, DatabaseInputTypeDef, DatabaseTypeDef, StorageDescriptorTypeDef, @@ -59,12 +61,32 @@ ) from pyiceberg.io import load_file_io from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionSpec -from pyiceberg.schema import Schema +from pyiceberg.schema import Schema, SchemaVisitor, visit from pyiceberg.serializers import FromInputFile from pyiceberg.table import CommitTableRequest, CommitTableResponse, Table, update_table_metadata -from pyiceberg.table.metadata import new_table_metadata +from pyiceberg.table.metadata import TableMetadata, new_table_metadata from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder from pyiceberg.typedef import EMPTY_DICT +from pyiceberg.types import ( + BinaryType, + BooleanType, + DateType, + DecimalType, + DoubleType, + FixedType, + FloatType, + IntegerType, + ListType, + LongType, + MapType, + NestedField, + PrimitiveType, + StringType, + StructType, + TimestampType, + TimeType, + UUIDType, +) # If Glue should skip archiving an old table version when creating a new version in a commit. By # default, Glue archives all old table versions after an UpdateTable call, but Glue has a default @@ -73,6 +95,10 @@ GLUE_SKIP_ARCHIVE = "glue.skip-archive" GLUE_SKIP_ARCHIVE_DEFAULT = True +ICEBERG_FIELD_ID = "iceberg.field.id" +ICEBERG_FIELD_OPTIONAL = "iceberg.field.optional" +ICEBERG_FIELD_CURRENT = "iceberg.field.current" + def _construct_parameters( metadata_location: str, glue_table: Optional[TableTypeDef] = None, prev_metadata_location: Optional[str] = None @@ -84,10 +110,86 @@ def _construct_parameters( return new_parameters +GLUE_PRIMITIVE_TYPES = { + BooleanType: "boolean", + IntegerType: "int", + LongType: "bigint", + FloatType: "float", + DoubleType: "double", + DateType: "date", + TimeType: "string", + StringType: "string", + UUIDType: "string", + TimestampType: "timestamp", + FixedType: "binary", + BinaryType: "binary", +} + + +class _IcebergSchemaToGlueType(SchemaVisitor[str]): + def schema(self, schema: Schema, struct_result: str) -> str: + return struct_result + + def struct(self, struct: StructType, field_results: List[str]) -> str: + return f"struct<{','.join(field_results)}>" + + def field(self, field: NestedField, field_result: str) -> str: + return f"{field.name}:{field_result}" + + def list(self, list_type: ListType, element_result: str) -> str: + return f"array<{element_result}>" + + def map(self, map_type: MapType, key_result: str, value_result: str) -> str: + return f"map<{key_result},{value_result}>" + + def primitive(self, primitive: PrimitiveType) -> str: + if isinstance(primitive, DecimalType): + return f"decimal({primitive.precision},{primitive.scale})" + if (primitive_type := type(primitive)) not in GLUE_PRIMITIVE_TYPES: + raise ValueError(f"Unknown primitive type: {primitive}") + return GLUE_PRIMITIVE_TYPES[primitive_type] + + +def _to_columns(metadata: TableMetadata) -> List[ColumnTypeDef]: + results: Dict[str, ColumnTypeDef] = {} + + def _append_to_results(field: NestedField, is_current: bool) -> None: + if field.name in results: + return + + results[field.name] = cast( + ColumnTypeDef, + { + "Name": field.name, + "Type": visit(field.field_type, _IcebergSchemaToGlueType()), + "Parameters": { + ICEBERG_FIELD_ID: str(field.field_id), + ICEBERG_FIELD_OPTIONAL: str(field.optional).lower(), + ICEBERG_FIELD_CURRENT: str(is_current).lower(), + }, + }, + ) + if field.doc: + results[field.name]["Comment"] = field.doc + + if current_schema := metadata.schema_by_id(metadata.current_schema_id): + for field in current_schema.columns: + _append_to_results(field, True) + + for schema in metadata.schemas: + if schema.schema_id == metadata.current_schema_id: + continue + for field in schema.columns: + _append_to_results(field, False) + + return list(results.values()) + + def _construct_table_input( table_name: str, metadata_location: str, properties: Properties, + metadata: TableMetadata, glue_table: Optional[TableTypeDef] = None, prev_metadata_location: Optional[str] = None, ) -> TableInputTypeDef: @@ -95,6 +197,10 @@ def _construct_table_input( "Name": table_name, "TableType": EXTERNAL_TABLE, "Parameters": _construct_parameters(metadata_location, glue_table, prev_metadata_location), + "StorageDescriptor": { + "Columns": _to_columns(metadata), + "Location": metadata.location, + }, } if "Description" in properties: @@ -258,7 +364,7 @@ def create_table( io = load_file_io(properties=self.properties, location=metadata_location) self._write_metadata(metadata, io, metadata_location) - table_input = _construct_table_input(table_name, metadata_location, properties) + table_input = _construct_table_input(table_name, metadata_location, properties, metadata) database_name, table_name = self.identifier_to_database_and_table(identifier) self._create_glue_table(database_name=database_name, table_name=table_name, table_input=table_input) @@ -322,6 +428,7 @@ def _commit_table(self, table_request: CommitTableRequest) -> CommitTableRespons table_name=table_name, metadata_location=new_metadata_location, properties=current_table.properties, + metadata=updated_metadata, glue_table=current_glue_table, prev_metadata_location=current_table.metadata_location, ) diff --git a/tests/catalog/integration_test_glue.py b/tests/catalog/integration_test_glue.py index 24401cae39..a56e4c6aaa 100644 --- a/tests/catalog/integration_test_glue.py +++ b/tests/catalog/integration_test_glue.py @@ -15,9 +15,12 @@ # specific language governing permissions and limitations # under the License. -from typing import Generator, List +import time +from typing import Any, Dict, Generator, List +from uuid import uuid4 import boto3 +import pyarrow as pa import pytest from botocore.exceptions import ClientError @@ -30,6 +33,7 @@ NoSuchTableError, TableAlreadyExistsError, ) +from pyiceberg.io.pyarrow import schema_to_pyarrow from pyiceberg.schema import Schema from pyiceberg.types import IntegerType from tests.conftest import clean_up, get_bucket_name, get_s3_path @@ -52,8 +56,62 @@ def fixture_test_catalog() -> Generator[Catalog, None, None]: clean_up(test_catalog) +class AthenaQueryHelper: + _athena_client: boto3.client + _s3_resource: boto3.resource + _output_bucket: str + _output_path: str + + def __init__(self) -> None: + self._s3_resource = boto3.resource("s3") + self._athena_client = boto3.client("athena") + self._output_bucket = get_bucket_name() + self._output_path = f"athena_results_{uuid4()}" + + def get_query_results(self, query: str) -> List[Dict[str, Any]]: + query_execution_id = self._athena_client.start_query_execution( + QueryString=query, ResultConfiguration={"OutputLocation": f"s3://{self._output_bucket}/{self._output_path}"} + )["QueryExecutionId"] + + while True: + result = self._athena_client.get_query_execution(QueryExecutionId=query_execution_id)["QueryExecution"]["Status"] + query_status = result["State"] + assert query_status not in [ + "FAILED", + "CANCELLED", + ], f""" + Athena query with the string failed or was cancelled: + Query: {query} + Status: {query_status} + Reason: {result["StateChangeReason"]}""" + + if query_status not in ["QUEUED", "RUNNING"]: + break + time.sleep(0.5) + + # No pagination for now, assume that we are not doing large queries + return self._athena_client.get_query_results(QueryExecutionId=query_execution_id)["ResultSet"]["Rows"] + + def clean_up(self) -> None: + bucket = self._s3_resource.Bucket(self._output_bucket) + for obj in bucket.objects.filter(Prefix=f"{self._output_path}/"): + self._s3_resource.Object(bucket.name, obj.key).delete() + + +@pytest.fixture(name="athena", scope="module") +def fixture_athena_helper() -> Generator[AthenaQueryHelper, None, None]: + query_helper = AthenaQueryHelper() + yield query_helper + query_helper.clean_up() + + def test_create_table( - test_catalog: Catalog, s3: boto3.client, table_schema_nested: Schema, table_name: str, database_name: str + test_catalog: Catalog, + s3: boto3.client, + table_schema_nested: Schema, + table_name: str, + database_name: str, + athena: AthenaQueryHelper, ) -> None: identifier = (database_name, table_name) test_catalog.create_namespace(database_name) @@ -64,6 +122,48 @@ def test_create_table( s3.head_object(Bucket=get_bucket_name(), Key=metadata_location) assert test_catalog._parse_metadata_version(table.metadata_location) == 0 + table.append( + pa.Table.from_pylist( + [ + { + "foo": "foo_val", + "bar": 1, + "baz": False, + "qux": ["x", "y"], + "quux": {"key": {"subkey": 2}}, + "location": [{"latitude": 1.1}], + "person": {"name": "some_name", "age": 23}, + } + ], + schema=schema_to_pyarrow(table.schema()), + ), + ) + + assert athena.get_query_results(f'SELECT * FROM "{database_name}"."{table_name}"') == [ + { + "Data": [ + {"VarCharValue": "foo"}, + {"VarCharValue": "bar"}, + {"VarCharValue": "baz"}, + {"VarCharValue": "qux"}, + {"VarCharValue": "quux"}, + {"VarCharValue": "location"}, + {"VarCharValue": "person"}, + ] + }, + { + "Data": [ + {"VarCharValue": "foo_val"}, + {"VarCharValue": "1"}, + {"VarCharValue": "false"}, + {"VarCharValue": "[x, y]"}, + {"VarCharValue": "{key={subkey=2}}"}, + {"VarCharValue": "[{latitude=1.1, longitude=null}]"}, + {"VarCharValue": "{name=some_name, age=23}"}, + ] + }, + ] + def test_create_table_with_invalid_location(table_schema_nested: Schema, table_name: str, database_name: str) -> None: identifier = (database_name, table_name) @@ -269,7 +369,7 @@ def test_update_namespace_properties(test_catalog: Catalog, database_name: str) def test_commit_table_update_schema( - test_catalog: Catalog, table_schema_nested: Schema, database_name: str, table_name: str + test_catalog: Catalog, table_schema_nested: Schema, database_name: str, table_name: str, athena: AthenaQueryHelper ) -> None: identifier = (database_name, table_name) test_catalog.create_namespace(namespace=database_name) @@ -279,6 +379,20 @@ def test_commit_table_update_schema( assert test_catalog._parse_metadata_version(table.metadata_location) == 0 assert original_table_metadata.current_schema_id == 0 + assert athena.get_query_results(f'SELECT * FROM "{database_name}"."{table_name}"') == [ + { + "Data": [ + {"VarCharValue": "foo"}, + {"VarCharValue": "bar"}, + {"VarCharValue": "baz"}, + {"VarCharValue": "qux"}, + {"VarCharValue": "quux"}, + {"VarCharValue": "location"}, + {"VarCharValue": "person"}, + ] + } + ] + transaction = table.transaction() update = transaction.update_schema() update.add_column(path="b", field_type=IntegerType()) @@ -295,6 +409,48 @@ def test_commit_table_update_schema( assert new_schema == update._apply() assert new_schema.find_field("b").field_type == IntegerType() + table.append( + pa.Table.from_pylist( + [ + { + "foo": "foo_val", + "bar": 1, + "location": [{"latitude": 1.1}], + "person": {"name": "some_name", "age": 23}, + "b": 2, + } + ], + schema=schema_to_pyarrow(new_schema), + ), + ) + + assert athena.get_query_results(f'SELECT * FROM "{database_name}"."{table_name}"') == [ + { + "Data": [ + {"VarCharValue": "foo"}, + {"VarCharValue": "bar"}, + {"VarCharValue": "baz"}, + {"VarCharValue": "qux"}, + {"VarCharValue": "quux"}, + {"VarCharValue": "location"}, + {"VarCharValue": "person"}, + {"VarCharValue": "b"}, + ] + }, + { + "Data": [ + {"VarCharValue": "foo_val"}, + {"VarCharValue": "1"}, + {}, + {"VarCharValue": "[]"}, + {"VarCharValue": "{}"}, + {"VarCharValue": "[{latitude=1.1, longitude=null}]"}, + {"VarCharValue": "{name=some_name, age=23}"}, + {"VarCharValue": "2"}, + ] + }, + ] + def test_commit_table_properties(test_catalog: Catalog, table_schema_nested: Schema, database_name: str, table_name: str) -> None: identifier = (database_name, table_name) diff --git a/tests/catalog/test_glue.py b/tests/catalog/test_glue.py index bf6d11784f..b1f1371a04 100644 --- a/tests/catalog/test_glue.py +++ b/tests/catalog/test_glue.py @@ -38,7 +38,12 @@ @mock_glue def test_create_table_with_database_location( - _bucket_initialize: None, moto_endpoint_url: str, table_schema_nested: Schema, database_name: str, table_name: str + _glue: boto3.client, + _bucket_initialize: None, + moto_endpoint_url: str, + table_schema_nested: Schema, + database_name: str, + table_name: str, ) -> None: catalog_name = "glue" identifier = (database_name, table_name) @@ -49,6 +54,22 @@ def test_create_table_with_database_location( assert TABLE_METADATA_LOCATION_REGEX.match(table.metadata_location) assert test_catalog._parse_metadata_version(table.metadata_location) == 0 + # Ensure schema is also pushed to Glue + table_info = _glue.get_table( + DatabaseName=database_name, + Name=table_name, + ) + storage_descriptor = table_info["Table"]["StorageDescriptor"] + columns = storage_descriptor["Columns"] + assert len(columns) == len(table_schema_nested.fields) + assert columns[0] == { + "Name": "foo", + "Type": "string", + "Parameters": {"iceberg.field.id": "1", "iceberg.field.optional": "true", "iceberg.field.current": "true"}, + } + + assert storage_descriptor["Location"] == f"s3://{BUCKET_NAME}/{database_name}.db/{table_name}" + @mock_glue def test_create_table_with_default_warehouse( @@ -524,7 +545,12 @@ def test_passing_profile_name() -> None: @mock_glue def test_commit_table_update_schema( - _bucket_initialize: None, moto_endpoint_url: str, table_schema_nested: Schema, database_name: str, table_name: str + _glue: boto3.client, + _bucket_initialize: None, + moto_endpoint_url: str, + table_schema_nested: Schema, + database_name: str, + table_name: str, ) -> None: catalog_name = "glue" identifier = (database_name, table_name) @@ -554,6 +580,21 @@ def test_commit_table_update_schema( assert new_schema == update._apply() assert new_schema.find_field("b").field_type == IntegerType() + # Ensure schema is also pushed to Glue + table_info = _glue.get_table( + DatabaseName=database_name, + Name=table_name, + ) + storage_descriptor = table_info["Table"]["StorageDescriptor"] + columns = storage_descriptor["Columns"] + assert len(columns) == len(table_schema_nested.fields) + 1 + assert columns[-1] == { + "Name": "b", + "Type": "int", + "Parameters": {"iceberg.field.id": "18", "iceberg.field.optional": "true", "iceberg.field.current": "true"}, + } + assert storage_descriptor["Location"] == f"s3://{BUCKET_NAME}/{database_name}.db/{table_name}" + @mock_glue def test_commit_table_properties(