Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions sqlglot/dialects/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,18 @@ def _parse_generated_as_identity(

return this

def _parse_user_defined_type(
self, identifier: exp.Identifier
) -> t.Optional[exp.Expression]:
udt_type: exp.Identifier | exp.Dot = identifier

while self._match(TokenType.DOT):
part = self._parse_id_var()
if part:
udt_type = exp.Dot(this=udt_type, expression=part)

return exp.DataType.build(udt_type, udt=True)

class Generator(generator.Generator):
SINGLE_STRING_INTERVAL = True
RENAME_TABLE_WITH_DB = False
Expand Down
8 changes: 5 additions & 3 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4753,6 +4753,8 @@ def build(
if udt:
return DataType(this=DataType.Type.USERDEFINED, kind=dtype, **kwargs)
raise
elif isinstance(dtype, (Identifier, Dot)) and udt:
return DataType(this=DataType.Type.USERDEFINED, kind=dtype, **kwargs)
elif isinstance(dtype, DataType.Type):
data_type_exp = DataType(this=dtype)
elif isinstance(dtype, DataType):
Expand Down Expand Up @@ -4794,9 +4796,6 @@ def is_type(self, *dtypes: DATA_TYPE, check_nullable: bool = False) -> bool:
return False


DATA_TYPE = t.Union[str, DataType, DataType.Type]


# https://www.postgresql.org/docs/15/datatype-pseudo.html
class PseudoType(DataType):
arg_types = {"this": True}
Expand Down Expand Up @@ -5030,6 +5029,9 @@ def parts(self) -> t.List[Expression]:
return parts


DATA_TYPE = t.Union[str, Identifier, Dot, DataType, DataType.Type]


class DPipe(Binary):
arg_types = {"this": True, "expression": True, "safe": False}

Expand Down
15 changes: 9 additions & 6 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5121,6 +5121,14 @@ def _parse_type_size(self) -> t.Optional[exp.DataTypeParam]:
exp.DataTypeParam, this=this, expression=self._parse_var(any_token=True)
)

def _parse_user_defined_type(self, identifier: exp.Identifier) -> t.Optional[exp.Expression]:
type_name = identifier.name

while self._match(TokenType.DOT):
type_name = f"{type_name}.{self._advance_any() and self._prev.text}"

return exp.DataType.build(type_name, udt=True)

def _parse_types(
self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True
) -> t.Optional[exp.Expression]:
Expand All @@ -5142,12 +5150,7 @@ def _parse_types(
if tokens[0].token_type in self.TYPE_TOKENS:
self._prev = tokens[0]
elif self.dialect.SUPPORTS_USER_DEFINED_TYPES:
type_name = identifier.name

while self._match(TokenType.DOT):
type_name = f"{type_name}.{self._advance_any() and self._prev.text}"

this = exp.DataType.build(type_name, udt=True)
this = self._parse_user_defined_type(identifier)
else:
self._retreat(self._index - 1)
return None
Expand Down
11 changes: 11 additions & 0 deletions tests/dialects/test_postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -1449,3 +1449,14 @@ def test_json_extract(self):
"clickhouse": "SELECT JSONExtractString(foo, '12')",
},
)

def test_udt(self):
def _validate_udt(sql: str):
self.validate_identity(sql).to.assert_is(exp.DataType)

_validate_udt("CAST(5 AS MyType)")
_validate_udt('CAST(5 AS "MyType")')
_validate_udt("CAST(5 AS MySchema.MyType)")
_validate_udt('CAST(5 AS "MySchema"."MyType")')
_validate_udt('CAST(5 AS MySchema."MyType")')
_validate_udt('CAST(5 AS "MyCatalog"."MySchema"."MyType")')