Make the Leaf dynamics differentiable#12
Conversation
|
I've installed the package and ran the tests. All 10 of them pass. So that's a good start. Will look at the actual code next week. I also see ruff is complaining about ANN101 and ANN102 that seem to be configured somewhere to be ignored, while these rules no longer exist and are from older versions of ruff. I made a PR for this: #14 |
[ANN101](https://docs.astral.sh/ruff/rules/missing-type-self/) and [ANN102](https://docs.astral.sh/ruff/rules/missing-type-cls/) have been removed and referencing them in the config causes warnings to be printed when ruff is run. See also WUR-AI/diffWOFOST#12 (comment)
There was a problem hiding this comment.
I think this is a great PR! 🎖Very neat!
I made some comments. Some are suggestions for possible improvements and others are more points for you to think about (or discuss together) before you decide to merge.
I also created a PR with a tiny refactoring of a single file in your code. Was easier to do than to explain. See #15 . I think that should be merged into this branch before merging your whole PR into main.
I hope it is helpful for you!
refactor: extract model creation method
thanks! 👍 we can merge #14 |
Thanks for the comments and helpful suggestions. I addressed your comments:
|
| TWLV = Any(default_value=torch.tensor(-99.0, dtype=DTYPE)) | ||
|
|
||
| class RateVariables(RatesTemplate): | ||
| GRLV = Any(default_value=torch.tensor(0.0, dtype=DTYPE)) |
There was a problem hiding this comment.
in original these were initialized at -99.0
There was a problem hiding this comment.
You're right. In PCSE, class attributes default to -99.0, but rate variables are initialized to zero, see here. So in practice, those -99.0 values never actually get used.
| TBASE = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)]) | ||
| PERDL = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)]) | ||
| TDWI = Any(default_value=[torch.tensor(-99.0, dtype=DTYPE)]) | ||
| SLATB = AfgenTrait() # FIXEME |
There was a problem hiding this comment.
This PR was about two parameters of leaf_dynamics; TDWI and SPAN. In the next iteration, the rest will be fixed.
| key/value pairs | ||
| """ | ||
| self.kiosk = kiosk | ||
| # TODO check if parvalues are already torch.nn.Parameters |
There was a problem hiding this comment.
Is this TODO still active?
There was a problem hiding this comment.
yes, currently, there are no checks on the parameters since we are not using Traits anymore. This should be fixed in the next iterations.
| # CALCULATE INITIAL STATE VARIABLES | ||
| # check for required external variables | ||
| _exist_required_external_variables(self.kiosk) | ||
| # TODO check if external variables are already torch tensors |
There was a problem hiding this comment.
Is this TODO still active?
There was a problem hiding this comment.
yes, the same as above.
| self.states.LVAGE = tLVAGE | ||
|
|
||
| @prepare_states | ||
| def _set_variable_LAI(self, nLAI): # FIXEME |
There was a problem hiding this comment.
What is this FIXEME?
There was a problem hiding this comment.
This function, as a BMI function, might not be used in any application that requires differentiability. For now, I kept it, but will most probably be removed.
michielkallenberg
left a comment
There was a problem hiding this comment.
Very nice work! Great. I left some minor comments; feel very much free to ignore.
| from pcse.base.parameter_providers import ParameterProvider | ||
| from pcse.engine import Engine | ||
| from pcse.models import Wofost72_PP | ||
| from diffwofost.physical_models.crop.leaf_dynamics import WOFOST_Leaf_Dynamics |
There was a problem hiding this comment.
, WeatherDataProviderTestHelper
| span = torch.nn.Parameter(torch.tensor(30, dtype=torch.float64)) | ||
| numerical_grad = calculate_numerical_grad("SPAN", span, "TWLV") # this is Δloss/Δspan | ||
|
|
||
| assert numerical_grad == 0.0 |
There was a problem hiding this comment.
this test seems not entirely complete
There was a problem hiding this comment.
right, it is fixed now.
| assert_almost_equal(numerical_grad, grads.item(), decimal=3) | ||
|
|
||
|
|
||
| class TestDiffLeafDynamicsSPAN: |
There was a problem hiding this comment.
Conssider making this more generic; there is a bit of code repetion.
There was a problem hiding this comment.
I see your point. While there’s some duplication, it helps keep each test clear and maintainable on its own. I’d keep them as-is unless in a future iteration, we see a need for reuse.
| self.TERMINAL_OUTPUT_VARS = [] | ||
|
|
||
|
|
||
| class EngineTestHelper(Engine): |
There was a problem hiding this comment.
There's a bit of code repetition from the original class.
There was a problem hiding this comment.
That's correct. This file will be improved later.
Thanks for the review and comments/suggestions. |
relates #8