AISTATS 2026 Batch • Paper 3
13 minute read
Published: June 2026

A Proof of Learning Rate Transfer Under μP

Original paper by Soufiane Hayou • summary & experiment by Hamidreza Hashempoor, Institute for AI, University of Stuttgart

Tuning the learning rate is the most expensive ritual in deep learning: a value that is perfect for a small model is often unstable or far too timid for a large one, so practitioners re-sweep the learning rate every time they scale up. The maximal-update parametrization (μP) promises to end that ritual — under μP the optimal learning rate found on a narrow network is supposed to transfer to a much wider one. This post follows A Proof of Learning Rate Transfer Under μP (Soufiane Hayou, AISTATS 2026), which turns that empirical folklore into a theorem for deep linear networks: the optimal one-step learning rate has a deterministic infinite-width limit, and finite-width networks approach it at a quantified rate. We restate the result, then verify it on a deep linear MLP whose numbers are reproduced by the companion notebook.

The idea: the optimal learning rate stops moving

Fix a network architecture and a dataset, and let only the width $n$ vary. For each width there is some learning rate $\eta_n$ that makes the very first gradient step reduce the loss as much as possible. The question the paper answers is: what happens to $\eta_n$ as $n\to\infty$?

Under the standard parametrization (SP) the answer is "it keeps moving" — the right step size depends on width, which is exactly why a learning rate tuned at one scale is wrong at another. Under μP, the layers are initialized and the updates are scaled so that every hidden unit receives an $O(1)$ update regardless of width. The consequence proved here is that the optimal learning rate converges to a fixed number:

$$ \eta_n^{(1)}\;\longrightarrow\;\eta_\infty^{(1)} \qquad \text{as width } n\to\infty. $$
The whole promise of learning-rate transfer rests on this limit existing. If $\eta_n^{(1)}$ settles onto a width-independent value $\eta_\infty^{(1)}$, you can compute (or tune) the learning rate once — even at infinite width, in closed form — and reuse it at any finite width. The paper makes this precise for deep linear networks, gives $\eta_\infty^{(1)}$ explicitly in terms of the data, and bounds how fast finite widths get there: $\eta_n^{(1)}-\eta_\infty^{(1)}=O_P(n^{-1/2})$.

The model and the μP parametrization

The analysis uses a deep linear MLP — a stack of matrix multiplications with no nonlinearity. Linearity is what makes everything computable in closed form while still exhibiting the depth- and width-dependent behavior that matters for the proof. The network maps an input $x\in\mathbb R^d$ to a scalar through $L$ hidden layers:

$$ f(x)=V^\top W_L W_{L-1}\cdots W_1 W_0\, x. $$

Freezing $W_0$ and $V$ isolates the part of the network whose scaling drives the result. The μP initialization sets the variance of each layer so that signals and updates stay $O(1)$ as the width grows:

$$ W_0\sim\mathcal N\!\big(0,\tfrac1d\big),\qquad W_\ell\sim\mathcal N\!\big(0,\tfrac1n\big)\ \ (\ell=1,\dots,L),\qquad V\sim\mathcal N\!\big(0,\tfrac1{n^{2}}\big). $$

In code: W0 = randn(n,d)/√d, W_l = randn(n,n)/√n, and the μP readout V = randn(n)/n.

The one line that makes it μP. The readout scales as $V\sim 1/n$ (variance $1/n^{2}$), not the standard $1/\sqrt n$. That single change is what keeps the per-step update to each preactivation $\Theta(1)$ at every width, and it is precisely the ingredient SP lacks. Swap $V\sim 1/\sqrt n$ back in and the learning-rate transfer studied here disappears.

One gradient step and the optimal learning rate

Training is full-batch gradient descent on the mean-squared error over $m$ samples $\{(x_i,y_i)\}_{i=1}^m$:

$$ \mathcal L_n(\eta)=\frac{1}{2m}\sum_{i=1}^{m}\big(f(x_i)-y_i\big)^2. $$

Starting from the μP initialization, a single gradient-descent step updates only the trained hidden matrices, with the μP rule that the effective learning rate is not width-scaled:

$$ W_\ell^{(1)}=W_\ell^{(0)}-\eta\,\nabla_{W_\ell}\mathcal L_n^{(0)},\qquad \ell=1,\dots,L. $$

Let $\mathcal L_n^{(1)}(\eta)$ be the loss after that one step. The empirical optimal one-step learning rate at width $n$ is whatever value of $\eta$ minimizes it over a compact interval $I$:

$$ \eta_n^{(1)}\in\arg\min_{\eta\in I}\ \mathcal L_n^{(1)}(\eta). $$

Because all candidate learning rates start from the same initialization, the gradient $\nabla_{W_\ell}\mathcal L_n^{(0)}$ is identical for every $\eta$ — it is computed once and reused while sweeping $\eta$, which is exactly what the notebook does.

The infinite-width learning rate, in closed form

The central object on the theory side is the normalized Gram matrix of the inputs — a quantity that depends only on the data, not on the network width:

$$ K=\frac{1}{d}\,X X^\top,\qquad K_{ij}=\frac{\langle x_i,x_j\rangle}{d}. $$

In the infinite-width limit the random network behaves like a deterministic linear map governed by $K$, and the one-step loss becomes an exact quadratic in $\eta$. Minimizing that quadratic gives the optimal learning rate in closed form:

$$ \boxed{\;\eta_\infty^{(1)}=\frac{m}{L}\,\frac{y^\top K y}{\lVert K y\rVert_2^{2}}\;} $$

The formula is well defined whenever $Ky\neq 0$, which the theorem assumes. Depth $L$ enters as a simple $1/L$ factor — deeper networks want proportionally smaller steps — and the data enters only through the Rayleigh-type quotient $y^\top K y/\lVert Ky\rVert^2$. Notably, the width $n$ does not appear: this is the fixed target every finite-width network is converging to.

The convergence claim

The main theorem is not just that the limit exists, but that finite-width networks reach it at a controlled rate. Treating the random initialization as the source of randomness,

$$ \eta_n^{(1)}-\eta_\infty^{(1)}=O_P\!\big(n^{-1/2}\big). $$

In words: the gap between the best finite-width step size and the infinite-width prediction shrinks like $1/\sqrt n$, with fluctuations across random seeds of the same order. Doubling the width should roughly cut both the typical error and its seed-to-seed spread by a factor $\sqrt 2$. That is the quantitative content behind "learning rates transfer": not only is there a fixed answer, but moderate widths are already close to it.

The minimal experiment

To check the claim we instantiate exactly the setting above: a deep linear μP MLP of depth $L=3$, trained for a single full-batch gradient step on a fixed synthetic linear-regression dataset, sweeping width across $\{64,128,256,512,1024\}$. The data is generated once and reused at every width — the theorem studies width asymptotics with the data held fixed — from

$$ x_i\sim\mathcal N(0,I_d),\quad w_\star\sim\mathcal N(0,d^{-1}I_d),\quad y_i=w_\star^\top x_i+\varepsilon_i,\ \ \varepsilon_i\sim\mathcal N(0,\sigma^2). $$

For each width and each of three initialization seeds we compute the empirical $\eta_n^{(1)}$ by a grid search over $\eta\in[0,\,4\,\eta_\infty^{(1)}]$ (the search interval is built to contain the theoretical value), refined locally around the best grid point, and compare it against the closed-form $\eta_\infty^{(1)}$. The exact configuration:

SettingValueSettingValue
depth $L$3parametrizationμP ($V\sim1/n$)
widths $n$64, 128, 256, 512, 1024init seeds1, 2, 3
samples $m$500input dim $d$1
noise std $\sigma$0.10data seed123
optimizerfull-batch GD, 1 stepprecisionfloat64
$\eta$ grid120 points on $[0,4\eta_\infty]$local refine60 points

Results

The closed-form target for this dataset is $\eta_\infty^{(1)}=0.37176$. The empirical optimum starts noisy at small width and tightens onto that line as the network grows — the seed-to-seed spread shrinking right alongside the mean.

Empirical optimal one-step learning rate versus width, with the theoretical infinite-width value as a dashed horizontal line and error bars across three seeds.
Figure 1. Empirical optimal one-step learning rate $\eta_n^{(1)}$ (mean ± std over 3 seeds) vs. width $n$ on a $\log_2$ axis. The dashed line is the closed-form $\eta_\infty^{(1)}=0.3718$. The estimate is volatile at $n=64,128$, then locks onto the theoretical value with a visibly tightening error bar by $n=512$ and $1024$.
Absolute error between empirical and theoretical learning rate versus width on log-log axes, compared against an n to the minus one-half reference slope.
Figure 2. Absolute error $\lvert\eta_n^{(1)}-\eta_\infty^{(1)}\rvert$ vs. width on log–log axes, with the $n^{-1/2}$ reference (grey dashed) the theorem predicts. The error falls steeply over $64\!\to\!512$; at this small range of widths the measured decay is even faster than $1/\sqrt n$ and the lone $n=128$ outlier reflects finite-width landscape noise (see below).

Per-width numbers

Averaged over the three initialization seeds (data seed 123), with $\eta_\infty^{(1)}=0.371763$:

width $n$$\eta_n^{(1)}$ meanstd (3 seeds)$\lvert\eta_n^{(1)}-\eta_\infty^{(1)}\rvert$rel. error
640.397970.08990.026217.0%
1280.513400.18850.1416438.1%
2560.413580.08980.0418111.2%
5120.370510.03700.001250.3%
10240.377220.01870.005451.5%

Errors are computed on the seed-averaged mean $\eta_n^{(1)}$. The optimal post-step loss is essentially identical across all widths and seeds ($\approx 5.076\times10^{-3}$), confirming each run found a genuine minimum of $\mathcal L_n^{(1)}(\eta)$.

Reading the convergence. Two trends match the theory cleanly. (1) The seed-to-seed standard deviation falls monotonically with width — $0.090\to0.019$ from $n=64$ to $n=1024$ — the $O_P(n^{-1/2})$ fluctuation shrinking as predicted. (2) The mean locks onto $\eta_\infty^{(1)}$: by $n=512$ the error is $0.13\%$ and at $n=1024$ it is $1.5\%$ ($5.5\times10^{-3}$ in absolute terms). The $n=128$ spike (error $38\%$, the largest std in the table) is a single-seed outlier where one initialization landed in a flat region of the one-step loss — the kind of finite-width noise the high-probability $O_P$ statement explicitly allows, and which averages out as width grows.

A caveat in the spirit of honest reporting: a least-squares fit of $\log\lvert\eta_n-\eta_\infty\rvert$ against $\log n$ over these five widths gives a slope of $\approx-1.14$, steeper than the theoretical $-1/2$. The theorem is an asymptotic, high-probability upper bound — it guarantees the error decays at least as fast as $n^{-1/2}$, not exactly — and with only five widths, one of them an outlier, the fitted slope is dominated by the noisy small-width points rather than a clean asymptotic rate. The robust takeaway is the one the theory cares about: the error and its spread both vanish with width, and a moderate width already pins the learning rate to within a couple of percent of the closed-form infinite-width value.

Reproduce it

The companion notebook re-runs the whole experiment from scratch with only torch, numpy, and matplotlib. It builds the μP linear MLP, computes the closed-form $\eta_\infty^{(1)}=\tfrac{m}{L}\,y^\top Ky/\lVert Ky\rVert^2$, performs the grad-once one-step learning-rate search across widths and seeds, and regenerates both figures — reproducing the per-width table above ($\eta_\infty=0.37176$, error at $n=1024$ of $5.5\times10^{-3}$) to floating-point precision.