diff --git a/package/MDAnalysis/analysis/base.py b/package/MDAnalysis/analysis/base.py index c7619203581..f38a24890eb 100644 --- a/package/MDAnalysis/analysis/base.py +++ b/package/MDAnalysis/analysis/base.py @@ -84,38 +84,37 @@ class in `scikit-learn`_. .. versionadded:: 2.0.0 """ + def _validate_key(self, key): - if key in dir(UserDict) or (key == "data" and self._dict_frozen): + if key in dir(self): raise AttributeError(f"'{key}' is a protected dictionary " "attribute") elif isinstance(key, str) and not key.isidentifier(): raise ValueError(f"'{key}' is not a valid attribute") - def __init__(self, **kwargs): + def __init__(self, *args, **kwargs): + kwargs = dict(*args, **kwargs) if "data" in kwargs.keys(): raise AttributeError(f"'data' is a protected dictionary attribute") - - self._dict_frozen = False - for key in kwargs: - self._validate_key(key) - super().__init__(**kwargs) - self._dict_frozen = True + self.__dict__["data"] = {} + self.update(kwargs) def __setitem__(self, key, item): self._validate_key(key) super().__setitem__(key, item) - def __setattr__(self, attr, value): - self._validate_key(attr) - super().__setattr__(attr, value) - - # attribute available as key - if self._dict_frozen and attr != "_dict_frozen": - super().__setitem__(attr, value) + __setattr__ = __setitem__ def __getattr__(self, attr): try: - return self.data[attr] + return self[attr] + except KeyError as err: + raise AttributeError("'Results' object has no " + f"attribute '{attr}'") from err + + def __delattr__(self, attr): + try: + del self[attr] except KeyError as err: raise AttributeError("'Results' object has no " f"attribute '{attr}'") from err diff --git a/testsuite/MDAnalysisTests/analysis/test_base.py b/testsuite/MDAnalysisTests/analysis/test_base.py index 71ea1aea3e3..9c4995e4c07 100644 --- a/testsuite/MDAnalysisTests/analysis/test_base.py +++ b/testsuite/MDAnalysisTests/analysis/test_base.py @@ -45,23 +45,24 @@ def test_get(self, results): assert results.a == results["a"] == 1 def test_no_attr(self, results): - with pytest.raises(AttributeError): + msg = "'Results' object has no attribute 'c'" + with pytest.raises(AttributeError, match=msg): results.c def test_set_attr(self, results): value = [1, 2, 3, 4] results.c = value - assert results.c == results["c"] == value + assert results.c is results["c"] is value def test_set_key(self, results): value = [1, 2, 3, 4] results["c"] = value - assert results.c == results["c"] == value + assert results.c is results["c"] is value @pytest.mark.parametrize('key', dir(UserDict) + ["data"]) def test_existing_dict_attr(self, results, key): msg = f"'{key}' is a protected dictionary attribute" - with pytest.raises(AttributeError, match=key): + with pytest.raises(AttributeError, match=msg): results[key] = None @pytest.mark.parametrize('key', dir(UserDict) + ["data"]) @@ -76,6 +77,70 @@ def test_weird_key(self, results, key): with pytest.raises(ValueError, match=msg): results[key] = None + def test_setattr_modify_item(self, results): + mylist = [1, 2] + mylist2 = [3, 4] + results.myattr = mylist + assert results.myattr is mylist + results["myattr"] = mylist2 + assert results.myattr is mylist2 + mylist2.pop(0) + assert len(results.myattr) == 1 + assert results.myattr is mylist2 + + def test_setitem_modify_item(self, results): + mylist = [1, 2] + mylist2 = [3, 4] + results["myattr"] = mylist + assert results.myattr is mylist + results.myattr = mylist2 + assert results.myattr is mylist2 + mylist2.pop(0) + assert len(results["myattr"]) == 1 + assert results["myattr"] is mylist2 + + def test_delattr(self, results): + assert hasattr(results, "a") + delattr(results, "a") + assert not hasattr(results, "a") + + def test_missing_delattr(self, results): + assert not hasattr(results, "d") + msg = "'Results' object has no attribute 'd'" + with pytest.raises(AttributeError, match=msg): + delattr(results, "d") + + def test_pop(self, results): + assert hasattr(results, "a") + results.pop("a") + assert not hasattr(results, "a") + + def test_update(self, results): + assert not hasattr(results, "spudda") + results.update({"spudda": "fett"}) + assert results.spudda == "fett" + + def test_update_data_fail(self, results): + msg = f"'data' is a protected dictionary attribute" + with pytest.raises(AttributeError, match=msg): + results.update({"data": 0}) + + @pytest.mark.parametrize("args, kwargs, length", [ + (({"darth": "tater"},), {}, 1), + ([], {"darth": "tater"}, 1), + (({"darth": "tater"},), {"yam": "solo"}, 2), + (({"darth": "tater"},), {"darth": "vader"}, 1), + ]) + def test_initialize_arguments(self, args, kwargs, length): + results = base.Results(*args, **kwargs) + ref = dict(*args, **kwargs) + assert ref == results + assert len(results) == length + + def test_different_instances(self, results): + new_results = base.Results(darth="tater") + assert new_results.data is not results.data + class FrameAnalysis(base.AnalysisBase): """Just grabs frame numbers of frames it goes over"""