Skip to content

Conversation

@ITHwang
Copy link
Owner

@ITHwang ITHwang commented Apr 22, 2025

  1. Higher-Order Differentiation Enablement:

    • Tensor-based Backward Pass: The core change involves modifying the backward methods of fundamental Function subclasses (Add, Mul, Neg, Sub, Div, Pow) in torch/core.py.
      • Previously, backward methods accepted np.ndarray gradient inputs (gy) and computed np.ndarray output gradients (gx).
      • They now accept Tensor objects as input gradients (gy) and return Tensor objects (gx).
      • Inputs within backward are now accessed directly as Tensor objects (e.g., x = self.inputs[0]) instead of accessing their underlying ._data.
    • Why this enables Higher-order derivatives: By treating gradients (gy) as Tensor objects within the backward pass, these gradients become part of the computational graph themselves. This allows us to differentiate the backward pass computation, enabling the calculation of second-order (and potentially higher) derivatives.
  2. Module Reorganization:

    • Dedicated functions.py: Mathematical functions (Sin, Cos, Square, Exp, Tanh) and their user-facing functional wrappers (sin, cos, etc.) were moved from torch/core.py to a new torch/functions.py file. This improves modularity and separates core autograd logic from specific mathematical operations.
    • The backward methods within these moved functions were also updated to use the Tensor-based approach described above (e.g., Sin.backward now uses cos(x) which returns a Tensor).
  3. API Enhancements and PyTorch Alignment:

    • Type Hint Consistency: Function signatures in torch/core.py (like add, mul, pow, etc.) now consistently use the torch.INPUT_TYPE alias from torch/types.py instead of the previous types.INPUT_TYPE.
    • Docstrings: Added docstrings to add, mul, neg, sub, div, pow, sin, cos, square, exp, tanh linking to the corresponding function in the official PyTorch documentation for clarity.
    • Factory Functions: Introduced torch.tensor, torch.ones, and torch.ones_like functions in torch/core.py to mimic the standard PyTorch tensor creation API.
    • Utility Functions:
      • as_tensor: Ensures an input is a Tensor, simplifying function inputs.
      • set_logging_level: Provides control over log verbosity using loguru.
    • Configuration Context Managers:
      • using_config: A general context manager to temporarily modify global configuration settings (like Config.enable_backprop).
      • no_grad: A specific context manager using using_config to disable gradient calculations within its scope, mirroring torch.no_grad().
  4. Type System (torch/types.py):

    • Explicit Torch Dtypes: Defined torch.int32, torch.int64, torch.float32, torch.float64 using typing.NewType wrapping the corresponding np types. This creates distinct types for the library while leveraging NumPy's underlying representations.
    • NumPy/Torch Dtype Conversion: Added dictionaries (np2torch, torch2np) and functions (type_np2torch, type_torch2np) to explicitly map and convert between numpy.dtype objects and the new torch dtype objects, ensuring correct type handling during Tensor creation and operations.

@ITHwang ITHwang requested a review from Copilot April 22, 2025 10:19
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces Tensor‐based implementations in backward passes to enable higher‐order differentiation, reorganizes mathematical functions into a dedicated module, and updates API type hints and documentation for closer alignment with PyTorch.

  • Changes include refactoring core functions (e.g., backward methods now returning Tensors), moving math operations to torch/functions.py, and enhancing tests and documentation.
  • Updated files: torch/types.py, torch/functions.py, torch/init.py, tests (adjustments and removals), pyproject.toml, docs/roadmap.md, and README.md.

Reviewed Changes

Copilot reviewed 14 out of 14 changed files in this pull request and generated no comments.

Show a summary per file
File Description
torch/types.py Added explicit dtype NewType definitions, np2torch/torch2np mappings, and conversion functions.
torch/functions.py Introduced new mathematical function classes (Sin, Cos, etc.) using Tensor-based implementations.
torch/init.py Reorganized imports to reflect the move of math functions and added new factory/utility functions.
tests/torch/utils.py Added tests for numerical differentiation and utility functions using Tensor internal data.
tests/torch/test_functions.py Extended tests for verifying backward passes for new math functions.
tests/torch/complex_funcs.py Removed duplicate implementations now replaced by torch/functions.py.
pyproject.toml Added matplotlib dependency for plotting derivative graphs.
docs/roadmap.md Introduced a placeholder roadmap file.
README.md Updated documentation and examples to include new operations and API features.
Comments suppressed due to low confidence (3)

torch/functions.py:13

  • [nitpick] Consider explicitly importing and referencing the cosine function (e.g., 'from torch.functions import cos') to make the dependency clear and avoid potential issues with function resolution.
gx: Tensor = gy * cos(x)  # type: ignore

torch/functions.py:24

  • [nitpick] Similarly, explicitly import and reference the sine function (e.g., 'from torch.functions import sin') to clarify its use in the backward computation.
gx: Tensor = -gy * sin(x)  # type: ignore

tests/torch/utils.py:46

  • [nitpick] Consider using a public API for accessing tensor data rather than directly accessing the '_data' attribute, to prevent potential issues if the internal representation changes.
x_data: np.ndarray = x._data

@ITHwang ITHwang merged commit 91fda26 into main Apr 22, 2025
1 check passed
@ITHwang ITHwang deleted the feat/higher-order-diff branch April 22, 2025 10:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants