diff --git a/pyproject.toml b/pyproject.toml index ae1b3ff..19f83bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,6 +88,7 @@ dev = [ "py-spy>=0.4.0", "pyink>=24.10.1", "maturin>=1.9.1", + "pre-commit>=4.5.1", ] [tool.uv] diff --git a/uv.lock b/uv.lock index 14d56c2..00f1d9d 100644 --- a/uv.lock +++ b/uv.lock @@ -219,6 +219,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/84/ae/320161bd181fc06471eed047ecce67b693fd7515b16d495d8932db763426/certifi-2025.6.15-py3-none-any.whl", hash = "sha256:2e0c7ce7cb5d8f8634ca55d2ba7e6ec2689a2fd6537d8dec1296a477a4910057", size = 157650, upload-time = "2025-06-15T02:45:49.977Z" }, ] +[[package]] +name = "cfgv" +version = "3.5.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/4e/b5/721b8799b04bf9afe054a3899c6cf4e880fcf8563cc71c15610242490a0c/cfgv-3.5.0.tar.gz", hash = "sha256:d5b1034354820651caa73ede66a6294d6e95c1b00acc5e9b098e917404669132", size = 7334, upload-time = "2025-11-19T20:55:51.612Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/db/3c/33bac158f8ab7f89b2e59426d5fe2e4f63f7ed25df84c036890172b412b5/cfgv-3.5.0-py2.py3-none-any.whl", hash = "sha256:a8dc6b26ad22ff227d2634a65cb388215ce6cc96bbcc5cfde7641ae87e8dacc0", size = 7445, upload-time = "2025-11-19T20:55:50.744Z" }, +] + [[package]] name = "cftime" version = "1.6.4.post1" @@ -459,6 +468,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4e/8c/f3147f5c4b73e7550fe5f9352eaa956ae838d5c51eb58e7a25b9f3e2643b/decorator-5.2.1-py3-none-any.whl", hash = "sha256:d316bb415a2d9e2d2b3abcc4084c6502fc09240e292cd76a76afc106a1c8e04a", size = 9190, upload-time = "2025-02-24T04:41:32.565Z" }, ] +[[package]] +name = "distlib" +version = "0.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/96/8e/709914eb2b5749865801041647dc7f4e6d00b549cfe88b65ca192995f07c/distlib-0.4.0.tar.gz", hash = "sha256:feec40075be03a04501a973d81f633735b4b69f98b05450592310c0f401a4e0d", size = 614605, upload-time = "2025-07-17T16:52:00.465Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/33/6b/e0547afaf41bf2c42e52430072fa5658766e3d65bd4b03a563d1b6336f57/distlib-0.4.0-py2.py3-none-any.whl", hash = "sha256:9659f7d87e46584a30b5780e43ac7a2143098441670ff0a49d5f9034c54a6c16", size = 469047, upload-time = "2025-07-17T16:51:58.613Z" }, +] + [[package]] name = "donfig" version = "0.8.1.post1" @@ -492,6 +510,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/61/bf/fd60001b3abc5222d8eaa4a204cd8c0ae78e75adc688f33ce4bf25b7fafa/fasteners-0.19-py3-none-any.whl", hash = "sha256:758819cb5d94cdedf4e836988b74de396ceacb8e2794d21f82d131fd9ee77237", size = 18679, upload-time = "2023-09-19T17:11:18.725Z" }, ] +[[package]] +name = "filelock" +version = "3.25.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/77/18/a1fd2231c679dcb9726204645721b12498aeac28e1ad0601038f94b42556/filelock-3.25.0.tar.gz", hash = "sha256:8f00faf3abf9dc730a1ffe9c354ae5c04e079ab7d3a683b7c32da5dd05f26af3", size = 40158, upload-time = "2026-03-01T15:08:45.916Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f9/0b/de6f54d4a8bedfe8645c41497f3c18d749f0bd3218170c667bf4b81d0cdd/filelock-3.25.0-py3-none-any.whl", hash = "sha256:5ccf8069f7948f494968fc0713c10e5c182a9c9d9eef3a636307a20c2490f047", size = 26427, upload-time = "2026-03-01T15:08:44.593Z" }, +] + [[package]] name = "frozenlist" version = "1.7.0" @@ -790,6 +817,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3f/6d/0084ed0b78d4fd3e7530c32491f2884140d9b06365dac8a08de726421d4a/h5py-3.14.0-cp313-cp313-win_amd64.whl", hash = "sha256:ae18e3de237a7a830adb76aaa68ad438d85fe6e19e0d99944a3ce46b772c69b3", size = 2852929, upload-time = "2025-06-06T14:05:47.659Z" }, ] +[[package]] +name = "identify" +version = "2.6.17" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/57/84/376a3b96e5a8d33a7aa2c5b3b31a4b3c364117184bf0b17418055f6ace66/identify-2.6.17.tar.gz", hash = "sha256:f816b0b596b204c9fdf076ded172322f2723cf958d02f9c3587504834c8ff04d", size = 99579, upload-time = "2026-03-01T20:04:12.702Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/40/66/71c1227dff78aaeb942fed29dd5651f2aec166cc7c9aeea3e8b26a539b7d/identify-2.6.17-py2.py3-none-any.whl", hash = "sha256:be5f8412d5ed4b20f2bd41a65f920990bdccaa6a4a18a08f1eefdcd0bdd885f0", size = 99382, upload-time = "2026-03-01T20:04:11.439Z" }, +] + [[package]] name = "idna" version = "3.10" @@ -1091,6 +1127,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cd/10/c52f12297965938d9b9be666ea1f9d8340c2aea31d6909d90aa650847248/netcdf4-1.7.2-cp311-abi3-win_amd64.whl", hash = "sha256:999bfc4acebf400ed724d5e7329e2e768accc7ee1fa1d82d505da782f730301b", size = 7148514, upload-time = "2025-10-13T18:32:33.121Z" }, ] +[[package]] +name = "nodeenv" +version = "1.10.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/24/bf/d1bda4f6168e0b2e9e5958945e01910052158313224ada5ce1fb2e1113b8/nodeenv-1.10.0.tar.gz", hash = "sha256:996c191ad80897d076bdfba80a41994c2b47c68e224c542b48feba42ba00f8bb", size = 55611, upload-time = "2025-12-20T14:08:54.006Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/b2/d0896bdcdc8d28a7fc5717c305f1a861c26e18c05047949fb371034d98bd/nodeenv-1.10.0-py2.py3-none-any.whl", hash = "sha256:5bb13e3eed2923615535339b3c620e76779af4cb4c6a90deccc9e36b274d3827", size = 23438, upload-time = "2025-12-20T14:08:52.782Z" }, +] + [[package]] name = "numcodecs" version = "0.13.1" @@ -1409,6 +1454,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a8/87/77cc11c7a9ea9fd05503def69e3d18605852cd0d4b0d3b8f15bbeb3ef1d1/pooch-1.8.2-py3-none-any.whl", hash = "sha256:3529a57096f7198778a5ceefd5ac3ef0e4d06a6ddaf9fc2d609b806f25302c47", size = 64574, upload-time = "2024-06-06T16:53:44.343Z" }, ] +[[package]] +name = "pre-commit" +version = "4.5.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cfgv" }, + { name = "identify" }, + { name = "nodeenv" }, + { name = "pyyaml" }, + { name = "virtualenv" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/40/f1/6d86a29246dfd2e9b6237f0b5823717f60cad94d47ddc26afa916d21f525/pre_commit-4.5.1.tar.gz", hash = "sha256:eb545fcff725875197837263e977ea257a402056661f09dae08e4b149b030a61", size = 198232, upload-time = "2025-12-16T21:14:33.552Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5d/19/fd3ef348460c80af7bb4669ea7926651d1f95c23ff2df18b9d24bab4f3fa/pre_commit-4.5.1-py2.py3-none-any.whl", hash = "sha256:3b3afd891e97337708c1674210f8eba659b52a38ea5f822ff142d10786221f77", size = 226437, upload-time = "2025-12-16T21:14:32.409Z" }, +] + [[package]] name = "propcache" version = "0.3.2" @@ -1693,6 +1754,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892, upload-time = "2024-03-01T18:36:18.57Z" }, ] +[[package]] +name = "python-discovery" +version = "1.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "filelock" }, + { name = "platformdirs" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/82/bb/93a3e83bdf9322c7e21cafd092e56a4a17c4d8ef4277b6eb01af1a540a6f/python_discovery-1.1.0.tar.gz", hash = "sha256:447941ba1aed8cc2ab7ee3cb91be5fc137c5bdbb05b7e6ea62fbdcb66e50b268", size = 55674, upload-time = "2026-02-26T09:42:49.668Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/06/54/82a6e2ef37f0f23dccac604b9585bdcbd0698604feb64807dcb72853693e/python_discovery-1.1.0-py3-none-any.whl", hash = "sha256:a162893b8809727f54594a99ad2179d2ede4bf953e12d4c7abc3cc9cdbd1437b", size = 30687, upload-time = "2026-02-26T09:42:48.548Z" }, +] + [[package]] name = "pytz" version = "2025.2" @@ -2020,6 +2094,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a7/c2/fe1e52489ae3122415c51f387e221dd0773709bad6c6cdaa599e8a2c5185/urllib3-2.5.0-py3-none-any.whl", hash = "sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc", size = 129795, upload-time = "2025-06-18T14:07:40.39Z" }, ] +[[package]] +name = "virtualenv" +version = "21.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "distlib" }, + { name = "filelock" }, + { name = "platformdirs" }, + { name = "python-discovery" }, + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/2f/c9/18d4b36606d6091844daa3bd93cf7dc78e6f5da21d9f21d06c221104b684/virtualenv-21.1.0.tar.gz", hash = "sha256:1990a0188c8f16b6b9cf65c9183049007375b26aad415514d377ccacf1e4fb44", size = 5840471, upload-time = "2026-02-27T08:49:29.702Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/78/55/896b06bf93a49bec0f4ae2a6f1ed12bd05c8860744ac3a70eda041064e4d/virtualenv-21.1.0-py3-none-any.whl", hash = "sha256:164f5e14c5587d170cf98e60378eb91ea35bf037be313811905d3a24ea33cc07", size = 5825072, upload-time = "2026-02-27T08:49:27.516Z" }, +] + [[package]] name = "webob" version = "1.8.9" @@ -2113,6 +2203,7 @@ test = [ [package.dev-dependencies] dev = [ { name = "maturin" }, + { name = "pre-commit" }, { name = "py-spy" }, { name = "pyink" }, { name = "xarray-sql", extra = ["test"] }, @@ -2132,6 +2223,7 @@ provides-extras = ["test"] [package.metadata.requires-dev] dev = [ { name = "maturin", specifier = ">=1.9.1" }, + { name = "pre-commit", specifier = ">=4.5.1" }, { name = "py-spy", specifier = ">=0.4.0" }, { name = "pyink", specifier = ">=24.10.1" }, { name = "xarray-sql", extras = ["test"] }, diff --git a/xarray_sql/reader.py b/xarray_sql/reader.py index 8777bea..2e29be0 100644 --- a/xarray_sql/reader.py +++ b/xarray_sql/reader.py @@ -15,6 +15,7 @@ import pyarrow as pa import xarray as xr +import datafusion as dfn from .df import ( Block, @@ -272,3 +273,106 @@ def partition_pairs(): yield make_partition_factory(block), _block_metadata(coord_arrays, block) return LazyArrowStreamTable(partition_pairs(), schema) + + def group_vars_by_dims(ds): + """ + Group variables in the dataset based on shared dims + + ("time", "lat", "lon"): ["temperature_2m", "wind_speed"], + ("time", "lat", "lon", "level"): ["pressure", "humidity"] + """ + groups = {} + + for var_name, var in ds.data_vars.items(): + dims = var.dims + + if dims not in groups: + groups[dims] = [] + + groups[dims].append(var_name) + + return groups + + +def dims_to_table_name(dims): + """ + "time", "lat", "lon" -> "time_lat_lon" + """ + return "_".join(dims) + + +class XarraySchemaProvider(dfn.catalog.SchemaProvider): + """ + Custom datafusion schema that holds the tables + """ + + def __init__(self, ds, groups, chunks): + # dictionary to store the tables + self.tables = {} + + # create a table for for each group of vars + for dims, var_names in groups.items(): + table_name = dims_to_table_name(dims) + subset = ds[var_names] + self.tables[table_name] = read_xarray_table(subset, chunks) + + def table_names(self): + return set(self.tables.keys()) + + def table(self, name): + return self.tables.get(name) + + def table_exist(self, name): + return name in self.tables + + def register_table(self, name, table): + self.tables[name] = table + + def deregister_table(self, name, cascade=True): + del self.tables[name] + + +class XarrayCatalogProvider(dfn.catalog.CatalogProvider): + """ + Custom datafusion catalog that holds the schemas + """ + + # Constructor + def __init__(self, ds, schema_name, chunks): + groups = group_vars_by_dims(ds) + + # dictionary to store schemas using previous schema class + """ + "data": { + "time_lat_lon": [temperature_2m, wind_speed], + "time_lat_lon_level": [pressure, humidity] + } + """ + self.schemas = {schema_name: XarraySchemaProvider(ds, groups, chunks)} + + """ + Other methods from test_catalog.py + """ + + def schema_names(self): + return set(self.schemas.keys()) + + def schema(self, name): + return self.schemas.get(name) + + def register_schema(self, name, schema): + self.schemas[name] = schema + + def deregister_schema(self, name, cascade=True): + del self.schemas[name] + + +def register_catalog_from_dataset( + ctx, ds, catalog_name="xarray", schema_name="data", chunks=None +): + """ + Main function. Takes an xarray dataset and registers it + with DataFusion so you can query it with SQL. + """ + catalog = XarrayCatalogProvider(ds, schema_name, chunks) + ctx.register_catalog_provider(catalog_name, catalog) diff --git a/xarray_sql/sql.py b/xarray_sql/sql.py index 4bdb705..9c3e371 100644 --- a/xarray_sql/sql.py +++ b/xarray_sql/sql.py @@ -2,11 +2,14 @@ from datafusion import SessionContext from .df import Chunks -from .reader import read_xarray_table +from .reader import read_xarray_table, register_catalog_from_dataset class XarrayContext(SessionContext): - """A datafusion `SessionContext` that also supports `xarray.Dataset`s.""" + """ + A regular DataFusion SessionContext but with an extra method + for registering xarray datasets. + """ def from_dataset( self, @@ -16,3 +19,8 @@ def from_dataset( ): table = read_xarray_table(input_table, chunks) self.register_table(table_name, table) + + def register_catalog_from_dataset( + self, ds, catalog_name="xarray", schema_name="data", chunks=None + ): + register_catalog_from_dataset(self, ds, catalog_name, schema_name, chunks)