From e88c5ad51f8edc7872098070f28fc15b7c5f5b20 Mon Sep 17 00:00:00 2001 From: vaggelisd Date: Thu, 12 Jun 2025 12:06:40 +0300 Subject: [PATCH] fix(postgres): Preserve quoting for UDT --- sqlglot/dialects/postgres.py | 12 ++++++++++++ sqlglot/expressions.py | 8 +++++--- sqlglot/parser.py | 15 +++++++++------ tests/dialects/test_postgres.py | 11 +++++++++++ 4 files changed, 37 insertions(+), 9 deletions(-) diff --git a/sqlglot/dialects/postgres.py b/sqlglot/dialects/postgres.py index 8ddfa22946..037916162c 100644 --- a/sqlglot/dialects/postgres.py +++ b/sqlglot/dialects/postgres.py @@ -512,6 +512,18 @@ def _parse_generated_as_identity( return this + def _parse_user_defined_type( + self, identifier: exp.Identifier + ) -> t.Optional[exp.Expression]: + udt_type: exp.Identifier | exp.Dot = identifier + + while self._match(TokenType.DOT): + part = self._parse_id_var() + if part: + udt_type = exp.Dot(this=udt_type, expression=part) + + return exp.DataType.build(udt_type, udt=True) + class Generator(generator.Generator): SINGLE_STRING_INTERVAL = True RENAME_TABLE_WITH_DB = False diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index fbe483970e..187ca48182 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -4753,6 +4753,8 @@ def build( if udt: return DataType(this=DataType.Type.USERDEFINED, kind=dtype, **kwargs) raise + elif isinstance(dtype, (Identifier, Dot)) and udt: + return DataType(this=DataType.Type.USERDEFINED, kind=dtype, **kwargs) elif isinstance(dtype, DataType.Type): data_type_exp = DataType(this=dtype) elif isinstance(dtype, DataType): @@ -4794,9 +4796,6 @@ def is_type(self, *dtypes: DATA_TYPE, check_nullable: bool = False) -> bool: return False -DATA_TYPE = t.Union[str, DataType, DataType.Type] - - # https://www.postgresql.org/docs/15/datatype-pseudo.html class PseudoType(DataType): arg_types = {"this": True} @@ -5030,6 +5029,9 @@ def parts(self) -> t.List[Expression]: return parts +DATA_TYPE = t.Union[str, Identifier, Dot, DataType, DataType.Type] + + class DPipe(Binary): arg_types = {"this": True, "expression": True, "safe": False} diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 405fd62c43..3d36cc0e6d 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -5121,6 +5121,14 @@ def _parse_type_size(self) -> t.Optional[exp.DataTypeParam]: exp.DataTypeParam, this=this, expression=self._parse_var(any_token=True) ) + def _parse_user_defined_type(self, identifier: exp.Identifier) -> t.Optional[exp.Expression]: + type_name = identifier.name + + while self._match(TokenType.DOT): + type_name = f"{type_name}.{self._advance_any() and self._prev.text}" + + return exp.DataType.build(type_name, udt=True) + def _parse_types( self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True ) -> t.Optional[exp.Expression]: @@ -5142,12 +5150,7 @@ def _parse_types( if tokens[0].token_type in self.TYPE_TOKENS: self._prev = tokens[0] elif self.dialect.SUPPORTS_USER_DEFINED_TYPES: - type_name = identifier.name - - while self._match(TokenType.DOT): - type_name = f"{type_name}.{self._advance_any() and self._prev.text}" - - this = exp.DataType.build(type_name, udt=True) + this = self._parse_user_defined_type(identifier) else: self._retreat(self._index - 1) return None diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index e5827fd4eb..54e65dd62f 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -1449,3 +1449,14 @@ def test_json_extract(self): "clickhouse": "SELECT JSONExtractString(foo, '12')", }, ) + + def test_udt(self): + def _validate_udt(sql: str): + self.validate_identity(sql).to.assert_is(exp.DataType) + + _validate_udt("CAST(5 AS MyType)") + _validate_udt('CAST(5 AS "MyType")') + _validate_udt("CAST(5 AS MySchema.MyType)") + _validate_udt('CAST(5 AS "MySchema"."MyType")') + _validate_udt('CAST(5 AS MySchema."MyType")') + _validate_udt('CAST(5 AS "MyCatalog"."MySchema"."MyType")')