-
Notifications
You must be signed in to change notification settings - Fork 0
Feat: Higher-order derivatives and aligning with PyTorch #6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…attr, 'detach' func
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces Tensor‐based implementations in backward passes to enable higher‐order differentiation, reorganizes mathematical functions into a dedicated module, and updates API type hints and documentation for closer alignment with PyTorch.
- Changes include refactoring core functions (e.g., backward methods now returning Tensors), moving math operations to torch/functions.py, and enhancing tests and documentation.
- Updated files: torch/types.py, torch/functions.py, torch/init.py, tests (adjustments and removals), pyproject.toml, docs/roadmap.md, and README.md.
Reviewed Changes
Copilot reviewed 14 out of 14 changed files in this pull request and generated no comments.
Show a summary per file
| File | Description |
|---|---|
| torch/types.py | Added explicit dtype NewType definitions, np2torch/torch2np mappings, and conversion functions. |
| torch/functions.py | Introduced new mathematical function classes (Sin, Cos, etc.) using Tensor-based implementations. |
| torch/init.py | Reorganized imports to reflect the move of math functions and added new factory/utility functions. |
| tests/torch/utils.py | Added tests for numerical differentiation and utility functions using Tensor internal data. |
| tests/torch/test_functions.py | Extended tests for verifying backward passes for new math functions. |
| tests/torch/complex_funcs.py | Removed duplicate implementations now replaced by torch/functions.py. |
| pyproject.toml | Added matplotlib dependency for plotting derivative graphs. |
| docs/roadmap.md | Introduced a placeholder roadmap file. |
| README.md | Updated documentation and examples to include new operations and API features. |
Comments suppressed due to low confidence (3)
torch/functions.py:13
- [nitpick] Consider explicitly importing and referencing the cosine function (e.g., 'from torch.functions import cos') to make the dependency clear and avoid potential issues with function resolution.
gx: Tensor = gy * cos(x) # type: ignore
torch/functions.py:24
- [nitpick] Similarly, explicitly import and reference the sine function (e.g., 'from torch.functions import sin') to clarify its use in the backward computation.
gx: Tensor = -gy * sin(x) # type: ignore
tests/torch/utils.py:46
- [nitpick] Consider using a public API for accessing tensor data rather than directly accessing the '_data' attribute, to prevent potential issues if the internal representation changes.
x_data: np.ndarray = x._data
Higher-Order Differentiation Enablement:
backwardmethods of fundamentalFunctionsubclasses (Add,Mul,Neg,Sub,Div,Pow) intorch/core.py.backwardmethods acceptednp.ndarraygradient inputs (gy) and computednp.ndarrayoutput gradients (gx).Tensorobjects as input gradients (gy) and returnTensorobjects (gx).backwardare now accessed directly asTensorobjects (e.g.,x = self.inputs[0]) instead of accessing their underlying._data.gy) asTensorobjects within thebackwardpass, these gradients become part of the computational graph themselves. This allows us to differentiate the backward pass computation, enabling the calculation of second-order (and potentially higher) derivatives.Module Reorganization:
functions.py: Mathematical functions (Sin,Cos,Square,Exp,Tanh) and their user-facing functional wrappers (sin,cos, etc.) were moved fromtorch/core.pyto a newtorch/functions.pyfile. This improves modularity and separates core autograd logic from specific mathematical operations.backwardmethods within these moved functions were also updated to use the Tensor-based approach described above (e.g.,Sin.backwardnow usescos(x)which returns a Tensor).API Enhancements and PyTorch Alignment:
torch/core.py(likeadd,mul,pow, etc.) now consistently use thetorch.INPUT_TYPEalias fromtorch/types.pyinstead of the previoustypes.INPUT_TYPE.add,mul,neg,sub,div,pow,sin,cos,square,exp,tanhlinking to the corresponding function in the official PyTorch documentation for clarity.torch.tensor,torch.ones, andtorch.ones_likefunctions intorch/core.pyto mimic the standard PyTorch tensor creation API.as_tensor: Ensures an input is aTensor, simplifying function inputs.set_logging_level: Provides control over log verbosity usingloguru.using_config: A general context manager to temporarily modify global configuration settings (likeConfig.enable_backprop).no_grad: A specific context manager usingusing_configto disable gradient calculations within its scope, mirroringtorch.no_grad().Type System (
torch/types.py):torch.int32,torch.int64,torch.float32,torch.float64usingtyping.NewTypewrapping the correspondingnptypes. This creates distinct types for the library while leveraging NumPy's underlying representations.np2torch,torch2np) and functions (type_np2torch,type_torch2np) to explicitly map and convert betweennumpy.dtypeobjects and the newtorchdtype objects, ensuring correct type handling during Tensor creation and operations.