diff --git a/sqlglot/dialects/tsql.py b/sqlglot/dialects/tsql.py index e40f46650e..10a2862bcb 100644 --- a/sqlglot/dialects/tsql.py +++ b/sqlglot/dialects/tsql.py @@ -244,7 +244,7 @@ def _string_agg_sql(self: TSQL.Generator, expression: exp.GroupConcat) -> str: def _build_date_delta( - exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None + exp_class: t.Type[E], unit_mapping: t.Optional[t.Dict[str, str]] = None, big_int: bool = False ) -> t.Callable[[t.List], E]: def _builder(args: t.List) -> E: unit = seq_get(args, 0) @@ -260,12 +260,15 @@ def _builder(args: t.List) -> E: else: # We currently don't handle float values, i.e. they're not converted to equivalent DATETIMEs. # This is not a problem when generating T-SQL code, it is when transpiling to other dialects. - return exp_class(this=seq_get(args, 2), expression=start_date, unit=unit) + return exp_class( + this=seq_get(args, 2), expression=start_date, unit=unit, big_int=big_int + ) return exp_class( this=exp.TimeStrToTime(this=seq_get(args, 2)), expression=exp.TimeStrToTime(this=start_date), unit=unit, + big_int=big_int, ) return _builder @@ -597,6 +600,9 @@ class Parser(parser.Parser): ), "DATEADD": build_date_delta(exp.DateAdd, unit_mapping=DATE_DELTA_INTERVAL), "DATEDIFF": _build_date_delta(exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL), + "DATEDIFF_BIG": _build_date_delta( + exp.DateDiff, unit_mapping=DATE_DELTA_INTERVAL, big_int=True + ), "DATENAME": _build_formatted_time(exp.TimeToStr, full_format_mapping=True), "DATEPART": _build_formatted_time(exp.TimeToStr), "DATETIMEFROMPARTS": _build_datetimefromparts, @@ -1033,7 +1039,6 @@ class Generator(generator.Generator): exp.AutoIncrementColumnConstraint: lambda *_: "IDENTITY", exp.Chr: rename_func("CHAR"), exp.DateAdd: date_delta_sql("DATEADD"), - exp.DateDiff: date_delta_sql("DATEDIFF"), exp.CTE: transforms.preprocess([qualify_derived_table_outputs]), exp.CurrentDate: rename_func("GETDATE"), exp.CurrentTimestamp: rename_func("GETDATE"), @@ -1299,6 +1304,10 @@ def count_sql(self, expression: exp.Count) -> str: func_name = "COUNT_BIG" if expression.args.get("big_int") else "COUNT" return rename_func(func_name)(self, expression) + def datediff_sql(self, expression: exp.DateDiff) -> str: + func_name = "DATEDIFF_BIG" if expression.args.get("big_int") else "DATEDIFF" + return date_delta_sql(func_name)(self, expression) + def offset_sql(self, expression: exp.Offset) -> str: return f"{super().offset_sql(expression)} ROWS" diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index c09adbfb65..372501b576 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -6267,7 +6267,7 @@ class DateSub(Func, IntervalOp): class DateDiff(Func, TimeUnit): _sql_names = ["DATEDIFF", "DATE_DIFF"] - arg_types = {"this": True, "expression": True, "unit": False, "zone": False} + arg_types = {"this": True, "expression": True, "unit": False, "zone": False, "big_int": False} class DateTrunc(Func): diff --git a/sqlglot/typing/__init__.py b/sqlglot/typing/__init__.py index 6aab788a8a..b216590d55 100644 --- a/sqlglot/typing/__init__.py +++ b/sqlglot/typing/__init__.py @@ -111,7 +111,6 @@ exp.Ascii, exp.Ceil, exp.DatetimeDiff, - exp.DateDiff, exp.TimestampDiff, exp.TimeDiff, exp.Unicode, @@ -273,6 +272,11 @@ e, exp.DataType.Type.BIGINT if e.args.get("big_int") else exp.DataType.Type.INT ) }, + exp.DateDiff: { + "annotator": lambda self, e: self._annotate_with_type( + e, exp.DataType.Type.BIGINT if e.args.get("big_int") else exp.DataType.Type.INT + ) + }, exp.DataType: {"annotator": lambda self, e: self._annotate_with_type(e, e.copy())}, exp.Div: {"annotator": lambda self, e: self._annotate_div(e)}, exp.Distinct: {"annotator": lambda self, e: self._annotate_by_args(e, "expressions")}, diff --git a/tests/dialects/test_tsql.py b/tests/dialects/test_tsql.py index 683bd4da8e..afecf99e56 100644 --- a/tests/dialects/test_tsql.py +++ b/tests/dialects/test_tsql.py @@ -1784,62 +1784,64 @@ def test_add_date(self): def test_date_diff(self): self.validate_identity("SELECT DATEDIFF(HOUR, 1.5, '2021-01-01')") + self.validate_identity("SELECT DATEDIFF_BIG(HOUR, 1.5, '2021-01-01')") - self.validate_all( - "SELECT DATEDIFF(quarter, 0, '2021-01-01')", - write={ - "tsql": "SELECT DATEDIFF(QUARTER, CAST('1900-01-01' AS DATETIME2), CAST('2021-01-01' AS DATETIME2))", - "spark": "SELECT DATEDIFF(QUARTER, CAST('1900-01-01' AS TIMESTAMP), CAST('2021-01-01' AS TIMESTAMP))", - "duckdb": "SELECT DATE_DIFF('QUARTER', CAST('1900-01-01' AS TIMESTAMP), CAST('2021-01-01' AS TIMESTAMP))", - }, - ) - self.validate_all( - "SELECT DATEDIFF(day, 1, '2021-01-01')", - write={ - "tsql": "SELECT DATEDIFF(DAY, CAST('1900-01-02' AS DATETIME2), CAST('2021-01-01' AS DATETIME2))", - "spark": "SELECT DATEDIFF(DAY, CAST('1900-01-02' AS TIMESTAMP), CAST('2021-01-01' AS TIMESTAMP))", - "duckdb": "SELECT DATE_DIFF('DAY', CAST('1900-01-02' AS TIMESTAMP), CAST('2021-01-01' AS TIMESTAMP))", - }, - ) - self.validate_all( - "SELECT DATEDIFF(year, '2020-01-01', '2021-01-01')", - write={ - "tsql": "SELECT DATEDIFF(YEAR, CAST('2020-01-01' AS DATETIME2), CAST('2021-01-01' AS DATETIME2))", - "spark": "SELECT DATEDIFF(YEAR, CAST('2020-01-01' AS TIMESTAMP), CAST('2021-01-01' AS TIMESTAMP))", - "spark2": "SELECT CAST(MONTHS_BETWEEN(CAST('2021-01-01' AS TIMESTAMP), CAST('2020-01-01' AS TIMESTAMP)) / 12 AS INT)", - }, - ) - self.validate_all( - "SELECT DATEDIFF(mm, 'start', 'end')", - write={ - "databricks": "SELECT DATEDIFF(MONTH, CAST('start' AS TIMESTAMP), CAST('end' AS TIMESTAMP))", - "spark2": "SELECT CAST(MONTHS_BETWEEN(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP)) AS INT)", - "tsql": "SELECT DATEDIFF(MONTH, CAST('start' AS DATETIME2), CAST('end' AS DATETIME2))", - }, - ) - self.validate_all( - "SELECT DATEDIFF(quarter, 'start', 'end')", - write={ - "databricks": "SELECT DATEDIFF(QUARTER, CAST('start' AS TIMESTAMP), CAST('end' AS TIMESTAMP))", - "spark": "SELECT DATEDIFF(QUARTER, CAST('start' AS TIMESTAMP), CAST('end' AS TIMESTAMP))", - "spark2": "SELECT CAST(MONTHS_BETWEEN(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP)) / 3 AS INT)", - "tsql": "SELECT DATEDIFF(QUARTER, CAST('start' AS DATETIME2), CAST('end' AS DATETIME2))", - }, - ) + for fnc in ["DATEDIFF", "DATEDIFF_BIG"]: + self.validate_all( + f"SELECT {fnc}(quarter, 0, '2021-01-01')", + write={ + "tsql": f"SELECT {fnc}(QUARTER, CAST('1900-01-01' AS DATETIME2), CAST('2021-01-01' AS DATETIME2))", + "spark": "SELECT DATEDIFF(QUARTER, CAST('1900-01-01' AS TIMESTAMP), CAST('2021-01-01' AS TIMESTAMP))", + "duckdb": "SELECT DATE_DIFF('QUARTER', CAST('1900-01-01' AS TIMESTAMP), CAST('2021-01-01' AS TIMESTAMP))", + }, + ) + self.validate_all( + f"SELECT {fnc}(day, 1, '2021-01-01')", + write={ + "tsql": f"SELECT {fnc}(DAY, CAST('1900-01-02' AS DATETIME2), CAST('2021-01-01' AS DATETIME2))", + "spark": "SELECT DATEDIFF(DAY, CAST('1900-01-02' AS TIMESTAMP), CAST('2021-01-01' AS TIMESTAMP))", + "duckdb": "SELECT DATE_DIFF('DAY', CAST('1900-01-02' AS TIMESTAMP), CAST('2021-01-01' AS TIMESTAMP))", + }, + ) + self.validate_all( + f"SELECT {fnc}(year, '2020-01-01', '2021-01-01')", + write={ + "tsql": f"SELECT {fnc}(YEAR, CAST('2020-01-01' AS DATETIME2), CAST('2021-01-01' AS DATETIME2))", + "spark": "SELECT DATEDIFF(YEAR, CAST('2020-01-01' AS TIMESTAMP), CAST('2021-01-01' AS TIMESTAMP))", + "spark2": "SELECT CAST(MONTHS_BETWEEN(CAST('2021-01-01' AS TIMESTAMP), CAST('2020-01-01' AS TIMESTAMP)) / 12 AS INT)", + }, + ) + self.validate_all( + f"SELECT {fnc}(mm, 'start', 'end')", + write={ + "databricks": "SELECT DATEDIFF(MONTH, CAST('start' AS TIMESTAMP), CAST('end' AS TIMESTAMP))", + "spark2": "SELECT CAST(MONTHS_BETWEEN(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP)) AS INT)", + "tsql": f"SELECT {fnc}(MONTH, CAST('start' AS DATETIME2), CAST('end' AS DATETIME2))", + }, + ) + self.validate_all( + f"SELECT {fnc}(quarter, 'start', 'end')", + write={ + "databricks": "SELECT DATEDIFF(QUARTER, CAST('start' AS TIMESTAMP), CAST('end' AS TIMESTAMP))", + "spark": "SELECT DATEDIFF(QUARTER, CAST('start' AS TIMESTAMP), CAST('end' AS TIMESTAMP))", + "spark2": "SELECT CAST(MONTHS_BETWEEN(CAST('end' AS TIMESTAMP), CAST('start' AS TIMESTAMP)) / 3 AS INT)", + "tsql": f"SELECT {fnc}(QUARTER, CAST('start' AS DATETIME2), CAST('end' AS DATETIME2))", + }, + ) - # Check superfluous casts arent added. ref: https://github.com/TobikoData/sqlmesh/issues/2672 - self.validate_all( - "SELECT DATEDIFF(DAY, CAST(a AS DATETIME2), CAST(b AS DATETIME2)) AS x FROM foo", - write={ - "tsql": "SELECT DATEDIFF(DAY, CAST(a AS DATETIME2), CAST(b AS DATETIME2)) AS x FROM foo", - "clickhouse": "SELECT DATE_DIFF(DAY, CAST(CAST(a AS Nullable(DateTime)) AS DateTime64(6)), CAST(CAST(b AS Nullable(DateTime)) AS DateTime64(6))) AS x FROM foo", - }, - ) + # Check superfluous casts arent added. ref: https://github.com/TobikoData/sqlmesh/issues/2672 + self.validate_all( + f"SELECT {fnc}(DAY, CAST(a AS DATETIME2), CAST(b AS DATETIME2)) AS x FROM foo", + write={ + "tsql": f"SELECT {fnc}(DAY, CAST(a AS DATETIME2), CAST(b AS DATETIME2)) AS x FROM foo", + "clickhouse": "SELECT DATE_DIFF(DAY, CAST(CAST(a AS Nullable(DateTime)) AS DateTime64(6)), CAST(CAST(b AS Nullable(DateTime)) AS DateTime64(6))) AS x FROM foo", + }, + ) - self.validate_identity( - "SELECT DATEADD(DAY, DATEDIFF(DAY, -3, GETDATE()), '08:00:00')", - "SELECT DATEADD(DAY, DATEDIFF(DAY, CAST('1899-12-29' AS DATETIME2), CAST(GETDATE() AS DATETIME2)), '08:00:00')", - ) + self.validate_identity( + f"SELECT DATEADD(DAY, {fnc}(DAY, -3, GETDATE()), '08:00:00')", + f"SELECT DATEADD(DAY, {fnc}(DAY, CAST('1899-12-29' AS DATETIME2), CAST(GETDATE() AS DATETIME2)), '08:00:00')", + ) def test_lateral_subquery(self): self.validate_all(