diff --git a/docs/integrations/engines/trino.md b/docs/integrations/engines/trino.md index ec1139e20d..db732f0cc1 100644 --- a/docs/integrations/engines/trino.md +++ b/docs/integrations/engines/trino.md @@ -90,6 +90,7 @@ hive.metastore.glue.default-warehouse-dir=s3://my-bucket/ | `http_scheme` | The HTTP scheme to use when connecting to your cluster. By default, it's `https` and can only be `http` for no-auth or basic auth. | string | N | | `port` | The port to connect to your cluster. By default, it's `443` for `https` scheme and `80` for `http` | int | N | | `roles` | Mapping of catalog name to a role | dict | N | +| `source` | Value to send as Trino's `source` field for query attribution / auditing. Default: `sqlmesh`. | string | N | | `http_headers` | Additional HTTP headers to send with each request. | dict | N | | `session_properties` | Trino session properties. Run `SHOW SESSION` to see all options. | dict | N | | `retries` | Number of retries to attempt when a request fails. Default: `3` | int | N | diff --git a/sqlmesh/core/config/connection.py b/sqlmesh/core/config/connection.py index 638f0c28c8..4e11fc626f 100644 --- a/sqlmesh/core/config/connection.py +++ b/sqlmesh/core/config/connection.py @@ -1888,6 +1888,7 @@ class TrinoConnectionConfig(ConnectionConfig): client_certificate: t.Optional[str] = None client_private_key: t.Optional[str] = None cert: t.Optional[str] = None + source: str = "sqlmesh" # SQLMesh options schema_location_mapping: t.Optional[dict[re.Pattern, str]] = None @@ -1984,6 +1985,7 @@ def _connection_kwargs_keys(self) -> t.Set[str]: "port", "catalog", "roles", + "source", "http_scheme", "http_headers", "session_properties", @@ -2041,7 +2043,7 @@ def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: "user": self.impersonation_user or self.user, "max_attempts": self.retries, "verify": self.cert if self.cert is not None else self.verify, - "source": "sqlmesh", + "source": self.source, } @property diff --git a/tests/core/engine_adapter/test_trino.py b/tests/core/engine_adapter/test_trino.py index a3c67eb023..1bfe82b858 100644 --- a/tests/core/engine_adapter/test_trino.py +++ b/tests/core/engine_adapter/test_trino.py @@ -412,6 +412,8 @@ def test_timestamp_mapping(): catalog="catalog", ) + assert config._connection_factory_with_kwargs.keywords["source"] == "sqlmesh" + adapter = config.create_engine_adapter() assert adapter.timestamp_mapping is None @@ -419,11 +421,13 @@ def test_timestamp_mapping(): user="user", host="host", catalog="catalog", + source="my_source", timestamp_mapping={ "TIMESTAMP": "TIMESTAMP(6)", "TIMESTAMP(3)": "TIMESTAMP WITH TIME ZONE", }, ) + assert config._connection_factory_with_kwargs.keywords["source"] == "my_source" adapter = config.create_engine_adapter() assert adapter.timestamp_mapping is not None assert adapter.timestamp_mapping[exp.DataType.build("TIMESTAMP")] == exp.DataType.build( diff --git a/tests/core/test_config.py b/tests/core/test_config.py index d0fad16e76..f3a0de6672 100644 --- a/tests/core/test_config.py +++ b/tests/core/test_config.py @@ -862,6 +862,39 @@ def test_trino_schema_location_mapping_syntax(tmp_path): assert len(conn.schema_location_mapping) == 2 +def test_trino_source_option(tmp_path): + config_path = tmp_path / "config_trino_source.yaml" + with open(config_path, "w", encoding="utf-8") as fd: + fd.write( + """ + gateways: + trino: + connection: + type: trino + user: trino + host: trino + catalog: trino + source: my_sqlmesh_source + + default_gateway: trino + + model_defaults: + dialect: trino + """ + ) + + config = load_config_from_paths( + Config, + project_paths=[config_path], + ) + + from sqlmesh.core.config.connection import TrinoConnectionConfig + + conn = config.gateways["trino"].connection + assert isinstance(conn, TrinoConnectionConfig) + assert conn.source == "my_sqlmesh_source" + + def test_gcp_postgres_ip_and_scopes(tmp_path): config_path = tmp_path / "config_gcp_postgres.yaml" with open(config_path, "w", encoding="utf-8") as fd: