diff --git a/torch_sim/math.py b/torch_sim/math.py index 850f22b6a..6e5e94c82 100644 --- a/torch_sim/math.py +++ b/torch_sim/math.py @@ -25,7 +25,7 @@ def torch_divmod(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch. return d, m -def expm_frechet( # noqa: PLR0915, C901 +def expm_frechet( # noqa: C901 A: torch.Tensor, E: torch.Tensor, method: str | None = None, @@ -36,6 +36,20 @@ def expm_frechet( # noqa: PLR0915, C901 Optimized for batched 3x3 matrices. Also handles single 3x3 matrices by auto-adding a batch dimension. + Method notes: + - ``SPS`` uses scaling-Pade-squaring for the matrix exponential and its + Frechet derivative. + - ``blockEnlarge`` uses the block matrix identity + exp([[A, E], [0, A]]) = [[exp(A), L_exp(A, E)], [0, exp(A)]]. + + References: + - Awad H. Al-Mohy and Nicholas J. Higham (2009), "Computing the Frechet + Derivative of the Matrix Exponential, with an Application to Condition + Number Estimation", SIAM J. Matrix Anal. Appl. 30(4):1639-1657. + https://doi.org/10.1137/080716426 + - Nicholas J. Higham (2008), "Functions of Matrices: Theory and + Computation", SIAM. (See the Frechet derivative block-matrix identity.) + Args: A: (B, 3, 3) or (3, 3) tensor. Matrix of which to take the matrix exponential. E: (B, 3, 3) or (3, 3) tensor. Matrix direction in which to take the Frechet @@ -77,9 +91,32 @@ def expm_frechet( # noqa: PLR0915, C901 raise ValueError("expected A to be (B, N, N)") return expm_frechet_block_enlarge(A, E) - if method != "SPS": - raise ValueError(f"Unknown {method=}") + if method == "SPS": + return expm_frechet_sps(A, E) + raise ValueError(f"Unknown {method=}") + + +def matrix_exp(A: torch.Tensor) -> torch.Tensor: + """Compute the matrix exponential of A using PyTorch's matrix_exp. + + Args: + A: Input matrix + + Returns: + torch.Tensor: Matrix exponential of A + """ + return torch.matrix_exp(A) + + +def expm_frechet_sps( + A: torch.Tensor, E: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + """SPS helper for Frechet derivative of exp(A) on 3x3 matrices. + SPS = scaling-Pade-squaring. This implementation follows the approach in: + Awad H. Al-Mohy and Nicholas J. Higham (2009), SIAM J. Matrix Anal. + Appl. 30(4):1639-1657. https://doi.org/10.1137/080716426 + """ # Handle unbatched 3x3 input by adding batch dimension unbatched = A.dim() == 2 if unbatched: @@ -154,22 +191,18 @@ def expm_frechet( # noqa: PLR0915, C901 return R, L -def matrix_exp(A: torch.Tensor) -> torch.Tensor: - """Compute the matrix exponential of A using PyTorch's matrix_exp. - - Args: - A: Input matrix - - Returns: - torch.Tensor: Matrix exponential of A - """ - return torch.matrix_exp(A) - - def expm_frechet_block_enlarge( A: torch.Tensor, E: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: - """Helper function for testing and profiling. + """Block-enlarge helper for Frechet derivative via matrix exponential. + + Builds M = [[A, E], [0, A]], computes exp(M), and extracts: + - exp(A) from the top-left block + - L_exp(A, E) from the top-right block + + Reference: + Nicholas J. Higham (2008), "Functions of Matrices: Theory and + Computation", SIAM. (Frechet derivative block-matrix identity.) Args: A: (B, N, N) Batch of input matrices.