diff --git a/airflow/datasets/__init__.py b/airflow/datasets/__init__.py index a4a127e3f7af9..eaa25d0a30c35 100644 --- a/airflow/datasets/__init__.py +++ b/airflow/datasets/__init__.py @@ -46,3 +46,12 @@ def _check_uri(self, attr, uri: str): def __fspath__(self): return self.uri + + def __eq__(self, other): + if isinstance(other, self.__class__): + return self.uri == other.uri + else: + return NotImplemented + + def __hash__(self): + return hash(self.uri) diff --git a/tests/datasets/test_dataset.py b/tests/datasets/test_dataset.py index dfc8b82ba1596..e10264b0e2490 100644 --- a/tests/datasets/test_dataset.py +++ b/tests/datasets/test_dataset.py @@ -70,6 +70,25 @@ def test_fspath(): assert os.fspath(dataset) == uri +def test_equal_when_same_uri(): + uri = "s3://example_dataset" + dataset1 = Dataset(uri=uri) + dataset2 = Dataset(uri=uri) + assert dataset1 == dataset2 + + +def test_not_equal_when_different_uri(): + dataset1 = Dataset(uri="s3://example_dataset") + dataset2 = Dataset(uri="s3://other_dataset") + assert dataset1 != dataset2 + + +def test_hash(): + uri = "s3://example_dataset" + dataset = Dataset(uri=uri) + hash(dataset) + + @pytest.mark.parametrize( "inputs, scenario, expected", [