Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 110 additions & 3 deletions pyiceberg/catalog/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from typing import (
Any,
Dict,
List,
Optional,
Set,
Expand All @@ -28,6 +29,7 @@
import boto3
from mypy_boto3_glue.client import GlueClient
from mypy_boto3_glue.type_defs import (
ColumnTypeDef,
DatabaseInputTypeDef,
DatabaseTypeDef,
StorageDescriptorTypeDef,
Expand Down Expand Up @@ -59,12 +61,32 @@
)
from pyiceberg.io import load_file_io
from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionSpec
from pyiceberg.schema import Schema
from pyiceberg.schema import Schema, SchemaVisitor, visit
from pyiceberg.serializers import FromInputFile
from pyiceberg.table import CommitTableRequest, CommitTableResponse, Table, update_table_metadata
from pyiceberg.table.metadata import new_table_metadata
from pyiceberg.table.metadata import TableMetadata, new_table_metadata
from pyiceberg.table.sorting import UNSORTED_SORT_ORDER, SortOrder
from pyiceberg.typedef import EMPTY_DICT
from pyiceberg.types import (
BinaryType,
BooleanType,
DateType,
DecimalType,
DoubleType,
FixedType,
FloatType,
IntegerType,
ListType,
LongType,
MapType,
NestedField,
PrimitiveType,
StringType,
StructType,
TimestampType,
TimeType,
UUIDType,
)

# If Glue should skip archiving an old table version when creating a new version in a commit. By
# default, Glue archives all old table versions after an UpdateTable call, but Glue has a default
Expand All @@ -73,6 +95,10 @@
GLUE_SKIP_ARCHIVE = "glue.skip-archive"
GLUE_SKIP_ARCHIVE_DEFAULT = True

ICEBERG_FIELD_ID = "iceberg.field.id"
ICEBERG_FIELD_OPTIONAL = "iceberg.field.optional"
ICEBERG_FIELD_CURRENT = "iceberg.field.current"


def _construct_parameters(
metadata_location: str, glue_table: Optional[TableTypeDef] = None, prev_metadata_location: Optional[str] = None
Expand All @@ -84,17 +110,97 @@ def _construct_parameters(
return new_parameters


GLUE_PRIMITIVE_TYPES = {
BooleanType: "boolean",
IntegerType: "int",
LongType: "bigint",
FloatType: "float",
DoubleType: "double",
DateType: "date",
TimeType: "string",
StringType: "string",
UUIDType: "string",
TimestampType: "timestamp",
FixedType: "binary",
BinaryType: "binary",
}


class _IcebergSchemaToGlueType(SchemaVisitor[str]):
def schema(self, schema: Schema, struct_result: str) -> str:
return struct_result

def struct(self, struct: StructType, field_results: List[str]) -> str:
return f"struct<{','.join(field_results)}>"

def field(self, field: NestedField, field_result: str) -> str:
return f"{field.name}:{field_result}"

def list(self, list_type: ListType, element_result: str) -> str:
return f"array<{element_result}>"

def map(self, map_type: MapType, key_result: str, value_result: str) -> str:
return f"map<{key_result},{value_result}>"

def primitive(self, primitive: PrimitiveType) -> str:
if isinstance(primitive, DecimalType):
return f"decimal({primitive.precision},{primitive.scale})"
if (primitive_type := type(primitive)) not in GLUE_PRIMITIVE_TYPES:
raise ValueError(f"Unknown primitive type: {primitive}")
return GLUE_PRIMITIVE_TYPES[primitive_type]


def _to_columns(metadata: TableMetadata) -> List[ColumnTypeDef]:
results: Dict[str, ColumnTypeDef] = {}

def _append_to_results(field: NestedField, is_current: bool) -> None:
if field.name in results:
return

results[field.name] = cast(
ColumnTypeDef,
{
"Name": field.name,
"Type": visit(field.field_type, _IcebergSchemaToGlueType()),
"Parameters": {
ICEBERG_FIELD_ID: str(field.field_id),
ICEBERG_FIELD_OPTIONAL: str(field.optional).lower(),
ICEBERG_FIELD_CURRENT: str(is_current).lower(),
},
},
)
if field.doc:
results[field.name]["Comment"] = field.doc

if current_schema := metadata.schema_by_id(metadata.current_schema_id):
for field in current_schema.columns:
_append_to_results(field, True)

for schema in metadata.schemas:
if schema.schema_id == metadata.current_schema_id:
continue
for field in schema.columns:
_append_to_results(field, False)

return list(results.values())


def _construct_table_input(
table_name: str,
metadata_location: str,
properties: Properties,
metadata: TableMetadata,
glue_table: Optional[TableTypeDef] = None,
prev_metadata_location: Optional[str] = None,
) -> TableInputTypeDef:
table_input: TableInputTypeDef = {
"Name": table_name,
"TableType": EXTERNAL_TABLE,
"Parameters": _construct_parameters(metadata_location, glue_table, prev_metadata_location),
"StorageDescriptor": {
"Columns": _to_columns(metadata),
"Location": metadata.location,
},
}

if "Description" in properties:
Expand Down Expand Up @@ -258,7 +364,7 @@ def create_table(
io = load_file_io(properties=self.properties, location=metadata_location)
self._write_metadata(metadata, io, metadata_location)

table_input = _construct_table_input(table_name, metadata_location, properties)
table_input = _construct_table_input(table_name, metadata_location, properties, metadata)
database_name, table_name = self.identifier_to_database_and_table(identifier)
self._create_glue_table(database_name=database_name, table_name=table_name, table_input=table_input)

Expand Down Expand Up @@ -322,6 +428,7 @@ def _commit_table(self, table_request: CommitTableRequest) -> CommitTableRespons
table_name=table_name,
metadata_location=new_metadata_location,
properties=current_table.properties,
metadata=updated_metadata,
glue_table=current_glue_table,
prev_metadata_location=current_table.metadata_location,
)
Expand Down
162 changes: 159 additions & 3 deletions tests/catalog/integration_test_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
# specific language governing permissions and limitations
# under the License.

from typing import Generator, List
import time
from typing import Any, Dict, Generator, List
from uuid import uuid4

import boto3
import pyarrow as pa
import pytest
from botocore.exceptions import ClientError

Expand All @@ -30,6 +33,7 @@
NoSuchTableError,
TableAlreadyExistsError,
)
from pyiceberg.io.pyarrow import schema_to_pyarrow
from pyiceberg.schema import Schema
from pyiceberg.types import IntegerType
from tests.conftest import clean_up, get_bucket_name, get_s3_path
Expand All @@ -52,8 +56,62 @@ def fixture_test_catalog() -> Generator[Catalog, None, None]:
clean_up(test_catalog)


class AthenaQueryHelper:
_athena_client: boto3.client
_s3_resource: boto3.resource
_output_bucket: str
_output_path: str

def __init__(self) -> None:
self._s3_resource = boto3.resource("s3")
self._athena_client = boto3.client("athena")
self._output_bucket = get_bucket_name()
self._output_path = f"athena_results_{uuid4()}"

def get_query_results(self, query: str) -> List[Dict[str, Any]]:
query_execution_id = self._athena_client.start_query_execution(
QueryString=query, ResultConfiguration={"OutputLocation": f"s3://{self._output_bucket}/{self._output_path}"}
)["QueryExecutionId"]

while True:
result = self._athena_client.get_query_execution(QueryExecutionId=query_execution_id)["QueryExecution"]["Status"]
query_status = result["State"]
assert query_status not in [
"FAILED",
"CANCELLED",
], f"""
Athena query with the string failed or was cancelled:
Query: {query}
Status: {query_status}
Reason: {result["StateChangeReason"]}"""

if query_status not in ["QUEUED", "RUNNING"]:
break
time.sleep(0.5)

# No pagination for now, assume that we are not doing large queries
return self._athena_client.get_query_results(QueryExecutionId=query_execution_id)["ResultSet"]["Rows"]

def clean_up(self) -> None:
bucket = self._s3_resource.Bucket(self._output_bucket)
for obj in bucket.objects.filter(Prefix=f"{self._output_path}/"):
self._s3_resource.Object(bucket.name, obj.key).delete()


@pytest.fixture(name="athena", scope="module")
def fixture_athena_helper() -> Generator[AthenaQueryHelper, None, None]:
query_helper = AthenaQueryHelper()
yield query_helper
query_helper.clean_up()


def test_create_table(
test_catalog: Catalog, s3: boto3.client, table_schema_nested: Schema, table_name: str, database_name: str
test_catalog: Catalog,
s3: boto3.client,
table_schema_nested: Schema,
table_name: str,
database_name: str,
athena: AthenaQueryHelper,
) -> None:
identifier = (database_name, table_name)
test_catalog.create_namespace(database_name)
Expand All @@ -64,6 +122,48 @@ def test_create_table(
s3.head_object(Bucket=get_bucket_name(), Key=metadata_location)
assert test_catalog._parse_metadata_version(table.metadata_location) == 0

table.append(
pa.Table.from_pylist(
[
{
"foo": "foo_val",
"bar": 1,
"baz": False,
"qux": ["x", "y"],
"quux": {"key": {"subkey": 2}},
"location": [{"latitude": 1.1}],
"person": {"name": "some_name", "age": 23},
}
],
schema=schema_to_pyarrow(table.schema()),
),
)

assert athena.get_query_results(f'SELECT * FROM "{database_name}"."{table_name}"') == [
{
"Data": [
{"VarCharValue": "foo"},
{"VarCharValue": "bar"},
{"VarCharValue": "baz"},
{"VarCharValue": "qux"},
{"VarCharValue": "quux"},
{"VarCharValue": "location"},
{"VarCharValue": "person"},
]
},
{
"Data": [
{"VarCharValue": "foo_val"},
{"VarCharValue": "1"},
{"VarCharValue": "false"},
{"VarCharValue": "[x, y]"},
{"VarCharValue": "{key={subkey=2}}"},
{"VarCharValue": "[{latitude=1.1, longitude=null}]"},
{"VarCharValue": "{name=some_name, age=23}"},
]
},
]


def test_create_table_with_invalid_location(table_schema_nested: Schema, table_name: str, database_name: str) -> None:
identifier = (database_name, table_name)
Expand Down Expand Up @@ -269,7 +369,7 @@ def test_update_namespace_properties(test_catalog: Catalog, database_name: str)


def test_commit_table_update_schema(
test_catalog: Catalog, table_schema_nested: Schema, database_name: str, table_name: str
test_catalog: Catalog, table_schema_nested: Schema, database_name: str, table_name: str, athena: AthenaQueryHelper
) -> None:
identifier = (database_name, table_name)
test_catalog.create_namespace(namespace=database_name)
Expand All @@ -279,6 +379,20 @@ def test_commit_table_update_schema(
assert test_catalog._parse_metadata_version(table.metadata_location) == 0
assert original_table_metadata.current_schema_id == 0

assert athena.get_query_results(f'SELECT * FROM "{database_name}"."{table_name}"') == [
{
"Data": [
{"VarCharValue": "foo"},
{"VarCharValue": "bar"},
{"VarCharValue": "baz"},
{"VarCharValue": "qux"},
{"VarCharValue": "quux"},
{"VarCharValue": "location"},
{"VarCharValue": "person"},
]
}
]

transaction = table.transaction()
update = transaction.update_schema()
update.add_column(path="b", field_type=IntegerType())
Expand All @@ -295,6 +409,48 @@ def test_commit_table_update_schema(
assert new_schema == update._apply()
assert new_schema.find_field("b").field_type == IntegerType()

table.append(
pa.Table.from_pylist(
[
{
"foo": "foo_val",
"bar": 1,
"location": [{"latitude": 1.1}],
"person": {"name": "some_name", "age": 23},
"b": 2,
}
],
schema=schema_to_pyarrow(new_schema),
),
)

assert athena.get_query_results(f'SELECT * FROM "{database_name}"."{table_name}"') == [
{
"Data": [
{"VarCharValue": "foo"},
{"VarCharValue": "bar"},
{"VarCharValue": "baz"},
{"VarCharValue": "qux"},
{"VarCharValue": "quux"},
{"VarCharValue": "location"},
{"VarCharValue": "person"},
{"VarCharValue": "b"},
]
},
{
"Data": [
{"VarCharValue": "foo_val"},
{"VarCharValue": "1"},
{},
{"VarCharValue": "[]"},
{"VarCharValue": "{}"},
{"VarCharValue": "[{latitude=1.1, longitude=null}]"},
{"VarCharValue": "{name=some_name, age=23}"},
{"VarCharValue": "2"},
]
},
]


def test_commit_table_properties(test_catalog: Catalog, table_schema_nested: Schema, database_name: str, table_name: str) -> None:
identifier = (database_name, table_name)
Expand Down
Loading