Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 49 additions & 16 deletions torch_sim/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def torch_divmod(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.
return d, m


def expm_frechet( # noqa: PLR0915, C901
def expm_frechet( # noqa: C901
A: torch.Tensor,
E: torch.Tensor,
method: str | None = None,
Expand All @@ -36,6 +36,20 @@ def expm_frechet( # noqa: PLR0915, C901
Optimized for batched 3x3 matrices. Also handles single 3x3 matrices by
auto-adding a batch dimension.

Method notes:
- ``SPS`` uses scaling-Pade-squaring for the matrix exponential and its
Frechet derivative.
- ``blockEnlarge`` uses the block matrix identity
exp([[A, E], [0, A]]) = [[exp(A), L_exp(A, E)], [0, exp(A)]].

References:
- Awad H. Al-Mohy and Nicholas J. Higham (2009), "Computing the Frechet
Derivative of the Matrix Exponential, with an Application to Condition
Number Estimation", SIAM J. Matrix Anal. Appl. 30(4):1639-1657.
https://doi.org/10.1137/080716426
- Nicholas J. Higham (2008), "Functions of Matrices: Theory and
Computation", SIAM. (See the Frechet derivative block-matrix identity.)

Args:
A: (B, 3, 3) or (3, 3) tensor. Matrix of which to take the matrix exponential.
E: (B, 3, 3) or (3, 3) tensor. Matrix direction in which to take the Frechet
Expand Down Expand Up @@ -77,9 +91,32 @@ def expm_frechet( # noqa: PLR0915, C901
raise ValueError("expected A to be (B, N, N)")
return expm_frechet_block_enlarge(A, E)

if method != "SPS":
raise ValueError(f"Unknown {method=}")
if method == "SPS":
return expm_frechet_sps(A, E)
raise ValueError(f"Unknown {method=}")


def matrix_exp(A: torch.Tensor) -> torch.Tensor:
"""Compute the matrix exponential of A using PyTorch's matrix_exp.

Args:
A: Input matrix

Returns:
torch.Tensor: Matrix exponential of A
"""
return torch.matrix_exp(A)


def expm_frechet_sps(
A: torch.Tensor, E: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""SPS helper for Frechet derivative of exp(A) on 3x3 matrices.

SPS = scaling-Pade-squaring. This implementation follows the approach in:
Awad H. Al-Mohy and Nicholas J. Higham (2009), SIAM J. Matrix Anal.
Appl. 30(4):1639-1657. https://doi.org/10.1137/080716426
"""
# Handle unbatched 3x3 input by adding batch dimension
unbatched = A.dim() == 2
if unbatched:
Expand Down Expand Up @@ -154,22 +191,18 @@ def expm_frechet( # noqa: PLR0915, C901
return R, L


def matrix_exp(A: torch.Tensor) -> torch.Tensor:
"""Compute the matrix exponential of A using PyTorch's matrix_exp.

Args:
A: Input matrix

Returns:
torch.Tensor: Matrix exponential of A
"""
return torch.matrix_exp(A)


def expm_frechet_block_enlarge(
A: torch.Tensor, E: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""Helper function for testing and profiling.
"""Block-enlarge helper for Frechet derivative via matrix exponential.

Builds M = [[A, E], [0, A]], computes exp(M), and extracts:
- exp(A) from the top-left block
- L_exp(A, E) from the top-right block

Reference:
Nicholas J. Higham (2008), "Functions of Matrices: Theory and
Computation", SIAM. (Frechet derivative block-matrix identity.)

Args:
A: (B, N, N) Batch of input matrices.
Expand Down