diff --git a/sqeleton/databases/snowflake.py b/sqeleton/databases/snowflake.py index 5edafc0..f843737 100644 --- a/sqeleton/databases/snowflake.py +++ b/sqeleton/databases/snowflake.py @@ -48,9 +48,9 @@ def md5_as_int(self, s: str) -> str: class Mixin_NormalizeValue(AbstractMixin_NormalizeValue): def normalize_timestamp(self, value: str, coltype: TemporalType) -> str: if coltype.rounds: - timestamp = f"to_timestamp(round(date_part(epoch_nanosecond, {value}::timestamp(9))/1000000000, {coltype.precision}))" + timestamp = f"to_timestamp(round(date_part(epoch_nanosecond, convert_timezone('UTC', {value})::timestamp(9))/1000000000, {coltype.precision}))" else: - timestamp = f"cast({value} as timestamp({coltype.precision}))" + timestamp = f"cast(convert_timezone('UTC', {value}) as timestamp({coltype.precision}))" return f"to_char({timestamp}, 'YYYY-MM-DD HH24:MI:SS.FF6')" diff --git a/tests/common.py b/tests/common.py index e3b460e..2be979d 100644 --- a/tests/common.py +++ b/tests/common.py @@ -7,10 +7,12 @@ import logging import subprocess +import sqeleton from parameterized import parameterized_class from sqeleton import databases as db from sqeleton import connect +from sqeleton.abcs.mixins import AbstractMixin_NormalizeValue from sqeleton.queries import table from sqeleton.databases import Database from sqeleton.query_utils import drop_table @@ -83,7 +85,8 @@ def get_conn(cls: type, shared: bool = True) -> Database: _database_instances[cls] = get_conn(cls, shared=False) return _database_instances[cls] - return connect(CONN_STRINGS[cls], N_THREADS) + con = sqeleton.connect.load_mixins(AbstractMixin_NormalizeValue) + return con(CONN_STRINGS[cls], N_THREADS) def _print_used_dbs(): diff --git a/tests/test_database.py b/tests/test_database.py index e42b92b..16be6cb 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -1,15 +1,14 @@ -from typing import Callable, List -from datetime import datetime import unittest +from datetime import datetime +from typing import Callable, List, Tuple -from .common import str_to_checksum, TEST_MYSQL_CONN_STRING -from .common import str_to_checksum, test_each_database_in_list, get_conn, random_table_suffix - -from sqeleton.queries import table, current_timestamp +import pytz -from sqeleton import databases as dbs from sqeleton import connect - +from sqeleton import databases as dbs +from sqeleton.queries import table, current_timestamp, NormalizeAsString +from .common import TEST_MYSQL_CONN_STRING +from .common import str_to_checksum, test_each_database_in_list, get_conn, random_table_suffix TEST_DATABASES = { dbs.MySQL, @@ -81,6 +80,37 @@ def test_current_timestamp(self): res = db.query(current_timestamp(), datetime) assert isinstance(res, datetime), (res, type(res)) + def test_correct_timezone(self): + name = "tbl_" + random_table_suffix() + db = get_conn(self.db_cls) + tbl = table(db.parse_table_name(name), schema={ + "id": int, "created_at": "timestamp_tz(9)", "updated_at": "timestamp_tz(9)" + }) + + db.query(tbl.create()) + + tz = pytz.timezone('Europe/Berlin') + + now = datetime.now(tz) + db.query(table(db.parse_table_name(name)).insert_row("1", now, now)) + db.query(db.dialect.set_timezone_to_utc()) + + t = db.table(name).query_schema() + t.schema["created_at"] = t.schema["created_at"].replace(precision=t.schema["created_at"].precision, rounds=True) + + tbl = table(db.parse_table_name(name), schema=t.schema) + + results = db.query(tbl.select(NormalizeAsString(tbl[c]) for c in ["created_at", "updated_at"]), List[Tuple]) + + created_at = results[0][1] + updated_at = results[0][1] + + utc = now.astimezone(pytz.UTC) + + self.assertEqual(created_at, utc.__format__("%Y-%m-%d %H:%M:%S.%f")) + self.assertEqual(updated_at, utc.__format__("%Y-%m-%d %H:%M:%S.%f")) + + db.query(tbl.drop()) @test_each_database class TestThreePartIds(unittest.TestCase): @@ -104,3 +134,4 @@ def test_three_part_support(self): d = db.query_table_schema(part.path) assert len(d) == 1 db.query(part.drop()) +