From aded7b666c8784337198199b4f455c48ed8c544a Mon Sep 17 00:00:00 2001 From: Michael Peteuil Date: Thu, 15 Feb 2024 17:00:14 -0500 Subject: [PATCH] Make Datasets hashable Currently DAGs accept a [`Collection["Dataset"]`](https://github.com/apache/airflow/blob/0c02ead4d8a527cbf0a916b6344f255c520e637f/airflow/models/dag.py#L171) as an option for the `schedule`, but that collection cannot be a `set` because Datasets are not a hashable type. The interesting thing is that [the `DatasetModel` is actually already hashable](https://github.com/apache/airflow/blob/dec78ab3f140f35e507de825327652ec24d03522/airflow/models/dataset.py#L93-L100), so this introduces a bit of duplication since it's the same implementation. However, Airflow users are primarily interfacing with `Dataset`, not `DatasetModel` so I think it makes sense for `Dataset` to be hashable. I'm not sure how to square the duplication or what `__eq__` and `__hash__` provide for `DatasetModel` though. There was discussion on the original PR that created the `Dataset` (https://github.com/apache/airflow/pull/24613) about whether to create two classes or one. In that discussion @kaxil mentioned: > I would slightly favour a separate `DatasetModel` and `Dataset` so `Dataset` becomes an extensible class, and `DatasetModel` just stores the info about the class. So users don't need to care about SQLAlchmey stuff when extending it. That first PR created the `Dataset` model as both SQLAlchemy and user space class though. It wasn't until later on (https://github.com/apache/airflow/pull/25727) that the `DatasetModel` got broken out from `Dataset` and one became two. That provides a bit of background on why they both exist for anyone reading this who is curious. --- airflow/datasets/__init__.py | 9 +++++++++ tests/datasets/test_dataset.py | 19 +++++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/airflow/datasets/__init__.py b/airflow/datasets/__init__.py index a4a127e3f7af9..eaa25d0a30c35 100644 --- a/airflow/datasets/__init__.py +++ b/airflow/datasets/__init__.py @@ -46,3 +46,12 @@ def _check_uri(self, attr, uri: str): def __fspath__(self): return self.uri + + def __eq__(self, other): + if isinstance(other, self.__class__): + return self.uri == other.uri + else: + return NotImplemented + + def __hash__(self): + return hash(self.uri) diff --git a/tests/datasets/test_dataset.py b/tests/datasets/test_dataset.py index dfc8b82ba1596..e10264b0e2490 100644 --- a/tests/datasets/test_dataset.py +++ b/tests/datasets/test_dataset.py @@ -70,6 +70,25 @@ def test_fspath(): assert os.fspath(dataset) == uri +def test_equal_when_same_uri(): + uri = "s3://example_dataset" + dataset1 = Dataset(uri=uri) + dataset2 = Dataset(uri=uri) + assert dataset1 == dataset2 + + +def test_not_equal_when_different_uri(): + dataset1 = Dataset(uri="s3://example_dataset") + dataset2 = Dataset(uri="s3://other_dataset") + assert dataset1 != dataset2 + + +def test_hash(): + uri = "s3://example_dataset" + dataset = Dataset(uri=uri) + hash(dataset) + + @pytest.mark.parametrize( "inputs, scenario, expected", [