Skip to content

API: linalg

ridge_fit_closed_form(X, Y, lambda_reg=0.001)

Closed-form ridge solution W = (X^T X + λI)^{-1} X^T Y.

X: (n, d), Y: (n, k) or (n,) Returns W: (d, k) or (d,)

Source code in transformer_instant/linalg.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def ridge_fit_closed_form(X: Array, Y: Array, lambda_reg: float = 1e-3) -> Array:
    """Closed-form ridge solution W = (X^T X + λI)^{-1} X^T Y.

    X: (n, d), Y: (n, k) or (n,)
    Returns W: (d, k) or (d,)
    """
    X = np.asarray(X, dtype=np.float64)
    Y = np.asarray(Y, dtype=np.float64)
    lambda_reg = float(lambda_reg)
    if lambda_reg < 0.0:
        raise ValueError("lambda_reg must be non-negative")
    if Y.ndim == 1:
        Y = Y.reshape(-1, 1)
        squeeze = True
    else:
        squeeze = False
    n, d = X.shape
    xtx = X.T @ X
    if lambda_reg > 0.0:
        xtx = xtx + lambda_reg * np.eye(d, dtype=xtx.dtype)
    xty = X.T @ Y
    W = solve_spd(xtx, xty)
    if squeeze and W.ndim == 2 and W.shape[1] == 1:
        return W.ravel()
    return W

solve_spd(A, B)

Solve A X = B for SPD A using Cholesky with robust fallbacks.

Source code in transformer_instant/linalg.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def solve_spd(A: Array, B: Array) -> Array:
    """Solve A X = B for SPD A using Cholesky with robust fallbacks."""
    A = np.asarray(A, dtype=np.float64)
    B = np.asarray(B, dtype=np.float64)
    if _HAVE_SCIPY:
        try:
            c, lower = _cho_factor(A, overwrite_a=False, check_finite=False)
            return cast(
                Array, np.asarray(_cho_solve((c, lower), B, check_finite=False), dtype=np.float64)
            )
        except Exception:
            pass
    try:
        L = np.linalg.cholesky(A)
        Y = np.linalg.solve(L, B)
        X = np.linalg.solve(L.T, Y)
        return cast(Array, np.asarray(X, dtype=np.float64))
    except Exception:
        return cast(Array, np.asarray(np.linalg.solve(A, B), dtype=np.float64))
.. automodule:: transformer_instant.linalg
    :members:
    :undoc-members:
    :show-inheritance: