Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 168 additions & 29 deletions sqlmesh/core/engine_adapter/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
import typing as t
from collections import defaultdict
from collections import defaultdict, OrderedDict

from sqlglot import exp, parse_one
from sqlglot.transforms import remove_precision_parameterized_types
Expand Down Expand Up @@ -169,18 +169,17 @@ def _df_to_source_queries(
)

def query_factory() -> Query:
ordered_df = df[list(source_columns_to_types)]
if bigframes_pd and isinstance(ordered_df, bigframes_pd.DataFrame):
ordered_df.to_gbq(
if bigframes_pd and isinstance(df, bigframes_pd.DataFrame):
df.to_gbq(
f"{temp_bq_table.project}.{temp_bq_table.dataset_id}.{temp_bq_table.table_id}",
if_exists="replace",
)
elif not self.table_exists(temp_table):
# Make mypy happy
assert isinstance(ordered_df, pd.DataFrame)
assert isinstance(df, pd.DataFrame)
self._db_call(self.client.create_table, table=temp_bq_table, exists_ok=False)
result = self.__load_pandas_to_table(
temp_bq_table, ordered_df, source_columns_to_types, replace=False
temp_bq_table, df, source_columns_to_types, replace=False
)
if result.errors:
raise SQLMeshError(result.errors)
Expand Down Expand Up @@ -755,28 +754,6 @@ def table_exists(self, table_name: TableName) -> bool:
except NotFound:
return False

def get_table_last_modified_ts(self, table_names: t.List[TableName]) -> t.List[int]:
from sqlmesh.utils.date import to_timestamp

datasets_to_tables: t.DefaultDict[str, t.List[str]] = defaultdict(list)
for table_name in table_names:
table = exp.to_table(table_name)
datasets_to_tables[table.db].append(table.name)

results = []

for dataset, tables in datasets_to_tables.items():
query = (
f"SELECT TIMESTAMP_MILLIS(last_modified_time) FROM `{dataset}.__TABLES__` WHERE "
)
for i, table_name in enumerate(tables):
query += f"TABLE_ID = '{table_name}'"
if i < len(tables) - 1:
query += " OR "
results.extend(self.fetchall(query))

return [to_timestamp(row[0]) for row in results]

def _get_table(self, table_name: TableName) -> BigQueryTable:
"""
Returns a BigQueryTable object for the given table name.
Expand Down Expand Up @@ -891,6 +868,60 @@ def _build_partitioned_by_exp(

return exp.PartitionedByProperty(this=this)

def _create_table(
self,
table_name_or_schema: t.Union[exp.Schema, TableName],
expression: t.Optional[exp.Expression],
exists: bool = True,
replace: bool = False,
target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None,
table_description: t.Optional[str] = None,
column_descriptions: t.Optional[t.Dict[str, str]] = None,
table_kind: t.Optional[str] = None,
track_rows_processed: bool = True,
**kwargs: t.Any,
) -> None:
normalized_properties, connection_property = self._prepare_create_table_properties(
kwargs.get("table_properties"),
kwargs.get("table_format"),
kwargs.get("storage_format"),
)
kwargs["table_properties"] = normalized_properties

if connection_property is None:
super()._create_table(
table_name_or_schema,
expression,
exists=exists,
replace=replace,
target_columns_to_types=target_columns_to_types,
table_description=table_description,
column_descriptions=column_descriptions,
table_kind=table_kind,
track_rows_processed=track_rows_processed,
**kwargs,
)
return

create_expression = self._build_create_table_exp(
table_name_or_schema,
expression=expression,
exists=exists,
replace=replace,
target_columns_to_types=target_columns_to_types,
table_description=(
table_description
if self.COMMENT_CREATION_TABLE.supports_schema_def and self.comments_enabled
else None
),
table_kind=table_kind,
**kwargs,
)
sql = self._to_sql(create_expression)
connection_sql = self._connection_clause_sql(connection_property)
sql = self._inject_connection_clause(sql, connection_sql)
self.execute(sql, track_rows_processed=track_rows_processed)

def _build_table_properties_exp(
self,
catalog_name: t.Optional[str] = None,
Expand Down Expand Up @@ -926,12 +957,120 @@ def _build_table_properties_exp(
),
)

properties.extend(self._table_or_view_properties_to_expressions(table_properties))
if table_properties:
for key, value in table_properties.items():
properties.append(exp.Property(this=key, value=value.copy()))

if properties:
return exp.Properties(expressions=properties)
return None

def _prepare_create_table_properties(
self,
table_properties: t.Optional[t.Dict[str, exp.Expression]],
table_format: t.Optional[str],
storage_format: t.Optional[str],
) -> t.Tuple[OrderedDict[str, exp.Expression], t.Optional[exp.Expression]]:
normalized_properties: OrderedDict[str, exp.Expression] = OrderedDict()
connection_property: t.Optional[exp.Expression] = None

if table_properties:
for key, value in table_properties.items():
if value is None:
continue
key_lower = key.lower()
if key_lower in {"connection", "with_connection"}:
connection_property = value
continue
# Reinsert properties with the latest casing while preserving order
for existing_key in list(normalized_properties.keys()):
if existing_key.lower() == key_lower:
normalized_properties.pop(existing_key)
break
normalized_properties[key] = value.copy()

def _get_property(name: str) -> t.Optional[exp.Expression]:
for existing_key, value in normalized_properties.items():
if existing_key.lower() == name:
return value
return None

def _set_property(name: str, expression: exp.Expression) -> None:
for existing_key in list(normalized_properties.keys()):
if existing_key.lower() == name:
normalized_properties.pop(existing_key)
break
normalized_properties[name] = expression

def _has_property(name: str) -> bool:
return any(existing_key.lower() == name for existing_key in normalized_properties)

normalized_table_format = table_format.lower() if table_format else None
if not normalized_table_format:
existing_table_format = _get_property("table_format")
if isinstance(existing_table_format, exp.Literal) and existing_table_format.is_string:
normalized_table_format = existing_table_format.this.lower()
is_iceberg = normalized_table_format == "iceberg"

if is_iceberg:
table_format_expression = self._ensure_upper_string_literal(
_get_property("table_format"),
default=normalized_table_format or "iceberg",
)
_set_property("table_format", table_format_expression)

file_format_expression = self._ensure_upper_string_literal(
_get_property("file_format"),
default=storage_format or "PARQUET",
)
_set_property("file_format", file_format_expression)

if not _has_property("storage_uri"):
raise SQLMeshError(
"BigQuery Iceberg tables require `storage_uri` to be set in physical_properties."
)

if connection_property is None:
raise SQLMeshError(
"BigQuery Iceberg tables require a `connection` entry in physical_properties."
)

return normalized_properties, connection_property

def _ensure_upper_string_literal(
self,
expression: t.Optional[exp.Expression],
default: str,
) -> exp.Expression:
if expression is None:
return exp.Literal.string(default.upper())

expression = expression.copy()
if isinstance(expression, exp.Literal) and expression.is_string:
return exp.Literal.string(expression.this.upper())
return expression

def _connection_clause_sql(self, connection_expression: exp.Expression) -> str:
expression = connection_expression.copy()
if isinstance(expression, exp.Literal) and expression.is_string:
value = expression.this.strip()
if value.upper() == "DEFAULT":
return "DEFAULT"
return exp.to_identifier(value, quoted=True).sql(dialect=self.dialect)

return self._to_sql(expression)

@staticmethod
def _inject_connection_clause(create_sql: str, connection_sql: str) -> str:
parts = create_sql.split("OPTIONS", 1)
if len(parts) == 2:
prefix, suffix = parts
if not prefix.endswith(" "):
prefix = f"{prefix} "
return f"{prefix}WITH CONNECTION {connection_sql} OPTIONS{suffix}"
separator = " " if not create_sql.endswith(" ") else ""
return f"{create_sql}{separator}WITH CONNECTION {connection_sql}"

def _build_column_def(
self,
col_name: str,
Expand Down