From a6bfde343ba62e3aa40a4be121cc09330730e06d Mon Sep 17 00:00:00 2001 From: Yagna Chilukuri Date: Mon, 2 Mar 2026 09:26:15 -0800 Subject: [PATCH 1/3] add catalog provider for review --- catalog.py | 105 ++++++++++++++++++++++++++++++++++++++++++++++ catalog_test.py | 63 ++++++++++++++++++++++++++++ xarray_context.py | 13 ++++++ 3 files changed, 181 insertions(+) create mode 100644 catalog.py create mode 100644 catalog_test.py create mode 100644 xarray_context.py diff --git a/catalog.py b/catalog.py new file mode 100644 index 0000000..a190f8a --- /dev/null +++ b/catalog.py @@ -0,0 +1,105 @@ +import xarray as xr +import datafusion as dfn +from xarray_sql.reader import read_xarray_table + + +def group_vars_by_dims(ds): + """ + Group variables in the dataset based on shared dims + + ("time", "lat", "lon"): ["temperature_2m", "wind_speed"], + ("time", "lat", "lon", "level"): ["pressure", "humidity"] + """ + groups = {} + + for var_name, var in ds.data_vars.items(): + dims = var.dims + + if dims not in groups: + groups[dims] = [] + + groups[dims].append(var_name) + + return groups + + +def dims_to_table_name(dims): + """ + "time", "lat", "lon" -> "time_lat_lon" + """ + return "_".join(dims) + + +class XarraySchemaProvider(dfn.catalog.SchemaProvider): + """ + Custom datafusion schema that holds the tables + """ + + def __init__(self, ds, groups, chunks): + # dictionary to store the tables + self.tables = {} + + # create a table for for each group of vars + for dims, var_names in groups.items(): + table_name = dims_to_table_name(dims) + subset = ds[var_names] + self.tables[table_name] = read_xarray_table(subset, chunks) + + def table_names(self): + return set(self.tables.keys()) + + def table(self, name): + return self.tables.get(name) + + def table_exist(self, name): + return name in self.tables + + def register_table(self, name, table): + self.tables[name] = table + + def deregister_table(self, name, cascade=True): + del self.tables[name] + + +class XarrayCatalogProvider(dfn.catalog.CatalogProvider): + """ + Custom datafusion catalog that holds the schemas + """ + #Constructor + def __init__(self, ds, schema_name, chunks): + groups = group_vars_by_dims(ds) + + # dictionary to store schemas using previous schema class + """ + "data": { + "time_lat_lon": [temperature_2m, wind_speed], + "time_lat_lon_level": [pressure, humidity] + } + """ + self.schemas = { + schema_name: XarraySchemaProvider(ds, groups, chunks) + } + + """ + Other methods from test_catalog.py + """ + def schema_names(self): + return set(self.schemas.keys()) + + def schema(self, name): + return self.schemas.get(name) + + def register_schema(self, name, schema): + self.schemas[name] = schema + + def deregister_schema(self, name, cascade=True): + del self.schemas[name] + + +def register_catalog_from_dataset(ctx, ds, catalog_name="xarray", schema_name="data", chunks=None): + """ + Main function. Takes an xarray dataset and registers it + with DataFusion so you can query it with SQL. + """ + catalog = XarrayCatalogProvider(ds, schema_name, chunks) + ctx.register_catalog_provider(catalog_name, catalog) \ No newline at end of file diff --git a/catalog_test.py b/catalog_test.py new file mode 100644 index 0000000..1ec8bbb --- /dev/null +++ b/catalog_test.py @@ -0,0 +1,63 @@ +import numpy as np +import xarray as xr +from context import XarrayContext + + +# create a fake era5 dataset for testing +times = np.array(["2020-01-01", "2020-01-02", "2020-01-03"], dtype="datetime64") +lats = np.array([0.0, 1.0, 2.0]) +lons = np.array([0.0, 1.0, 2.0]) +levels = np.array([500, 850]) + +ds = xr.Dataset( + { + "temperature_2m": (["time", "lat", "lon"], np.random.rand(3, 3, 3)), + "wind_speed": (["time", "lat", "lon"], np.random.rand(3, 3, 3)), + "pressure": (["time", "lat", "lon", "level"], np.random.rand(3, 3, 3, 2)), + "humidity": (["time", "lat", "lon", "level"], np.random.rand(3, 3, 3, 2)), + }, + coords={ + "time": times, + "lat": lats, + "lon": lons, + "level": levels, + } +).chunk({"time": 1}) + +print("Variables:", list(ds.data_vars)) +print("Dimensions:", list(ds.dims)) + +ctx = XarrayContext() +ctx.register_catalog_from_dataset(ds) + +print("\nCatalogs:", ctx.catalog_names()) +print("Schemas:", ctx.catalog("xarray").schema_names()) +print("Tables:", ctx.catalog("xarray").schema("data").table_names()) + +print("\n--- Surface variables (time, lat, lon) ---") +result = ctx.sql("SELECT * FROM xarray.data.time_lat_lon LIMIT 5").collect() +for batch in result: + print(batch.to_pandas()) + + +print("\n--- Atmospheric variables (time, lat, lon, level) ---") +result = ctx.sql("SELECT * FROM xarray.data.time_lat_lon_level LIMIT 5").collect() +for batch in result: + print(batch.to_pandas()) + +print("\n--- Joined surface + atmospheric on shared dims ---") +result = ctx.sql(""" + SELECT + s.time, s.lat, s.lon, + s.temperature_2m, + a.level, + a.pressure + FROM xarray.data.time_lat_lon s + JOIN xarray.data.time_lat_lon_level a + ON s.time = a.time + AND s.lat = a.lat + AND s.lon = a.lon + LIMIT 10 +""").collect() +for batch in result: + print(batch.to_pandas()) \ No newline at end of file diff --git a/xarray_context.py b/xarray_context.py new file mode 100644 index 0000000..126a332 --- /dev/null +++ b/xarray_context.py @@ -0,0 +1,13 @@ +import xarray as xr +from datafusion import SessionContext +from catalog import register_catalog_from_dataset + + +class XarrayContext(SessionContext): + """ + A regular DataFusion SessionContext but with an extra method + for registering xarray datasets. + """ + + def register_catalog_from_dataset(self, ds, catalog_name="xarray", schema_name="data", chunks=None): + register_catalog_from_dataset(self, ds, catalog_name, schema_name, chunks) \ No newline at end of file From 5fed884ab73f1672f7ef1c2ae8ab291c84ed53f3 Mon Sep 17 00:00:00 2001 From: Yagna Chilukuri Date: Thu, 5 Mar 2026 17:59:00 -0800 Subject: [PATCH 2/3] Remove mistakenly added catalog files --- catalog.py | 105 ---------------------------------------------- catalog_test.py | 63 ---------------------------- xarray_context.py | 13 ------ 3 files changed, 181 deletions(-) delete mode 100644 catalog.py delete mode 100644 catalog_test.py delete mode 100644 xarray_context.py diff --git a/catalog.py b/catalog.py deleted file mode 100644 index a190f8a..0000000 --- a/catalog.py +++ /dev/null @@ -1,105 +0,0 @@ -import xarray as xr -import datafusion as dfn -from xarray_sql.reader import read_xarray_table - - -def group_vars_by_dims(ds): - """ - Group variables in the dataset based on shared dims - - ("time", "lat", "lon"): ["temperature_2m", "wind_speed"], - ("time", "lat", "lon", "level"): ["pressure", "humidity"] - """ - groups = {} - - for var_name, var in ds.data_vars.items(): - dims = var.dims - - if dims not in groups: - groups[dims] = [] - - groups[dims].append(var_name) - - return groups - - -def dims_to_table_name(dims): - """ - "time", "lat", "lon" -> "time_lat_lon" - """ - return "_".join(dims) - - -class XarraySchemaProvider(dfn.catalog.SchemaProvider): - """ - Custom datafusion schema that holds the tables - """ - - def __init__(self, ds, groups, chunks): - # dictionary to store the tables - self.tables = {} - - # create a table for for each group of vars - for dims, var_names in groups.items(): - table_name = dims_to_table_name(dims) - subset = ds[var_names] - self.tables[table_name] = read_xarray_table(subset, chunks) - - def table_names(self): - return set(self.tables.keys()) - - def table(self, name): - return self.tables.get(name) - - def table_exist(self, name): - return name in self.tables - - def register_table(self, name, table): - self.tables[name] = table - - def deregister_table(self, name, cascade=True): - del self.tables[name] - - -class XarrayCatalogProvider(dfn.catalog.CatalogProvider): - """ - Custom datafusion catalog that holds the schemas - """ - #Constructor - def __init__(self, ds, schema_name, chunks): - groups = group_vars_by_dims(ds) - - # dictionary to store schemas using previous schema class - """ - "data": { - "time_lat_lon": [temperature_2m, wind_speed], - "time_lat_lon_level": [pressure, humidity] - } - """ - self.schemas = { - schema_name: XarraySchemaProvider(ds, groups, chunks) - } - - """ - Other methods from test_catalog.py - """ - def schema_names(self): - return set(self.schemas.keys()) - - def schema(self, name): - return self.schemas.get(name) - - def register_schema(self, name, schema): - self.schemas[name] = schema - - def deregister_schema(self, name, cascade=True): - del self.schemas[name] - - -def register_catalog_from_dataset(ctx, ds, catalog_name="xarray", schema_name="data", chunks=None): - """ - Main function. Takes an xarray dataset and registers it - with DataFusion so you can query it with SQL. - """ - catalog = XarrayCatalogProvider(ds, schema_name, chunks) - ctx.register_catalog_provider(catalog_name, catalog) \ No newline at end of file diff --git a/catalog_test.py b/catalog_test.py deleted file mode 100644 index 1ec8bbb..0000000 --- a/catalog_test.py +++ /dev/null @@ -1,63 +0,0 @@ -import numpy as np -import xarray as xr -from context import XarrayContext - - -# create a fake era5 dataset for testing -times = np.array(["2020-01-01", "2020-01-02", "2020-01-03"], dtype="datetime64") -lats = np.array([0.0, 1.0, 2.0]) -lons = np.array([0.0, 1.0, 2.0]) -levels = np.array([500, 850]) - -ds = xr.Dataset( - { - "temperature_2m": (["time", "lat", "lon"], np.random.rand(3, 3, 3)), - "wind_speed": (["time", "lat", "lon"], np.random.rand(3, 3, 3)), - "pressure": (["time", "lat", "lon", "level"], np.random.rand(3, 3, 3, 2)), - "humidity": (["time", "lat", "lon", "level"], np.random.rand(3, 3, 3, 2)), - }, - coords={ - "time": times, - "lat": lats, - "lon": lons, - "level": levels, - } -).chunk({"time": 1}) - -print("Variables:", list(ds.data_vars)) -print("Dimensions:", list(ds.dims)) - -ctx = XarrayContext() -ctx.register_catalog_from_dataset(ds) - -print("\nCatalogs:", ctx.catalog_names()) -print("Schemas:", ctx.catalog("xarray").schema_names()) -print("Tables:", ctx.catalog("xarray").schema("data").table_names()) - -print("\n--- Surface variables (time, lat, lon) ---") -result = ctx.sql("SELECT * FROM xarray.data.time_lat_lon LIMIT 5").collect() -for batch in result: - print(batch.to_pandas()) - - -print("\n--- Atmospheric variables (time, lat, lon, level) ---") -result = ctx.sql("SELECT * FROM xarray.data.time_lat_lon_level LIMIT 5").collect() -for batch in result: - print(batch.to_pandas()) - -print("\n--- Joined surface + atmospheric on shared dims ---") -result = ctx.sql(""" - SELECT - s.time, s.lat, s.lon, - s.temperature_2m, - a.level, - a.pressure - FROM xarray.data.time_lat_lon s - JOIN xarray.data.time_lat_lon_level a - ON s.time = a.time - AND s.lat = a.lat - AND s.lon = a.lon - LIMIT 10 -""").collect() -for batch in result: - print(batch.to_pandas()) \ No newline at end of file diff --git a/xarray_context.py b/xarray_context.py deleted file mode 100644 index 126a332..0000000 --- a/xarray_context.py +++ /dev/null @@ -1,13 +0,0 @@ -import xarray as xr -from datafusion import SessionContext -from catalog import register_catalog_from_dataset - - -class XarrayContext(SessionContext): - """ - A regular DataFusion SessionContext but with an extra method - for registering xarray datasets. - """ - - def register_catalog_from_dataset(self, ds, catalog_name="xarray", schema_name="data", chunks=None): - register_catalog_from_dataset(self, ds, catalog_name, schema_name, chunks) \ No newline at end of file From bf03125d91a2a5452a41cea4f6c2cc42ce37705b Mon Sep 17 00:00:00 2001 From: Yagna Chilukuri Date: Thu, 5 Mar 2026 18:10:36 -0800 Subject: [PATCH 3/3] move catalog functionality into existing modules --- pyproject.toml | 1 + uv.lock | 92 ++++++++++++++++++++++++++++++++++++++ xarray_sql/reader.py | 104 +++++++++++++++++++++++++++++++++++++++++++ xarray_sql/sql.py | 12 ++++- 4 files changed, 207 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ae1b3ff..19f83bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,6 +88,7 @@ dev = [ "py-spy>=0.4.0", "pyink>=24.10.1", "maturin>=1.9.1", + "pre-commit>=4.5.1", ] [tool.uv] diff --git a/uv.lock b/uv.lock index 14d56c2..00f1d9d 100644 --- a/uv.lock +++ b/uv.lock @@ -219,6 +219,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/84/ae/320161bd181fc06471eed047ecce67b693fd7515b16d495d8932db763426/certifi-2025.6.15-py3-none-any.whl", hash = "sha256:2e0c7ce7cb5d8f8634ca55d2ba7e6ec2689a2fd6537d8dec1296a477a4910057", size = 157650, upload-time = "2025-06-15T02:45:49.977Z" }, ] +[[package]] +name = "cfgv" +version = "3.5.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/4e/b5/721b8799b04bf9afe054a3899c6cf4e880fcf8563cc71c15610242490a0c/cfgv-3.5.0.tar.gz", hash = "sha256:d5b1034354820651caa73ede66a6294d6e95c1b00acc5e9b098e917404669132", size = 7334, upload-time = "2025-11-19T20:55:51.612Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/db/3c/33bac158f8ab7f89b2e59426d5fe2e4f63f7ed25df84c036890172b412b5/cfgv-3.5.0-py2.py3-none-any.whl", hash = "sha256:a8dc6b26ad22ff227d2634a65cb388215ce6cc96bbcc5cfde7641ae87e8dacc0", size = 7445, upload-time = "2025-11-19T20:55:50.744Z" }, +] + [[package]] name = "cftime" version = "1.6.4.post1" @@ -459,6 +468,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4e/8c/f3147f5c4b73e7550fe5f9352eaa956ae838d5c51eb58e7a25b9f3e2643b/decorator-5.2.1-py3-none-any.whl", hash = "sha256:d316bb415a2d9e2d2b3abcc4084c6502fc09240e292cd76a76afc106a1c8e04a", size = 9190, upload-time = "2025-02-24T04:41:32.565Z" }, ] +[[package]] +name = "distlib" +version = "0.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/96/8e/709914eb2b5749865801041647dc7f4e6d00b549cfe88b65ca192995f07c/distlib-0.4.0.tar.gz", hash = "sha256:feec40075be03a04501a973d81f633735b4b69f98b05450592310c0f401a4e0d", size = 614605, upload-time = "2025-07-17T16:52:00.465Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/33/6b/e0547afaf41bf2c42e52430072fa5658766e3d65bd4b03a563d1b6336f57/distlib-0.4.0-py2.py3-none-any.whl", hash = "sha256:9659f7d87e46584a30b5780e43ac7a2143098441670ff0a49d5f9034c54a6c16", size = 469047, upload-time = "2025-07-17T16:51:58.613Z" }, +] + [[package]] name = "donfig" version = "0.8.1.post1" @@ -492,6 +510,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/61/bf/fd60001b3abc5222d8eaa4a204cd8c0ae78e75adc688f33ce4bf25b7fafa/fasteners-0.19-py3-none-any.whl", hash = "sha256:758819cb5d94cdedf4e836988b74de396ceacb8e2794d21f82d131fd9ee77237", size = 18679, upload-time = "2023-09-19T17:11:18.725Z" }, ] +[[package]] +name = "filelock" +version = "3.25.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/77/18/a1fd2231c679dcb9726204645721b12498aeac28e1ad0601038f94b42556/filelock-3.25.0.tar.gz", hash = "sha256:8f00faf3abf9dc730a1ffe9c354ae5c04e079ab7d3a683b7c32da5dd05f26af3", size = 40158, upload-time = "2026-03-01T15:08:45.916Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f9/0b/de6f54d4a8bedfe8645c41497f3c18d749f0bd3218170c667bf4b81d0cdd/filelock-3.25.0-py3-none-any.whl", hash = "sha256:5ccf8069f7948f494968fc0713c10e5c182a9c9d9eef3a636307a20c2490f047", size = 26427, upload-time = "2026-03-01T15:08:44.593Z" }, +] + [[package]] name = "frozenlist" version = "1.7.0" @@ -790,6 +817,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3f/6d/0084ed0b78d4fd3e7530c32491f2884140d9b06365dac8a08de726421d4a/h5py-3.14.0-cp313-cp313-win_amd64.whl", hash = "sha256:ae18e3de237a7a830adb76aaa68ad438d85fe6e19e0d99944a3ce46b772c69b3", size = 2852929, upload-time = "2025-06-06T14:05:47.659Z" }, ] +[[package]] +name = "identify" +version = "2.6.17" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/57/84/376a3b96e5a8d33a7aa2c5b3b31a4b3c364117184bf0b17418055f6ace66/identify-2.6.17.tar.gz", hash = "sha256:f816b0b596b204c9fdf076ded172322f2723cf958d02f9c3587504834c8ff04d", size = 99579, upload-time = "2026-03-01T20:04:12.702Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/40/66/71c1227dff78aaeb942fed29dd5651f2aec166cc7c9aeea3e8b26a539b7d/identify-2.6.17-py2.py3-none-any.whl", hash = "sha256:be5f8412d5ed4b20f2bd41a65f920990bdccaa6a4a18a08f1eefdcd0bdd885f0", size = 99382, upload-time = "2026-03-01T20:04:11.439Z" }, +] + [[package]] name = "idna" version = "3.10" @@ -1091,6 +1127,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cd/10/c52f12297965938d9b9be666ea1f9d8340c2aea31d6909d90aa650847248/netcdf4-1.7.2-cp311-abi3-win_amd64.whl", hash = "sha256:999bfc4acebf400ed724d5e7329e2e768accc7ee1fa1d82d505da782f730301b", size = 7148514, upload-time = "2025-10-13T18:32:33.121Z" }, ] +[[package]] +name = "nodeenv" +version = "1.10.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/24/bf/d1bda4f6168e0b2e9e5958945e01910052158313224ada5ce1fb2e1113b8/nodeenv-1.10.0.tar.gz", hash = "sha256:996c191ad80897d076bdfba80a41994c2b47c68e224c542b48feba42ba00f8bb", size = 55611, upload-time = "2025-12-20T14:08:54.006Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/b2/d0896bdcdc8d28a7fc5717c305f1a861c26e18c05047949fb371034d98bd/nodeenv-1.10.0-py2.py3-none-any.whl", hash = "sha256:5bb13e3eed2923615535339b3c620e76779af4cb4c6a90deccc9e36b274d3827", size = 23438, upload-time = "2025-12-20T14:08:52.782Z" }, +] + [[package]] name = "numcodecs" version = "0.13.1" @@ -1409,6 +1454,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a8/87/77cc11c7a9ea9fd05503def69e3d18605852cd0d4b0d3b8f15bbeb3ef1d1/pooch-1.8.2-py3-none-any.whl", hash = "sha256:3529a57096f7198778a5ceefd5ac3ef0e4d06a6ddaf9fc2d609b806f25302c47", size = 64574, upload-time = "2024-06-06T16:53:44.343Z" }, ] +[[package]] +name = "pre-commit" +version = "4.5.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cfgv" }, + { name = "identify" }, + { name = "nodeenv" }, + { name = "pyyaml" }, + { name = "virtualenv" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/40/f1/6d86a29246dfd2e9b6237f0b5823717f60cad94d47ddc26afa916d21f525/pre_commit-4.5.1.tar.gz", hash = "sha256:eb545fcff725875197837263e977ea257a402056661f09dae08e4b149b030a61", size = 198232, upload-time = "2025-12-16T21:14:33.552Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5d/19/fd3ef348460c80af7bb4669ea7926651d1f95c23ff2df18b9d24bab4f3fa/pre_commit-4.5.1-py2.py3-none-any.whl", hash = "sha256:3b3afd891e97337708c1674210f8eba659b52a38ea5f822ff142d10786221f77", size = 226437, upload-time = "2025-12-16T21:14:32.409Z" }, +] + [[package]] name = "propcache" version = "0.3.2" @@ -1693,6 +1754,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" }, ] +[[package]] +name = "python-discovery" +version = "1.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "filelock" }, + { name = "platformdirs" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/82/bb/93a3e83bdf9322c7e21cafd092e56a4a17c4d8ef4277b6eb01af1a540a6f/python_discovery-1.1.0.tar.gz", hash = "sha256:447941ba1aed8cc2ab7ee3cb91be5fc137c5bdbb05b7e6ea62fbdcb66e50b268", size = 55674, upload-time = "2026-02-26T09:42:49.668Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/06/54/82a6e2ef37f0f23dccac604b9585bdcbd0698604feb64807dcb72853693e/python_discovery-1.1.0-py3-none-any.whl", hash = "sha256:a162893b8809727f54594a99ad2179d2ede4bf953e12d4c7abc3cc9cdbd1437b", size = 30687, upload-time = "2026-02-26T09:42:48.548Z" }, +] + [[package]] name = "pytz" version = "2025.2" @@ -2020,6 +2094,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a7/c2/fe1e52489ae3122415c51f387e221dd0773709bad6c6cdaa599e8a2c5185/urllib3-2.5.0-py3-none-any.whl", hash = "sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc", size = 129795, upload-time = "2025-06-18T14:07:40.39Z" }, ] +[[package]] +name = "virtualenv" +version = "21.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "distlib" }, + { name = "filelock" }, + { name = "platformdirs" }, + { name = "python-discovery" }, + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/2f/c9/18d4b36606d6091844daa3bd93cf7dc78e6f5da21d9f21d06c221104b684/virtualenv-21.1.0.tar.gz", hash = "sha256:1990a0188c8f16b6b9cf65c9183049007375b26aad415514d377ccacf1e4fb44", size = 5840471, upload-time = "2026-02-27T08:49:29.702Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/78/55/896b06bf93a49bec0f4ae2a6f1ed12bd05c8860744ac3a70eda041064e4d/virtualenv-21.1.0-py3-none-any.whl", hash = "sha256:164f5e14c5587d170cf98e60378eb91ea35bf037be313811905d3a24ea33cc07", size = 5825072, upload-time = "2026-02-27T08:49:27.516Z" }, +] + [[package]] name = "webob" version = "1.8.9" @@ -2113,6 +2203,7 @@ test = [ [package.dev-dependencies] dev = [ { name = "maturin" }, + { name = "pre-commit" }, { name = "py-spy" }, { name = "pyink" }, { name = "xarray-sql", extra = ["test"] }, @@ -2132,6 +2223,7 @@ provides-extras = ["test"] [package.metadata.requires-dev] dev = [ { name = "maturin", specifier = ">=1.9.1" }, + { name = "pre-commit", specifier = ">=4.5.1" }, { name = "py-spy", specifier = ">=0.4.0" }, { name = "pyink", specifier = ">=24.10.1" }, { name = "xarray-sql", extras = ["test"] }, diff --git a/xarray_sql/reader.py b/xarray_sql/reader.py index 8777bea..2e29be0 100644 --- a/xarray_sql/reader.py +++ b/xarray_sql/reader.py @@ -15,6 +15,7 @@ import pyarrow as pa import xarray as xr +import datafusion as dfn from .df import ( Block, @@ -272,3 +273,106 @@ def partition_pairs(): yield make_partition_factory(block), _block_metadata(coord_arrays, block) return LazyArrowStreamTable(partition_pairs(), schema) + + def group_vars_by_dims(ds): + """ + Group variables in the dataset based on shared dims + + ("time", "lat", "lon"): ["temperature_2m", "wind_speed"], + ("time", "lat", "lon", "level"): ["pressure", "humidity"] + """ + groups = {} + + for var_name, var in ds.data_vars.items(): + dims = var.dims + + if dims not in groups: + groups[dims] = [] + + groups[dims].append(var_name) + + return groups + + +def dims_to_table_name(dims): + """ + "time", "lat", "lon" -> "time_lat_lon" + """ + return "_".join(dims) + + +class XarraySchemaProvider(dfn.catalog.SchemaProvider): + """ + Custom datafusion schema that holds the tables + """ + + def __init__(self, ds, groups, chunks): + # dictionary to store the tables + self.tables = {} + + # create a table for for each group of vars + for dims, var_names in groups.items(): + table_name = dims_to_table_name(dims) + subset = ds[var_names] + self.tables[table_name] = read_xarray_table(subset, chunks) + + def table_names(self): + return set(self.tables.keys()) + + def table(self, name): + return self.tables.get(name) + + def table_exist(self, name): + return name in self.tables + + def register_table(self, name, table): + self.tables[name] = table + + def deregister_table(self, name, cascade=True): + del self.tables[name] + + +class XarrayCatalogProvider(dfn.catalog.CatalogProvider): + """ + Custom datafusion catalog that holds the schemas + """ + + # Constructor + def __init__(self, ds, schema_name, chunks): + groups = group_vars_by_dims(ds) + + # dictionary to store schemas using previous schema class + """ + "data": { + "time_lat_lon": [temperature_2m, wind_speed], + "time_lat_lon_level": [pressure, humidity] + } + """ + self.schemas = {schema_name: XarraySchemaProvider(ds, groups, chunks)} + + """ + Other methods from test_catalog.py + """ + + def schema_names(self): + return set(self.schemas.keys()) + + def schema(self, name): + return self.schemas.get(name) + + def register_schema(self, name, schema): + self.schemas[name] = schema + + def deregister_schema(self, name, cascade=True): + del self.schemas[name] + + +def register_catalog_from_dataset( + ctx, ds, catalog_name="xarray", schema_name="data", chunks=None +): + """ + Main function. Takes an xarray dataset and registers it + with DataFusion so you can query it with SQL. + """ + catalog = XarrayCatalogProvider(ds, schema_name, chunks) + ctx.register_catalog_provider(catalog_name, catalog) diff --git a/xarray_sql/sql.py b/xarray_sql/sql.py index 4bdb705..9c3e371 100644 --- a/xarray_sql/sql.py +++ b/xarray_sql/sql.py @@ -2,11 +2,14 @@ from datafusion import SessionContext from .df import Chunks -from .reader import read_xarray_table +from .reader import read_xarray_table, register_catalog_from_dataset class XarrayContext(SessionContext): - """A datafusion `SessionContext` that also supports `xarray.Dataset`s.""" + """ + A regular DataFusion SessionContext but with an extra method + for registering xarray datasets. + """ def from_dataset( self, @@ -16,3 +19,8 @@ def from_dataset( ): table = read_xarray_table(input_table, chunks) self.register_table(table_name, table) + + def register_catalog_from_dataset( + self, ds, catalog_name="xarray", schema_name="data", chunks=None + ): + register_catalog_from_dataset(self, ds, catalog_name, schema_name, chunks)