Skip to content

Conversation

@Advaitgaur004
Copy link
Contributor

This PR tackles an important bug in ctensor, making backpropagation through sum and mean operations correct.

What Was Happening?

Previously, calling Tensor_backward() on a tensor that was the result of a Tensor_sum or Tensor_mean (with a dim argument) would crash. This was because the upstream gradient had a "squashed" shape (e.g., {2}) while the local gradient had the original, larger shape (e.g., {2, 3}), making them incompatible for the chain rule's multiplication step.

Fix:

  • The Tensor_backward function now dynamically compares the input and output shapes of a reduction operation to deduce which dimension was removed.
  • The "Unsqueeze" Trick: Once the reduced dimension is found, I "unsqueeze" the upstream gradient (e.g., from {2} to {2, 1}), making it ready for broadcasting.
  • Fixing the mean Derivative: I also found and fixed a bug in GradFn_mean where it was incorrectly calculating the gradient value based on the output tensor's size instead of the input's. It now correctly computes the 1/N gradient.

As a result, the sum and mean tests, including chained operations with nn_linear and nn_relu, are now all passing.

- Once the reduced dimension is found, we "unsqueeze" the upstream gradient (e.g., from {2} to {2, 1}), making it ready for broadcasting. This is a clean, zero-copy operation that solves the shape mismatch elegantly.
- fixed a bug in GradFn_mean where it was incorrectly calculating the gradient value based on the output tensor's size instead of the input's

- It now correctly computes the 1/N gradient.
- Else block(fallback) This code runs when:
for scalar input only Logically.

Example : Scalar Input
Input tensor: shape [1] (1 element)
Output tensor: shape [1] (1 element)
Again, input_ndim (1) == output_ndim (1)
divisor = 1
@PrimedErwin PrimedErwin merged commit cd2a310 into pocketpy:test Jul 5, 2025
5 checks passed
@Advaitgaur004 Advaitgaur004 deleted the backward-3 branch July 5, 2025 15:11
@Advaitgaur004 Advaitgaur004 changed the title Fix: Stabilize Backpropagation for Sum and Mean [Fix] : Stabilize Backpropagation for Sum and Mean Aug 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants