Skip to content

ENH add ADMM solver #76

@mathurinm

Description

@mathurinm

@josephsalmon here's a basic ADMM solver for the Lasso:

import numpy as np
from numpy.linalg import norm
import matplotlib.pyplot as plt

from scipy import io
from sklearn.preprocessing import MinMaxScaler

from skglm.utils import ST_vec


def primal_lasso(X, y, alpha, w):
    """Compute the primal objective of Lasso.

    Parameters
    ----------
    X : array, shape (n_samples, n_features)
        Design matrix.

    y : array, shape (n_samples,)
        Target vector.

    alpha : float
        Regularization parameter.

    w : array, shape (n_features)
        Coefficient vector.

    Returns
    -------
    p_obj : float
        The primal objective.
    """
    r = X @ w - y
    return 1/2 * r @ r + alpha * norm(w, ord=1)


def extrapolate(last_K_u, K, extrap_type="LPinf", s=100):
    """Construct an extrapolated point from past iterates.

    Parameters
    ----------
    last_K_u : array, shape (n_features, K)
        Array of past iterates.

    K : int
        Number of past iterates to use for extrapolation.

    extrap_type : str, (`LP` | `LPinf`), optional
        Extrapolation type.

    s : int
        LP steps (used if extra_type=`LP`).

    Returns
    -------
    w : array, shape (n_features,)
        Extrapolated vectors
    """
    q = K - 2

    V = np.diff(last_K_u, 1, axis=-1)

    V_k = V[:, 1:]
    v_k = V[:, -1]
    V_prev = V[:, :-1]

    # Compute coefficient
    VtV = V_prev.T @ V_prev

    try:
        c = np.linalg.solve(VtV, V_prev.T @ v_k)
    except np.linalg.LinAlgError:
        return v_k
    else:
        # Iteration matrix
        C = np.diag(np.ones(q-1), -1)
        C[:, -1] = c

        rho = norm(np.linalg.eigvals(C), ord=np.inf)

        if extrap_type == "LPinf":
            if rho < 1:
                tmp = np.eye(q) - C
                S = np.linalg.solve(tmp.T, C.T).T
            else:
                S = 0 * C
        else:
            if rho < 1:
                pC = np.linalg.matrix_power(C, s)
                tmp = np.eye(q) - C
                S = np.linalg.solve(tmp.T, (C - pC).T, ).T
            else:
                S = 0 * C

        return V_k @ S[:, -1]


def admm(X, y, alpha, gamma=2., max_iter=1000, tol=1e-5, check_gap_freq=50, a=0, K=6,
         use_accel=True, verbose=True):
    """Run Alternate Direction Method of multipliers optimization scheme for Lasso.

    Parameters
    ----------
    X : array, shape (n_samples, n_features)
        Design matrix.

    y : array, shape (n_samples,)
        Target vector.

    alpha : float
        Regularization parameter.

    gamma : float
        Augmented Lagrangian parameter.

    max_iter : int
        Maximum number of iterations.

    tol : float
        Tolerance.

    check_gap_freq : int
        Frequency for checking convergence.

    a : float
        Inertia parameter.

    K : int
        Number of past iterates to compute extrapolated point.

    use_accel : bool
        Use extrapolation.

    verbose : bool
        Verbosity.

    Returns
    -------
    w : array, shape (n_features,)
        Coefficient vector.
    """
    n_features = X.shape[1]
    residuals = []
    iterates = []

    # Acceleration variables
    last_K_u = np.zeros((n_features, K))

    # Optimization variables
    w = np.ones(n_features)  # Primal iterates
    z = np.ones(n_features)
    psi = np.ones(n_features)  # Dual iterates
    u = psi + gamma * w
    u_bar = u

    v = u - u

    # Pre-compute useful quantities
    XtX_scaled = X.T @ X / gamma
    Xty_scaled = X.T @ y / gamma
    L = np.linalg.cholesky(XtX_scaled + np.eye(n_features))
    U = L.T

    for iter in range(1, max_iter + 1):
        u_prev = u.copy()

        # Proximal step for datafit
        z = np.linalg.solve(U, np.linalg.solve(L, Xty_scaled + u_bar / gamma))
        psi = u_bar - gamma * z  # Dual update

        # Proximal step for pen
        w = ST_vec((u_bar - 2 * psi) / gamma, alpha / gamma)
        u = psi + gamma * w
        iterates.append(w)

        # Inertial step
        v = u - u_prev
        u_bar = u + a * v

        last_K_u = np.column_stack((last_K_u[:, 1:], u))

        if use_accel and iter % (K + 1) == 0:
            e = extrapolate(last_K_u, K)
            with np.errstate(divide="ignore"):
                # Removes warning for zero division at first iteration
                # Parameter safeguard - avoid numerical errors
                coeff = np.minimum(1., 1e5 / (iter**1.1 * norm(e)))
            u = u + coeff * e
            u_bar = u

        res = norm(v)
        residuals.append(res)

        if iter % check_gap_freq == 0:
            p_obj = primal_lasso(X, y, alpha, w)
            if verbose:
                print(f"iter {iter} :: residual {res:.5f} :: obj {p_obj:.4f}")

            if res < tol:
                break
    return w, residuals, iterates


if __name__ == "__main__":
    # Matrices can be downloaded at:
    # https://github.com/jliang993/A3DMM/tree/master/codes/data
    X = io.loadmat('covtype_sample.mat')["h"]
    y = io.loadmat('covtype_label.mat')["l"]

    y = np.ravel(y)

    scaler = MinMaxScaler(feature_range=(-1, 1))
    X = X.toarray()
    X = scaler.fit_transform(X)

    alpha = 1
    w, residuals, iterates = admm(X, y, alpha, tol=1e-6, use_accel=True,
                                  max_iter=50_000, check_gap_freq=10)
    print("#" * 25)
    w_no_accel, residuals_no_acc, iterates_no_acc = admm(X, y, alpha, tol=1e-6,
                                                         use_accel=False,
                                                         max_iter=50_000,
                                                         check_gap_freq=100)

    np.testing.assert_allclose(w, w_no_accel, rtol=1e-4)

    # Plotting
    norms_accel = list(map(lambda wc: np.log(norm(wc - w)), iterates))
    norms_no_accel = list(map(lambda wc: np.log(
        norm(wc - w_no_accel)), iterates_no_acc))
    plt.plot(norms_accel, label="Accelerated")
    plt.plot(norms_no_accel, label="No accel")
    plt.legend()
    plt.title("ADMM - ||x - x^*||")
    plt.show()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions