Skip to content

API: transformer_instant

ExactStep Transformer Toolkit (CPU)

Hook-based, model-agnostic utilities to integrate closed-form solves and one-shot updates with Transformer models on CPU. Includes:

  • Frozen features → closed-form head (ridge / KRR / ELM)
  • One-shot block collapse via linear solves + balanced SVD factorization
  • One-shot LoRA via SVD (rank-r update)
  • Explicit attention path wiring Q/K/V/O and solving W_O from Z = (A V)

This package reuses SPD Cholesky-backed solves with stable whitening by default.

BlockCollapseTrainer dataclass

Source code in transformer_instant/pipeline.py
61
62
63
64
65
66
67
@dataclass
class BlockCollapseTrainer:
    lambda_reg: float = 1e-3

    def solve_linear_block(self, X: Array, Y: Array) -> Array:
        """Solve W from X W ≈ Y via ridge."""
        return ridge_fit_closed_form(X, Y, lambda_reg=self.lambda_reg)

solve_linear_block(X, Y)

Solve W from X W ≈ Y via ridge.

Source code in transformer_instant/pipeline.py
65
66
67
def solve_linear_block(self, X: Array, Y: Array) -> Array:
    """Solve W from X W ≈ Y via ridge."""
    return ridge_fit_closed_form(X, Y, lambda_reg=self.lambda_reg)

apply_lora_update_inplace(W, A, B)

Apply W ← W + A @ B in-place on W if shapes are compatible.

Source code in transformer_instant/lora_svd.py
29
30
31
32
33
def apply_lora_update_inplace(W: np.ndarray, A: np.ndarray, B: np.ndarray) -> None:
    """Apply W ← W + A @ B in-place on W if shapes are compatible."""
    if A.size == 0 or B.size == 0:
        return
    W += A @ B

compute_qkv_attn(X, W_Q, W_K, W_V, *, mask=None)

Compute Q, K, V, A=softmax(QK^T/√d), Z=A V for a single-head attention.

Shapes
  • X: (n, d_model)
  • W_Q: (d_model, d_k), W_K: (d_model, d_k), W_V: (d_model, d_v)
  • Returns: Q(n,d_k), K(n,d_k), V(n,d_v), A(n,n), Z(n,d_v)
