Cache model loading in model card#299
Conversation
|
Thank you for taking this issue. I'm unsure about the implementation here, which I admit is not a trivial matter. If I understand correctly, you're using an To me, this seems to be a bit "hacky" and I would like to suggest a different approach. In my suggestion, we would leave I did a quick and dirty implementation of how it could look like (in ...
from hashlib import sha256
from functools import cached_property
...
class Card:
def __init__(...):
...
self._model_hash = ""
self._populate_template()
def get_model(self) -> Any:
"""..."""
if isinstance(self.model, (str, Path)) and hasattr(self, "_model"):
hash_obj = sha256()
buf_size = 2 ** 20 # load in chunks to save memory
with open(self.model, "rb") as f:
for chunk in iter(lambda: f.read(buf_size), b""):
hash_obj.update(chunk)
model_hash = hash_obj.hexdigest()
# if hash changed, invalidate cache by deleting attribute
if model_hash != self._model_hash:
del self._model
self._model_hash = model_hash
return self._model
@cached_property
def _model(self):
model = _load_model(self.model, self.trusted)
return modelWhat do you think about that? Of course, it would require some comments and tests, but I hope you get the general idea.
It's okay, the diff is still shown correctly, right? However, please make sure to correct this for the next PR. |
|
I understand your approach and I agree that it fits better conceptually. I took your implementation because it works out of the box, no errors. I wrote a test for it for which I would appreciate some feedback. Thanks a lot for the help on this PR! |
|
Thanks for the updates. I haven't done a proper review yet, but I saw that some changes were unrelated to the additions of this PR. Could you please clean those up? Maybe those were changed by your IDE automatically? Also, it seems that there are Finally, the docs are not building. I think it's the same issue as in #207, so whatever fixes that should work here too. |
|
Thank you for your comments. I'm in the process of fixing the docs error, I have asked a question about that in #207 |
|
There's a merge conflict here, and the CI hasn't run completely somehow, could you please merge with |
BenjaminBossan
left a comment
There was a problem hiding this comment.
Thanks for updating the PR. It still shows a lot of unrelated changes in the diff, could you please remove them?
Regarding the test, I have to admit I don't quite understand it. For example, what does this test?
assert str(card._model_hash) == card.__dict__["_model_hash"]
I think what I would like to see is a more high level test, i.e. nothing that involves any hashes, since those are implementation details. One way would be to mock _load_model and assert that, when card.get_model() is called, _load_model is only called once, the first time, and after that it's not called anymore. Then, only when the underlying model is overwritten, should it be called again. WDYT?
|
Hey! Sorry it's been a while. Now that I look back at the test I wrote I think that the line you highlighted is not testing anything, I'm not sure what I was thinking at that time. I agree with your proposal to check that everything works rather than checking the details as I was trying to do. I'm happy to update the test. I'm not sure how to get rid of all of the unrelated changes in the diff. I have modified 3 files of the 7, the other changes were applied after I ran pre-commit manually on all files. I will merge with |
Okay, thanks for clearing that up.
Let's see if that works. Otherwise, in the worst case, you could try opening a new PR based on the latest |
|
If we go about mocking |
Are you sure that we need |
…nto cache-model-loading
|
Ah okay! I have implemented it using I have also got rid of the files in the diff that weren't supposed to be there. I reverted the changes of those files. |
There was a problem hiding this comment.
Well done with reverting the changes, this is now much easier to review, thanks.
There isn't much work left. Regarding the test, I think it can be improved a bit, please take a look at my suggestion. Other than that, please add an entry to docs/changes.rst. Then this should be good to go.
The failing CI job is unrelated to this PR, so please ignore it.
| # _load_model get called | ||
| card = Card(iris_skops_file, metadata=metadata_from_config(destination_path)) | ||
| with mock.patch("skops.card._model_card._load_model") as mock_load_model: | ||
| model1 = card.get_model() | ||
| model2 = card.get_model() | ||
| assert model1 is model2 | ||
| # model is cached, hence _load_model is not called | ||
| mock_load_model.assert_not_called() | ||
| # update card with new model | ||
| new_model = LogisticRegression() | ||
| _, save_file = save_model_to_file(new_model, ".skops") | ||
| del card.model | ||
| card.model = save_file | ||
| model3 = card.get_model() # model gets cached | ||
| model4 = card.get_model() | ||
| assert model3 is model4 | ||
| assert mock_load_model.call_count == 1 |
There was a problem hiding this comment.
I see the intent with this test, but I think it's problematic that del card.model and card.model = save_file are being used. As a skops user, I wouldn't do that and I would still expect the cached model loading to work correctly. Therefore, I made some changes to the test so that these lines are not needed:
| # _load_model get called | |
| card = Card(iris_skops_file, metadata=metadata_from_config(destination_path)) | |
| with mock.patch("skops.card._model_card._load_model") as mock_load_model: | |
| model1 = card.get_model() | |
| model2 = card.get_model() | |
| assert model1 is model2 | |
| # model is cached, hence _load_model is not called | |
| mock_load_model.assert_not_called() | |
| # update card with new model | |
| new_model = LogisticRegression() | |
| _, save_file = save_model_to_file(new_model, ".skops") | |
| del card.model | |
| card.model = save_file | |
| model3 = card.get_model() # model gets cached | |
| model4 = card.get_model() | |
| assert model3 is model4 | |
| assert mock_load_model.call_count == 1 | |
| new_model = LogisticRegression(random_state=4321) | |
| # mock _load_model, it still loads the model but we can track call count | |
| mock_load_model = mock.Mock(side_effect=load) | |
| card = Card(iris_skops_file, metadata=metadata_from_config(destination_path)) | |
| with mock.patch("skops.card._model_card._load_model", mock_load_model): | |
| model1 = card.get_model() | |
| model2 = card.get_model() | |
| assert model1 is model2 | |
| # model is cached, hence _load_model is not called | |
| mock_load_model.assert_not_called() | |
| # override model with new model | |
| dump(new_model, card.model) | |
| model3 = card.get_model() | |
| assert mock_load_model.call_count == 1 | |
| assert model3.random_state == 4321 | |
| model4 = card.get_model() | |
| assert model3 is model4 | |
| assert mock_load_model.call_count == 1 # cached call |
(line 3: load needs to be imported from skops.io)
This test is similar to yours but is closer to how a user would actually use the model card. Please take a look and see if you agree with me. It would also be good to have a comment at the start of the test to explain what is being tested here.
There was a problem hiding this comment.
I can see how the user would go with your approach first, rather than mine. Definitely, I agree with your suggestion.
There was a problem hiding this comment.
Tested your suggestion and it passes the test as expected. I also added a short comment to describe what the function tests at the beginning of it.
BenjaminBossan
left a comment
There was a problem hiding this comment.
Thx. This LGTM. @adrinjalali not sure if you want to review too, if not feel free to merge.
adrinjalali
left a comment
There was a problem hiding this comment.
just a nit, otherwise LGTM.
|
|
||
| def get_model(self) -> Any: | ||
| """Returns sklearn estimator object. | ||
|
|
|
|
||
| If the ``model`` is already loaded, return it as is. If the ``model`` | ||
| attribute is a ``Path``/``str``, load the model and return it. | ||
|
|
There was a problem hiding this comment.
Reverted them! sorry about that, I will pay attention to that next time.
Implementation of cache model loading discussed in issue #243
This PR includes the following changes:
_load_modelfunction to implement model cachinghash_modelthat is used as a decorator on_load_modelfunctionlru_cachedecorator to implement caching on top ofhash_modeland_load_modelfunctionstest_load_modelto test for cache model loadingtest_hash_modeltest implemented to testhash_modelfunction