diff --git a/tests/unit/factored_matrix/test_multiply_by_tensor_like.py b/tests/unit/factored_matrix/test_multiply_by_tensor_like.py new file mode 100644 index 000000000..e9482b3e3 --- /dev/null +++ b/tests/unit/factored_matrix/test_multiply_by_tensor_like.py @@ -0,0 +1,109 @@ +"""Tests that FactoredMatrix matmul works with tensor-like objects. + +A "tensor-like" object is one that quacks like a torch.Tensor (supports the +operations FactoredMatrix needs — .ndim, .size, .unsqueeze, __matmul__) but +isn't a torch.Tensor subclass. This is useful for things like jaxtyping wrappers +or custom array types. +""" + +from torch import randn +from torch.testing import assert_close + +from transformer_lens import FactoredMatrix + + +class TensorLike: + """A wrapper that exposes the tensor protocol without subclassing torch.Tensor. + + Implements just enough of the protocol that FactoredMatrix can multiply + with it: matmul, ndim, size, shape, unsqueeze, squeeze, broadcast_to. + """ + + def __init__(self, tensor): + self._tensor = tensor + + @property + def ndim(self): + return self._tensor.ndim + + @property + def shape(self): + return self._tensor.shape + + def size(self, dim=None): + return self._tensor.size() if dim is None else self._tensor.size(dim) + + def unsqueeze(self, dim): + return TensorLike(self._tensor.unsqueeze(dim)) + + def squeeze(self, dim): + return TensorLike(self._tensor.squeeze(dim)) + + def broadcast_to(self, shape): + return TensorLike(self._tensor.broadcast_to(shape)) + + def __matmul__(self, other): + if isinstance(other, FactoredMatrix): + # Defer to FactoredMatrix.__rmatmul__ so the result is a FactoredMatrix + return NotImplemented + if isinstance(other, TensorLike): + return TensorLike(self._tensor @ other._tensor) + return TensorLike(self._tensor @ other) + + def __rmatmul__(self, other): + if isinstance(other, TensorLike): + return TensorLike(other._tensor @ self._tensor) + return TensorLike(other @ self._tensor) + + +def test_left_multiply_factored_matrix_by_tensor_like_matrix(): + """factored_matrix @ tensor_like_matrix should not return None.""" + a = randn(2, 3) + b = randn(3, 4) + matrix = randn(4, 5) + factored_matrix = FactoredMatrix(a, b) + + result = factored_matrix @ TensorLike(matrix) + + assert result is not None, "matmul with tensor-like silently returned None" + assert isinstance(result, FactoredMatrix) + expected = (a @ b) @ matrix + assert isinstance(result.AB, TensorLike) + assert_close(result.AB._tensor, expected) + + +def test_right_multiply_factored_matrix_by_tensor_like_matrix(): + """tensor_like_matrix @ factored_matrix should not return None.""" + a = randn(3, 4) + b = randn(4, 6) + matrix = randn(5, 3) + factored_matrix = FactoredMatrix(a, b) + + result = TensorLike(matrix) @ factored_matrix + + assert result is not None, "rmatmul with tensor-like silently returned None" + assert isinstance(result, FactoredMatrix) + expected = matrix @ (a @ b) + assert isinstance(result.AB, TensorLike) + assert_close(result.AB._tensor, expected) + + +def test_left_multiply_factored_matrix_by_tensor_like_vector(): + """factored_matrix @ tensor_like_vector should dispatch through the vector path. + + The vector branch of FactoredMatrix.__matmul__ collapses to a single tensor + via unsqueeze/squeeze rather than wrapping in a new FactoredMatrix. This test + exercises that path and verifies the TensorLike protocol methods (unsqueeze, + squeeze, __rmatmul__) are correctly invoked. + """ + a = randn(2, 3) + b = randn(3, 4) + vector = randn(4) + factored_matrix = FactoredMatrix(a, b) + + result = factored_matrix @ TensorLike(vector) + + # The fix's core guarantee: the dispatch produces a result instead of None + assert isinstance(result, TensorLike) + expected = (a @ b) @ vector + assert_close(result._tensor, expected) diff --git a/transformer_lens/FactoredMatrix.py b/transformer_lens/FactoredMatrix.py index 2f69220df..0c7ce3610 100644 --- a/transformer_lens/FactoredMatrix.py +++ b/transformer_lens/FactoredMatrix.py @@ -7,7 +7,7 @@ from __future__ import annotations from functools import lru_cache -from typing import List, Tuple, Union, overload +from typing import Any, List, Protocol, Tuple, Union, cast, overload, runtime_checkable import torch from jaxtyping import Complex, Float @@ -15,6 +15,39 @@ import transformer_lens.utilities.tensors as tensor_utils +@runtime_checkable +class TensorLike(Protocol): + """Minimal tensor protocol that FactoredMatrix accepts in place of torch.Tensor. + + Allows duck-typed inputs (e.g. jaxtyping wrappers, custom array types) that + aren't torch.Tensor subclasses but support the operations FactoredMatrix uses + when constructing, multiplying, and broadcasting its A and B factors. + """ + + @property + def ndim(self) -> int: + ... + + @property + def shape(self) -> Any: + ... + + def size(self, dim: int) -> int: + ... + + def unsqueeze(self, dim: int) -> Any: + ... + + def squeeze(self, dim: int) -> Any: + ... + + def broadcast_to(self, shape: Any) -> Any: + ... + + def __matmul__(self, other: Any) -> Any: + ... + + class FactoredMatrix: """ Class to represent low rank factored matrices, where the matrix is represented as a product of two matrices. Has utilities for efficient calculation of eigenvalues, norm and SVD. @@ -22,11 +55,21 @@ class FactoredMatrix: def __init__( self, - A: Float[torch.Tensor, "... ldim mdim"], - B: Float[torch.Tensor, "... mdim rdim"], + A: Union[Float[torch.Tensor, "... ldim mdim"], TensorLike], + B: Union[Float[torch.Tensor, "... mdim rdim"], TensorLike], ): - self.A = A - self.B = B + """Construct a FactoredMatrix from factors A and B. + + A and B may be torch.Tensor or TensorLike duck types. TensorLike inputs + are only fully supported by matmul-family operations (``@``, ``AB``, + ``BA``); operations like ``svd()``, ``norm()``, ``transpose()``, + ``__getitem__``, and eigenvalue methods require both factors to be + actual torch.Tensor and will raise AttributeError on TensorLike inputs. + """ + # Cast to Tensor for type-checker purposes. At runtime A and B may be + # TensorLike duck types; the class methods trust the protocol. + self.A: torch.Tensor = cast(torch.Tensor, A) + self.B: torch.Tensor = cast(torch.Tensor, B) assert self.A.size(-1) == self.B.size( -2 ), f"Factored matrix must match on inner dimension, shapes were a: {self.A.shape}, b:{self.B.shape}" @@ -74,9 +117,12 @@ def __matmul__( Float[torch.Tensor, "... rdim new_rdim"], Float[torch.Tensor, "rdim"], "FactoredMatrix", + TensorLike, ], - ) -> Union["FactoredMatrix", Float[torch.Tensor, "... ldim"]]: - if isinstance(other, torch.Tensor): + ) -> Union["FactoredMatrix", Float[torch.Tensor, "... ldim"], TensorLike]: + if isinstance(other, FactoredMatrix): + return (self @ other.A) @ other.B + else: if other.ndim < 2: # It's a vector, so we collapse the factorisation and just return a vector # Squeezing/Unsqueezing is to preserve broadcasting working nicely @@ -86,11 +132,11 @@ def __matmul__( other.size(-2) == self.rdim ), f"Right matrix must match on inner dimension, shapes were self: {self.shape}, other:{other.shape}" if self.rdim > self.mdim: - return FactoredMatrix(self.A, self.B @ other) + # other is Tensor or TensorLike; runtime delegates to + # the appropriate __matmul__/__rmatmul__ overload. + return FactoredMatrix(self.A, self.B @ cast(torch.Tensor, other)) else: return FactoredMatrix(self.AB, other) - elif isinstance(other, FactoredMatrix): - return (self @ other.A) @ other.B @overload def __rmatmul__( # type: ignore @@ -115,9 +161,12 @@ def __rmatmul__( # type: ignore Float[torch.Tensor, "... new_rdim ldim"], Float[torch.Tensor, "ldim"], "FactoredMatrix", + TensorLike, ], - ) -> Union["FactoredMatrix", Float[torch.Tensor, "... rdim"]]: - if isinstance(other, torch.Tensor): + ) -> Union["FactoredMatrix", Float[torch.Tensor, "... rdim"], TensorLike]: + if isinstance(other, FactoredMatrix): + return other.A @ (other.B @ self) + else: assert ( other.size(-1) == self.ldim ), f"Left matrix must match on inner dimension, shapes were self: {self.shape}, other:{other.shape}" @@ -128,8 +177,6 @@ def __rmatmul__( # type: ignore return FactoredMatrix(other @ self.A, self.B) else: return FactoredMatrix(other, self.AB) - elif isinstance(other, FactoredMatrix): - return other.A @ (other.B @ self) def __mul__(self, scalar: Union[int, float, torch.Tensor]) -> FactoredMatrix: """ @@ -148,8 +195,11 @@ def __rmul__(self, scalar: Union[int, float, torch.Tensor]) -> FactoredMatrix: return self * scalar @property - def AB(self) -> Float[torch.Tensor, "*leading_dims ldim rdim"]: - """The product matrix - expensive to compute, and can consume a lot of GPU memory""" + def AB(self) -> Union[Float[torch.Tensor, "*leading_dims ldim rdim"], TensorLike]: + """The product matrix - expensive to compute, and can consume a lot of GPU memory. + + Returns a TensorLike when A or B is a non-Tensor TensorLike duck type. + """ return self.A @ self.B @property