Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 29 additions & 20 deletions algorithms/linfa-linear/src/ols.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
//! Ordinary Least Squares
#![allow(non_snake_case)]
use crate::error::{LinearError, Result};
use ndarray::{Array1, Array2, ArrayBase, Axis, Data, Ix1, Ix2};
use ndarray_linalg::{Lapack, Scalar, Solve};
use ndarray::{Array1, Array2, ArrayBase, Axis, Data, DataMut, Ix1, Ix2};
use ndarray_linalg::{Lapack, LeastSquaresSvdInto, Scalar};
use ndarray_stats::SummaryStatisticsExt;
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -154,13 +154,17 @@ impl<F: Float, D: Data<Elem = F>, T: AsTargets<Elem = F>> Fit<ArrayBase<D, Ix2>,
let y_offset: F = y.mean().ok_or_else(|| LinearError::NotEnoughTargets)?;
let y_centered: Array1<F> = &y - y_offset;
let params: Array1<F> =
compute_params(&X_centered, &y_centered, self.options.should_normalize())?;
compute_params(X_centered, y_centered, self.options.should_normalize())?;
let intercept: F = y_offset - X_offset.dot(&params);
Ok(FittedLinearRegression { intercept, params })
} else {
// `LeastSquaresSvdInto` needs a mutable reference to the data and `dataset` is taken
// by reference. Therefore copy the problem matrix and target vector.
let (X, y) = (X.to_owned(), y.to_owned());

Ok(FittedLinearRegression {
intercept: F::cast(0),
params: solve_normal_equation(X, &y)?,
params: solve_least_squares(X, y)?,
})
}
}
Expand All @@ -169,38 +173,43 @@ impl<F: Float, D: Data<Elem = F>, T: AsTargets<Elem = F>> Fit<ArrayBase<D, Ix2>,
/// Compute the parameters for the linear regression model with
/// or without normalization.
fn compute_params<F, B, C>(
X: &ArrayBase<B, Ix2>,
y: &ArrayBase<C, Ix1>,
X: ArrayBase<B, Ix2>,
y: ArrayBase<C, Ix1>,
normalize: bool,
) -> Result<Array1<F>>
where
F: Float,
B: Data<Elem = F>,
C: Data<Elem = F>,
B: DataMut<Elem = F>,
C: DataMut<Elem = F>,
{
if normalize {
let scale: Array1<F> = X.map_axis(Axis(0), |column| column.central_moment(2).unwrap());
let X: Array2<F> = X / &scale;
let mut params: Array1<F> = solve_normal_equation(&X, y)?;
let X: Array2<F> = &X / &scale;
let mut params: Array1<F> = solve_least_squares(X, y)?;
params /= &scale;
Ok(params)
} else {
solve_normal_equation(X, y)
solve_least_squares(X, y)
}
}

/// Solve the overconstrained model Xb = y by solving X^T X b = X^t y,
/// this is (mathematically, not numerically) equivalent to computing
/// the solution with the Moore-Penrose pseudo-inverse.
fn solve_normal_equation<F, B, C>(X: &ArrayBase<B, Ix2>, y: &ArrayBase<C, Ix1>) -> Result<Array1<F>>
/// Find the b that minimizes the 2-norm of X b - y
/// by using the least_squares solver from ndarray-linalg
fn solve_least_squares<F, B, C>(
mut X: ArrayBase<B, Ix2>,
mut y: ArrayBase<C, Ix1>,
) -> Result<Array1<F>>
where
F: Float,
B: Data<Elem = F>,
C: Data<Elem = F>,
B: DataMut<Elem = F>,
C: DataMut<Elem = F>,
{
let rhs = X.t().dot(y);
let linear_operator = X.t().dot(X);
linear_operator.solve_into(rhs).map_err(|err| err.into())
// ensure that B = C
let (X, y) = (X.view_mut(), y.view_mut());

X.least_squares_into(y)
.map(|x| x.solution)
.map_err(|err| err.into())
}

/// View the fitted parameters and make predictions with a fitted
Expand Down