Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
name: ci

on:
push:
branches: [main]
pull_request:
branches: [main]
workflow_dispatch:

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

env:
CARGO_TERM_COLOR: always
CARGO_INCREMENTAL: 0

jobs:
fmt:
name: cargo fmt
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable
with:
components: rustfmt
- run: cargo fmt --all -- --check

clippy:
name: cargo clippy
runs-on: ubuntu-latest
needs: fmt
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable
with:
components: clippy
- uses: Swatinem/rust-cache@v2
- run: cargo clippy --all-targets -- -D warnings

test:
name: cargo test
runs-on: ubuntu-latest
needs: fmt
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@stable
- uses: Swatinem/rust-cache@v2
- run: cargo test --all-targets
30 changes: 20 additions & 10 deletions src/dynamics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,11 @@ mod tests {
let dt = 0.1;
let (np, nv) = euler_step(&p, &v, &a, dt);
// Explicit Euler: position uses old velocity (0), so position unchanged.
approx(np.value.data[1], 0.0, "explicit Euler position(y) at t=0+dt");
approx(
np.value.data[1],
0.0,
"explicit Euler position(y) at t=0+dt",
);
// Velocity gains -9.81 * dt = -0.981.
approx(nv.value.data[1], -0.981, "vy");
}
Expand Down Expand Up @@ -538,9 +542,8 @@ mod tests {
let force = t.var(Tensor::from_data(vec![0.0, -g], vec![2]));
let torque = t.var(Tensor::from_data(vec![0.0], vec![]));
let dt = 0.1;
let (np, nv, _na, _no) = rigid_body_step_2d(
&pos, &vel, &angle, &omega, 1.0, 1.0, &force, &torque, dt,
);
let (np, nv, _na, _no) =
rigid_body_step_2d(&pos, &vel, &angle, &omega, 1.0, 1.0, &force, &torque, dt);
approx(nv.value.data[1], -g * dt, "vy_new");
approx(np.value.data[1], -g * dt * dt, "y_new");
}
Expand All @@ -557,7 +560,15 @@ mod tests {
let inv_inertia = 0.5; // I = 2.0
let dt = 0.1;
let (_np, _nv, na, no) = rigid_body_step_2d(
&pos, &vel, &angle, &omega, 1.0, inv_inertia, &force, &torque, dt,
&pos,
&vel,
&angle,
&omega,
1.0,
inv_inertia,
&force,
&torque,
dt,
);
// omega_new = 0 + tau * (1/I) * dt = 5 * 0.5 * 0.1 = 0.25
approx(no.value.data[0], 0.25, "omega_new");
Expand Down Expand Up @@ -591,9 +602,9 @@ mod tests {
for &x in &grads[vel.idx()].data {
approx(x, dt, "dL/dvel");
}
let expected_dF = dt * dt * inv_mass;
let expected_d_f = dt * dt * inv_mass;
for &x in &grads[force.idx()].data {
approx(x, expected_dF, "dL/dF");
approx(x, expected_d_f, "dL/dF");
}
}

Expand All @@ -610,9 +621,8 @@ mod tests {
let torque = t.var(Tensor::from_data(vec![0.0], vec![]));
let dt = 0.1;
for _ in 0..5 {
let (np, nv, na, no) = rigid_body_step_2d(
&pos, &vel, &angle, &omega, 1.0, 1.0, &force, &torque, dt,
);
let (np, nv, na, no) =
rigid_body_step_2d(&pos, &vel, &angle, &omega, 1.0, 1.0, &force, &torque, dt);
pos = np;
vel = nv;
angle = na;
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ pub use dynamics::{
semi_implicit_euler_step,
};
pub use tape::{Tape, Var};
pub use tensor::{Tensor, TensorTape, TVar};
pub use tensor::{TVar, Tensor, TensorTape};
14 changes: 2 additions & 12 deletions src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -427,10 +427,7 @@ impl TensorTape {
grads[lhs].add_in_place(&dlhs);
grads[rhs].add_in_place(&drhs);
}
Op::Sum {
input,
input_shape,
} => {
Op::Sum { input, input_shape } => {
// g is scalar; broadcast to input_shape.
let gs = g.data[0];
let dinput = Tensor::full(input_shape, gs);
Expand Down Expand Up @@ -638,14 +635,7 @@ mod tests {
fn approx_tensor(a: &Tensor, b: &Tensor, eps: f64, ctx: &str) {
assert_eq!(a.shape, b.shape, "{}: shape mismatch", ctx);
for (i, (x, y)) in a.data.iter().zip(b.data.iter()).enumerate() {
assert!(
(x - y).abs() < eps,
"{}: data[{}] = {} vs {}",
ctx,
i,
x,
y
);
assert!((x - y).abs() < eps, "{}: data[{}] = {} vs {}", ctx, i, x, y);
}
}

Expand Down
Loading