Source code in transformer_instant/attention_solver.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def compute_qkv_attn(
    X: np.ndarray,
    W_Q: np.ndarray,
    W_K: np.ndarray,
    W_V: np.ndarray,
    *,
    mask: np.ndarray | None = None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """Compute Q, K, V, A=softmax(QK^T/√d), Z=A V for a single-head attention.

    Shapes:
      - X: (n, d_model)
      - W_Q: (d_model, d_k), W_K: (d_model, d_k), W_V: (d_model, d_v)
      - Returns: Q(n,d_k), K(n,d_k), V(n,d_v), A(n,n), Z(n,d_v)
    """
    X = np.asarray(X, dtype=np.float64)
    Q = X @ W_Q
    K = X @ W_K
    V = X @ W_V
    scale = 1.0 / np.sqrt(max(Q.shape[1], 1))
    logits = (Q @ K.T) * scale
    if mask is not None:
        logits = np.where(mask, logits, -1e30)
    A = _softmax(logits, axis=-1)
    Z = A @ V
    return Q, K, V, A, Z

generate_long_bracket_depth_dataset(num_sequences=512, min_len=64, max_len=512, max_depth=16, seed=1)

Bracket-depth task: stream of brackets and distractors; target depth.

Vocabulary
  • 0: EOS / padding
  • 1: '('
  • 2: ')'
  • 3..K: distractors (letters)

We generate sequences with balanced parentheses possibly interleaved with distractors. The target Y[t] is the nesting depth after processing token t.

Source code in transformer_instant/datasets/long_datasets.py
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
def generate_long_bracket_depth_dataset(
    num_sequences: int = 512,
    min_len: int = 64,
    max_len: int = 512,
    max_depth: int = 16,
    seed: int = 1,
) -> tuple[np.ndarray, np.ndarray, dict]:
    """Bracket-depth task: stream of brackets and distractors; target depth.

    Vocabulary:
      - 0: EOS / padding
      - 1: '('
      - 2: ')'
      - 3..K: distractors (letters)

    We generate sequences with balanced parentheses possibly interleaved with
    distractors. The target Y[t] is the nesting depth after processing token t.
    """
    rng = np.random.default_rng(seed)
    eos = 0
    tok_l = 1
    tok_r = 2
    distractors = list(range(3, 26))  # A..Z subset

    lens = rng.integers(low=min_len, high=max_len + 1, size=num_sequences)
    L_max = int(np.max(lens))
    X = np.full((num_sequences, L_max), fill_value=eos, dtype=np.int64)
    Y = np.zeros((num_sequences, L_max), dtype=np.int64)

    for i, L in enumerate(lens):
        depth = 0
        tokens: list[int] = []
        targets: list[int] = []
        for _t in range(int(L)):
            # Randomly decide to open, close (if possible), or distract
            r = rng.random()
            if r < 0.4 and depth < max_depth:
                tokens.append(tok_l)
                depth += 1
            elif r < 0.7 and depth > 0:
                tokens.append(tok_r)
                depth -= 1
            else:
                tokens.append(int(rng.choice(distractors)))
            targets.append(depth)
        X[i, :L] = _to_numpy(tokens)
        Y[i, :L] = _to_numpy(targets)

    vocab_size = max(distractors) + 1
    meta = {
        "task": "bracket_depth",
        "tokens": {"eos": eos, "left": tok_l, "right": tok_r, "distractors": distractors},
        "vocab_size": vocab_size,
        "lengths": lens,
        "max_depth": max_depth,
    }
    return X, Y, meta

generate_long_copy_dataset(num_sequences=512, min_len=64, max_len=512, vocab_size=32, seed=0)

Copy task: predict the next token matching the current token (shifted).

Returns X (N, L) int64 tokens and Y (N, L) int64 targets where Y[t] = X[t]. We also add an EOS=0 token and ensure tokens are >=1 if vocab_size>1.

Source code in transformer_instant/datasets/long_datasets.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def generate_long_copy_dataset(
    num_sequences: int = 512,
    min_len: int = 64,
    max_len: int = 512,
    vocab_size: int = 32,
    seed: int = 0,
) -> tuple[np.ndarray, np.ndarray, dict]:
    """Copy task: predict the next token matching the current token (shifted).

    Returns X (N, L) int64 tokens and Y (N, L) int64 targets where Y[t] = X[t].
    We also add an EOS=0 token and ensure tokens are >=1 if vocab_size>1.
    """
    rng = np.random.default_rng(seed)
    eos = 0
    vx = vocab_size
    lens = rng.integers(low=min_len, high=max_len + 1, size=num_sequences)
    L_max = int(np.max(lens))
    X = np.full((num_sequences, L_max), fill_value=eos, dtype=np.int64)
    Y = np.full((num_sequences, L_max), fill_value=eos, dtype=np.int64)
    for i, L in enumerate(lens):
        seq = rng.integers(low=1 if vx > 1 else 0, high=vx, size=int(L), dtype=np.int64)
        X[i, :L] = seq
        Y[i, :L] = seq
    meta = {"task": "copy", "eos": eos, "vocab_size": vx, "lengths": lens}
    return X, Y, meta

generate_long_reverse_dataset(num_sequences=512, min_len=64, max_len=512, vocab_size=32, seed=2)

Reverse task: targets are the input sequence reversed along time.

Returns X (N, L) int64 tokens and Y (N, L) int64 targets where Y[i, :L] = reverse(X[i, :L]). Positions beyond length are EOS=0.

Source code in transformer_instant/datasets/long_datasets.py
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def generate_long_reverse_dataset(
    num_sequences: int = 512,
    min_len: int = 64,
    max_len: int = 512,
    vocab_size: int = 32,
    seed: int = 2,
) -> tuple[np.ndarray, np.ndarray, dict]:
    """Reverse task: targets are the input sequence reversed along time.

    Returns X (N, L) int64 tokens and Y (N, L) int64 targets where
    Y[i, :L] = reverse(X[i, :L]). Positions beyond length are EOS=0.
    """
    rng = np.random.default_rng(seed)
    eos = 0
    vx = vocab_size
    lens = rng.integers(low=min_len, high=max_len + 1, size=num_sequences)
    L_max = int(np.max(lens))
    X = np.full((num_sequences, L_max), fill_value=eos, dtype=np.int64)
    Y = np.full((num_sequences, L_max), fill_value=eos, dtype=np.int64)
    for i, L in enumerate(lens):
        seq = rng.integers(low=1 if vx > 1 else 0, high=vx, size=int(L), dtype=np.int64)
        X[i, :L] = seq
        Y[i, :L] = seq[::-1]
    meta = {"task": "reverse", "eos": eos, "vocab_size": vx, "lengths": lens}
    return X, Y, meta

generate_running_sum_dataset(num_sequences=512, min_len=64, max_len=512, vocab_size=10, modulo=None, seed=3)

Running-sum task: tokens are digits in [0, vocab_size-1], targets are running sum.

If modulo is provided, targets are running sum modulo modulo.

Source code in transformer_instant/datasets/long_datasets.py
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
def generate_running_sum_dataset(
    num_sequences: int = 512,
    min_len: int = 64,
    max_len: int = 512,
    vocab_size: int = 10,
    modulo: int | None = None,
    seed: int = 3,
) -> tuple[np.ndarray, np.ndarray, dict]:
    """Running-sum task: tokens are digits in [0, vocab_size-1], targets are running sum.

    If modulo is provided, targets are running sum modulo `modulo`.
    """
    rng = np.random.default_rng(seed)
    eos = 0
    vx = max(2, vocab_size)
    lens = rng.integers(low=min_len, high=max_len + 1, size=num_sequences)
    L_max = int(np.max(lens))
    X = np.full((num_sequences, L_max), fill_value=eos, dtype=np.int64)
    Y = np.zeros((num_sequences, L_max), dtype=np.int64)
    for i, L in enumerate(lens):
        seq = rng.integers(low=0, high=vx, size=int(L), dtype=np.int64)
        X[i, :L] = seq
        csum = np.cumsum(seq.astype(np.int64))
        if modulo is not None:
            csum = np.remainder(csum, int(modulo))
        Y[i, :L] = csum
    meta = {"task": "running_sum", "eos": eos, "vocab_size": vx, "lengths": lens, "modulo": modulo}
    return X, Y, meta

krr_fit(X, Y, lambda_reg=0.001, kernel='rbf', length_scale=1.0, variance=1.0)

Fit Kernel Ridge on X,Y. Returns dict with alpha and config.

For multi-output Y (n,k), alpha has shape (n,k).

Source code in transformer_instant/utils.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def krr_fit(
    X: Array,
    Y: Array,
    lambda_reg: float = 1e-3,
    kernel: str = "rbf",
    length_scale: float = 1.0,
    variance: float = 1.0,
) -> dict[str, Array | float | str]:
    """Fit Kernel Ridge on X,Y. Returns dict with alpha and config.

    For multi-output Y (n,k), alpha has shape (n,k).
    """
    X = np.asarray(X, dtype=np.float64)
    Y = np.asarray(Y, dtype=np.float64)
    lambda_reg = float(lambda_reg)
    if lambda_reg < 0.0:
        raise ValueError("lambda_reg must be non-negative")
    if kernel != "rbf":
        raise ValueError("Only 'rbf' kernel supported in krr_fit")
    # Build symmetric Gram with self-consistent diagonal for better precision
    K = rbf_cross_kernel(X, X, length_scale=length_scale, variance=variance)
    K = 0.5 * (K + K.T)
    n = K.shape[0]
    lambda_eff = 0.0 if lambda_reg <= 1e-8 else lambda_reg
    A = K + lambda_eff * np.eye(n, dtype=K.dtype)
    alpha = solve_spd(A, Y)
    return {
        "X_train": X,
        "alpha": alpha,
        "kernel": kernel,
        "length_scale": float(length_scale),
        "variance": float(variance),
        "lambda_reg": lambda_reg,
    }

lora_from_residual_svd(residual, rank, *, scaling=1.0)

Return LoRA factors (A, B) s.t. A @ B ≈ residual (rank-r SVD approx).

Uses residual ≈ U_r Σ_r V_r^T, set A = U_r Σ_r^{1/2}, B = Σ_r^{1/2} V_r^T so that A @ B = U_r Σ_r V_r^T. Optional scalar scaling applied to both.

Source code in transformer_instant/lora_svd.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
def lora_from_residual_svd(
    residual: np.ndarray,
    rank: int,
    *,
    scaling: float = 1.0,
) -> tuple[np.ndarray, np.ndarray]:
    """Return LoRA factors (A, B) s.t. A @ B ≈ residual (rank-r SVD approx).

    Uses residual ≈ U_r Σ_r V_r^T, set A = U_r Σ_r^{1/2}, B = Σ_r^{1/2} V_r^T
    so that A @ B = U_r Σ_r V_r^T. Optional scalar scaling applied to both.
    """
    U, s, Vt = svd_truncate(residual, rank)
    if s.size == 0:
        d_out, d_in = residual.shape
        return np.zeros((d_out, 0), dtype=np.float64), np.zeros((0, d_in), dtype=np.float64)
    S_half = np.sqrt(s)
    A = (U * S_half[None, :]) * float(scaling)
    B = (S_half[:, None] * Vt) * float(scaling)
    return A, B

ridge_fit_closed_form(X, Y, lambda_reg=0.001)

Closed-form ridge solution W = (X^T X + λI)^{-1} X^T Y.

X: (n, d), Y: (n, k) or (n,) Returns W: (d, k) or (d,)

Source code in transformer_instant/linalg.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def ridge_fit_closed_form(X: Array, Y: Array, lambda_reg: float = 1e-3) -> Array:
    """Closed-form ridge solution W = (X^T X + λI)^{-1} X^T Y.

    X: (n, d), Y: (n, k) or (n,)
    Returns W: (d, k) or (d,)
    """
    X = np.asarray(X, dtype=np.float64)
    Y = np.asarray(Y, dtype=np.float64)
    lambda_reg = float(lambda_reg)
    if lambda_reg < 0.0:
        raise ValueError("lambda_reg must be non-negative")
    if Y.ndim == 1:
        Y = Y.reshape(-1, 1)
        squeeze = True
    else:
        squeeze = False
    n, d = X.shape
    xtx = X.T @ X
    if lambda_reg > 0.0:
        xtx = xtx + lambda_reg * np.eye(d, dtype=xtx.dtype)
    xty = X.T @ Y
    W = solve_spd(xtx, xty)
    if squeeze and W.ndim == 2 and W.shape[1] == 1:
        return W.ravel()
    return W

solve_attention_output_projection(X, W_Q, W_K, W_V, H_attn_target, *, lambda_reg=0.001, mask=None)

Return (Q,K,V,A,W_O) with W_O solving Z W_O ≈ H_attn_target.

This keeps the scaffold hook-friendly and model-agnostic.

Source code in transformer_instant/pipeline.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
def solve_attention_output_projection(
    X: Array,
    W_Q: Array,
    W_K: Array,
    W_V: Array,
    H_attn_target: Array,
    *,
    lambda_reg: float = 1e-3,
    mask: Array | None = None,
) -> tuple[Array, Array, Array, Array, Array]:
    """Return (Q,K,V,A,W_O) with W_O solving Z W_O ≈ H_attn_target.

    This keeps the scaffold hook-friendly and model-agnostic.
    """
    Q, K, V, A, Z = compute_qkv_attn(X, W_Q, W_K, W_V, mask=mask)
    W_O = solve_output_projection_from_attn(Z, H_attn_target, lambda_reg=lambda_reg)
    return Q, K, V, A, W_O

solve_output_projection_from_attn(Z, H_attn, *, lambda_reg=0.001)

Solve W_O from Z W_O ≈ H_attn via ridge in closed form.

Z: (n, d_v), H_attn: (n, d_model), returns W_O: (d_v, d_model)

Source code in transformer_instant/attention_solver.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def solve_output_projection_from_attn(
    Z: np.ndarray,
    H_attn: np.ndarray,
    *,
    lambda_reg: float = 1e-3,
) -> np.ndarray:
    """Solve W_O from Z W_O ≈ H_attn via ridge in closed form.

    Z: (n, d_v), H_attn: (n, d_model), returns W_O: (d_v, d_model)
    """
    Z = np.asarray(Z, dtype=np.float64)
    H_attn = np.asarray(H_attn, dtype=np.float64)
    W_O = ridge_fit_closed_form(Z, H_attn, lambda_reg=lambda_reg)
    return W_O
.. automodule:: transformer_instant
    :members:
    :undoc-members:
    :show-inheritance:

CLI additions

The exactstep CLI includes a deep-elm-bench subcommand for sweeping depth/hidden/lambda and saving plots.