\NAT@set@cites

An Effective Gram Matrix Characterizes Generalization
in Deep Networks

Rubing Yang Pratik Chaudhari
University of Pennsylvania
Email: [email protected]
[email protected]
Abstract

We derive a differential equation that governs the evolution of the generalization gap when a deep network is trained by gradient descent. This differential equation is controlled by two quantities, a contraction factor that brings together trajectories corresponding to slightly different datasets, and a perturbation factor that accounts for them training on different datasets. We analyze this differential equation to compute an “effective Gram matrix” that characterizes the generalization gap after training in terms of the alignment between this Gram matrix and a certain initial “residual”. Empirical evaluations 111Code at https://github.com/grasp-lyrl/effective-gram-matrix.git on image classification datasets indicate that this analysis can predict the test loss accurately. Further, at any point during training, the residual predominantly lies in the subspace of the effective Gram matrix with the smallest eigenvalues. This indicates that the training process is benign, i.e., it does not lead to significant deterioration of the generalization gap (which is zero at initialization). The alignment between the effective Gram matrix and the residual is different for different datasets and architectures. The match/mismatch of the data and the architecture is primarily responsible for good/bad generalization.

1 Introduction

Generalization is the ability of a model to apply patterns learned from training data to new, unseen data. Deep neural networks are interesting in this regard because, despite having many parameters and a complex loss landscape, they can still generalize well. This challenges the traditional statistical wisdom, e.g., the bias-variance trade-off, which suggests that highly flexible models should overfit and perform poorly on test data. Deep networks however consistently perform well on unseen data, raising fundamental questions about the principles that govern their generalization. A large body of work has sought to tackle this question and there are numerous perspectives on the relationship between training data, test data, and the model class for deep networks in the literature today. While this work provides valuable insights, each of the existing lines of attack have their limitations.

Probably-approximately correct (PAC) frameworks

The PAC learning framework (Valiant, 1984) provides generalization bounds for models trained on independently and identically distributed (i.i.d.) data, using measures such as Vapnik-Chervonenkis (VC) dimension or Rademacher complexity to characterize the hypothesis space. The PAC-Bayes framework (McAllester, 1999) extends these ideas by deriving generalization error bounds for randomized estimators. However, both frameworks are limited in their ability to explain the generalization behavior of modern deep neural networks. Despite their rich hypothesis class and extremely large VC dimensions, deep neural networks consistently achieve remarkable generalization, a phenomenon that defies the worst-case assumptions inherent in these classical frameworks.

Simplified models of deep networks

To move beyond worst-case analyses, a direct examination of the exact solutions of deep neural networks is an appealing approach. However, due to the inherent complexity of these networks, such analyses are often intractable. Instead, various solvable models from statistics and physics have been employed to partially characterize deep neural network behavior, offering valuable theoretical insights. Linear regression, for instance, has been widely used to explore phenomena like benign over-fitting (Bartlett et al., 2020) and double descent (Hastie et al., 2022; Belkin et al., 2020). To investigate the effects of depth, deep linear networks have served as a useful abstraction for studying multi-layer dynamics (Laurent and von Brecht, 2018). One of the most prominent frameworks in this area is the Neural Tangent Kernel (NTK) approach (Jacot et al., 2018), which models the training dynamics of deep neural networks in the infinite-width regime under a kernel-based approximation. Mallinar et al. (2022); Belkin et al. (2018) characterized different regimes of kernel regression and analyzed its resemblance with deep learning. The NTK method has enabled significant results on convergence (Du et al., 2019; Li and Liang, 2018) and generalization (Arora et al., 2019; Jacot et al., 2020). Bowman and Montúfar (2022) analyzed the divergence of finite-width neural networks with NTK regime in different eigenspaces. Similarly, mean-field analysis (Chizat and Bach, 2018; Mei et al., 2019) has been used to study the evolution of neuron distribution in the infinite-width limit. These models often rely on assumptions such as convexity in the loss landscape and constraining the dynamics to a region around initialization. This precludes feature learning—a critical aspect of modern deep neural networks.

Non-worst-case generalization bounds

Infinite-width assumptions are impractical for real-world scenarios, prompting research into deriving bounds for general neural networks by making mild assumptions about the training process and data. For example, Bartlett et al. (2017); Neyshabur et al. (2018) analyzed the complexity of the reachable hypothesis class and proposed weight-dependent generalization bounds that restrict the hypothesis space based on the weights’ distance to a reference point. Dziugaite and Roy (2017); Yang et al. (2022) established generalization bounds for stochastic algorithms using properties of the trained minima, however, these bounds are derived in a post-hoc manner, based on the trained solution. Algorithm-specific approaches, such as sensitivity analyses on the effects of perturbations of the dataset, provide insights into algorithmic stability and its impact on generalization (Bousquet and Elisseeff, 2002; Hardt et al., 2016; Xu and Mannor, 2012; Chu and Raginsky, 2023). Kawaguchi et al. (2022) gives a generalization bound from data-dependent robustness analysis. Xu and Raginsky (2017); Mou et al. (2018) explored the stability of stochastic algorithms using information-theoretic approaches, which has led to further discussions (Negrea et al., 2019; Neu et al., 2021; Lugosi and Neu, 2022) that offer generalization bounds that depend on the training trajectory. By exploring the conditional mutual information (Hafez-Kolahi et al., 2020; Steinke and Zakynthinou, 2020), one can get tighter generalization bounds. Additionally, by assuming specific properties of the training loss landscape, studies such as (Kozachkov et al., 2023) and (Lugosi and Neu, 2022) provided generalization guarantees. These methods often rely on assumptions that are uniformly applied across the entire hypothesis space, which can be problematic. Task-specific analysis of neural network training are provided by (Ramesh et al., 2024; Mao et al., 2024). Chuang et al. (2021) gives margin-based generalization bound normalized by optimal transport cost, deploying the properties of data while ignoring the training process.

1.1 Contributions

We analyze how the generalization gap accumulates along the training trajectory. We derive a differential equation describing the evolution of the averaged loss difference, controlled by the contraction factor and the perturbation factor. This equation tells us how perturbation of dataset affects the output of the predictor during training. We define an “effective Gram matrix” for neural network training, that characterizes the accumulation of generalization gap in different subspaces. Using this effective Gram matrix, we derive a complexity measure that faithfully characterizes the generalization gap in general networks. This analysis allows us to get a data-dependent estimate of the generalization gap. Time-varying contraction and perturbation factors along the training trajectory allow us to avoid making uniform assumptions about the loss function. We next describe the contributions of the paper.

  • In Section 3.2, we derive the differential equation for the evolution of the averaged loss difference Δ¯n(t)subscript¯Δ𝑛𝑡\bar{\Delta}_{n}(t)over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ). This equation depends upon a certain “contraction factor” c¯nsubscript¯𝑐𝑛\bar{c}_{n}over¯ start_ARG italic_c end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT and a “perturbation factor” ϵ¯nsubscript¯italic-ϵ𝑛\bar{\epsilon}_{n}over¯ start_ARG italic_ϵ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT.

  • In Section 3.3 and Section 3.4, we derive the effective Gram matrix Knsubscript𝐾𝑛K_{n}italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT, and give a complexity measure in terms of the quadratic form rn(0)Knrnsubscript𝑟𝑛superscript0topsubscript𝐾𝑛subscript𝑟𝑛\vec{r}_{n}(0)^{\top}K_{n}\vec{r}_{n}over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT, that faithfully characterize the generalization ability of neural networks. This analysis holds for arbitrary networks and loss functions.

  • In Section 4, we calculate numerical approximations of Knsubscript𝐾𝑛K_{n}italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT and projections of the residuals onto different eigenspaces of Knsubscript𝐾𝑛K_{n}italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT. We show that simpler tasks and better model architectures benefit the training process due to an effective Gram matrix with smaller eigenvalues and better alignment with the initial residual, resulting in better generalization at the end of training.

2 Preliminaries

Let [n]delimited-[]𝑛[n][ italic_n ] denote the set of integers {1,,n}1𝑛\{1,...,n\}{ 1 , … , italic_n }. We use the notation ab𝑎𝑏a\cdot bitalic_a ⋅ italic_b to denote the inner product of vectors a𝑎aitalic_a, b𝑏bitalic_b. For a function hhitalic_h, we write h(w)|abh(b)h(a)evaluated-at𝑤𝑎𝑏𝑏𝑎h(w)|^{b}_{a}\equiv h(b)-h(a)italic_h ( italic_w ) | start_POSTSUPERSCRIPT italic_b end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT ≡ italic_h ( italic_b ) - italic_h ( italic_a ) and, sometimes, h(w)|ah(a)evaluated-at𝑤𝑎𝑎h(w)|_{a}\equiv h(a)italic_h ( italic_w ) | start_POSTSUBSCRIPT italic_a end_POSTSUBSCRIPT ≡ italic_h ( italic_a ). We use ||,2,Fsubscriptnorm2subscriptnorm𝐹\absolutevalue{\cdot},\norm{\cdot}_{2},\norm{\cdot}_{F}| start_ARG ⋅ end_ARG | , ∥ start_ARG ⋅ end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , ∥ start_ARG ⋅ end_ARG ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT for the absolute value of a scalar, 2subscript2\ell_{2}roman_ℓ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT-norm of a vector or a matrix, and the Frobenius norm of a matrix, respectively. We use the notation g(t)=Θ(h(t))𝑔𝑡Θ𝑡g(t)=\Theta(h(t))italic_g ( italic_t ) = roman_Θ ( italic_h ( italic_t ) ) when there exists constants c0,c1,t0>0subscript𝑐0subscript𝑐1subscript𝑡00c_{0},c_{1},t_{0}>0italic_c start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT > 0 such that c0g(t)/h(t)c1subscript𝑐0𝑔𝑡𝑡subscript𝑐1c_{0}\leq g(t)/h(t)\leq c_{1}italic_c start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ≤ italic_g ( italic_t ) / italic_h ( italic_t ) ≤ italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT for tt0𝑡subscript𝑡0t\geq t_{0}italic_t ≥ italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. We omit the subscript n𝑛nitalic_n indicating the size of dataset, and t𝑡titalic_t indicating the time, for all quantities defined in this paper when no ambiguity arises.

Dataset

Let 𝒵=𝒳×𝒴𝒵𝒳𝒴\mathcal{Z}=\mathcal{X}\times\mathcal{Y}caligraphic_Z = caligraphic_X × caligraphic_Y be the sample space where 𝒳𝒳\mathcal{X}caligraphic_X and 𝒴𝒴\mathcal{Y}caligraphic_Y are input and output spaces, respectively. Consider a dataset Sn={zi=(xi,yi)}i[n]subscript𝑆𝑛subscriptsubscript𝑧𝑖subscript𝑥𝑖subscript𝑦𝑖𝑖delimited-[]𝑛S_{n}=\{z_{i}=(x_{i},y_{i})\}_{i\in[n]}italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT = { italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) } start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT of size n𝑛nitalic_n (each zi𝒵)z_{i}\in\mathcal{Z})italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ caligraphic_Z ) drawn i.i.d. from a distribution D𝐷Ditalic_D. Let Dnsuperscript𝐷𝑛D^{n}italic_D start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT denote the distribution of the dataset, i.e., SnDnsimilar-tosubscript𝑆𝑛superscript𝐷𝑛S_{n}\sim D^{n}italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ∼ italic_D start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT. Let Snisuperscriptsubscript𝑆𝑛𝑖S_{n}^{-i}italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT denote a modified dataset obtained by removing the i𝑖iitalic_i-th datum, i.e., Sni={z1,,zi1,zi+1,,zn}superscriptsubscript𝑆𝑛𝑖subscript𝑧1subscript𝑧𝑖1subscript𝑧𝑖1subscript𝑧𝑛S_{n}^{-i}=\{z_{1},...,z_{i-1},z_{i+1},...,z_{n}\}italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT = { italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_z start_POSTSUBSCRIPT italic_i - 1 end_POSTSUBSCRIPT , italic_z start_POSTSUBSCRIPT italic_i + 1 end_POSTSUBSCRIPT , … , italic_z start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT }.

Predictor and the loss function

We consider the predictor f:𝒲×𝒳𝒴:𝑓𝒲𝒳𝒴f:\mathcal{W}\times\mathcal{X}\to\mathcal{Y}italic_f : caligraphic_W × caligraphic_X → caligraphic_Y where 𝒲𝒲\mathcal{W}caligraphic_W is the weight space. Consider a loss function :𝒴×𝒴+:𝒴𝒴superscript\ell:\mathcal{Y}\times\mathcal{Y}\to\mathbb{R}^{+}roman_ℓ : caligraphic_Y × caligraphic_Y → blackboard_R start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT. As an example, for the cross-entropy loss on a C𝐶Citalic_C-class classification problem, 𝒴=C𝒴superscript𝐶\mathcal{Y}=\mathbb{R}^{C}caligraphic_Y = blackboard_R start_POSTSUPERSCRIPT italic_C end_POSTSUPERSCRIPT, and the loss (f(w,x),y)=j=1Cyjlogpj𝑓𝑤𝑥𝑦superscriptsubscript𝑗1𝐶superscript𝑦𝑗superscript𝑝𝑗\ell(f(w,x),y)=-\sum_{j=1}^{C}y^{j}\log p^{j}roman_ℓ ( italic_f ( italic_w , italic_x ) , italic_y ) = - ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C end_POSTSUPERSCRIPT italic_y start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT roman_log italic_p start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT, where pj=exp(fj(w,x))/(j=1Cexp(fj(w,x)))superscript𝑝𝑗superscript𝑓𝑗𝑤𝑥superscriptsubscript𝑗1𝐶superscript𝑓𝑗𝑤𝑥p^{j}=\exp(f^{j}(w,x))/\left(\sum_{j=1}^{C}\exp(f^{j}(w,x))\right)italic_p start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT = roman_exp ( start_ARG italic_f start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ( italic_w , italic_x ) end_ARG ) / ( ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C end_POSTSUPERSCRIPT roman_exp ( start_ARG italic_f start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ( italic_w , italic_x ) end_ARG ) ) with yjsuperscript𝑦𝑗y^{j}italic_y start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT and fjsuperscript𝑓𝑗f^{j}italic_f start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT denoting the j𝑗jitalic_j-th element of y𝑦yitalic_y and f(w,x)𝑓𝑤𝑥f(w,x)italic_f ( italic_w , italic_x ), respectively. We use the notation (w,z)(f(w,x),y)𝑤𝑧𝑓𝑤𝑥𝑦\ell(w,z)\equiv\ell(f(w,x),y)roman_ℓ ( italic_w , italic_z ) ≡ roman_ℓ ( italic_f ( italic_w , italic_x ) , italic_y ) as a shorthand. Let ¯(w,Sn)=(1/n)i=1n(w,zi)¯𝑤subscript𝑆𝑛1𝑛superscriptsubscript𝑖1𝑛𝑤subscript𝑧𝑖\bar{\ell}(w,S_{n})=(1/n)\sum_{i=1}^{n}\ell(w,z_{i})over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) = ( 1 / italic_n ) ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT roman_ℓ ( italic_w , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) be the average loss over the dataset Snsubscript𝑆𝑛S_{n}italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT.

Gradient flow

Let wn(t)subscript𝑤𝑛𝑡w_{n}(t)italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) and wni(t)superscriptsubscript𝑤𝑛𝑖𝑡w_{n}^{-i}(t)italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) denote solutions corresponding to the gradient flows

dwdt=(w,Sn),dwdt=(w,Sni),formulae-sequencederivative𝑡𝑤𝑤subscript𝑆𝑛derivative𝑡𝑤𝑤superscriptsubscript𝑆𝑛𝑖\displaystyle\derivative{w}{t}=-\nabla\ell(w,S_{n}),\quad\derivative{w}{t}=-% \nabla\ell(w,S_{n}^{-i}),divide start_ARG roman_d start_ARG italic_w end_ARG end_ARG start_ARG roman_d start_ARG italic_t end_ARG end_ARG = - ∇ roman_ℓ ( italic_w , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) , divide start_ARG roman_d start_ARG italic_w end_ARG end_ARG start_ARG roman_d start_ARG italic_t end_ARG end_ARG = - ∇ roman_ℓ ( italic_w , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ) , (1)

respectively. Unless otherwise specified, we assume that wn(t)subscript𝑤𝑛𝑡w_{n}(t)italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) and wni(t)superscriptsubscript𝑤𝑛𝑖𝑡w_{n}^{-i}(t)italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) are initialized at the same point for all i[n]𝑖delimited-[]𝑛i\in[n]italic_i ∈ [ italic_n ]. As a precursor, in Section 3, we will chose wn(0)subscript𝑤𝑛0w_{n}(0)italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) and wni(0)superscriptsubscript𝑤𝑛𝑖0w_{n}^{-i}(0)italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( 0 ) to be initializations of neural networks. In Section 4, we will also sometimes initialize wn(0)subscript𝑤𝑛0w_{n}(0)italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) and wni(0)superscriptsubscript𝑤𝑛𝑖0w_{n}^{-i}(0)italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( 0 ) to be the weights of neural networks that are not fully trained, in this case, wn(0)subscript𝑤𝑛0w_{n}(0)italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) and wni(0)superscriptsubscript𝑤𝑛𝑖0w_{n}^{-i}(0)italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( 0 ) are not necessarily the same.

Generalization gap

We will consider a few different measures of performance of a predictor trained with gradient flow. Given a predictor f𝑓fitalic_f, a loss function \ellroman_ℓ, and an initialization wn(0)subscript𝑤𝑛0w_{n}(0)italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ), the generalization loss and the train loss of gradient flow trained on Snsubscript𝑆𝑛S_{n}italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT at time t𝑡titalic_t are

R(Sn,t)𝑅subscript𝑆𝑛𝑡\displaystyle R(S_{n},t)italic_R ( italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_t ) =𝔼z[(wn(t),z)],absentsubscript𝔼𝑧delimited-[]subscript𝑤𝑛𝑡𝑧\displaystyle=\mathbb{E}_{z}[\ell(w_{n}(t),z)],= blackboard_E start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT [ roman_ℓ ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_z ) ] ,
Rtrain(Sn,t)subscript𝑅trainsubscript𝑆𝑛𝑡\displaystyle R_{\text{train}}(S_{n},t)italic_R start_POSTSUBSCRIPT train end_POSTSUBSCRIPT ( italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_t ) =¯(wn(t),Sn).absent¯subscript𝑤𝑛𝑡subscript𝑆𝑛\displaystyle=\bar{\ell}(w_{n}(t),S_{n}).= over¯ start_ARG roman_ℓ end_ARG ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) .

Our main quantity of interest is the generalization gap, defined as their difference

δR(Sn,t)=R(Sn,t)Rtrain(Sn,t).𝛿𝑅subscript𝑆𝑛𝑡𝑅subscript𝑆𝑛𝑡subscript𝑅trainsubscript𝑆𝑛𝑡\displaystyle\delta R(S_{n},t)=R(S_{n},t)-R_{\text{train}}(S_{n},t).italic_δ italic_R ( italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_t ) = italic_R ( italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_t ) - italic_R start_POSTSUBSCRIPT train end_POSTSUBSCRIPT ( italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_t ) .

The expected values of these quantities will be useful to us in Section 3. They are the expected generalization loss 𝔼Sn[R(Sn,t)]subscript𝔼subscript𝑆𝑛𝑅subscript𝑆𝑛𝑡\operatorname*{\mathbb{E}}_{S_{n}}\left[R(S_{n},t)\right]blackboard_E start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_R ( italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_t ) ], expected train loss 𝔼Sn[Rtrain(Sn,t)]subscript𝔼subscript𝑆𝑛subscript𝑅trainsubscript𝑆𝑛𝑡\operatorname*{\mathbb{E}}_{S_{n}}\left[R_{\text{train}}(S_{n},t)\right]blackboard_E start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_R start_POSTSUBSCRIPT train end_POSTSUBSCRIPT ( italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_t ) ] and the expected generalization gap 𝔼Sn[δR(Sn,t)]subscript𝔼subscript𝑆𝑛𝛿𝑅subscript𝑆𝑛𝑡\operatorname*{\mathbb{E}}_{S_{n}}\left[\delta R(S_{n},t)\right]blackboard_E start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_δ italic_R ( italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_t ) ]. The notations 𝔼Snsubscript𝔼subscript𝑆𝑛\operatorname*{\mathbb{E}}_{S_{n}}blackboard_E start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT and 𝔼zsubscript𝔼𝑧\operatorname*{\mathbb{E}}_{z}blackboard_E start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT denote expectations with respect to the random draw of dataset Snsubscript𝑆𝑛S_{n}italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT and the sample z𝑧zitalic_z, from distributions Dnsuperscript𝐷𝑛D^{n}italic_D start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT and D𝐷Ditalic_D, respectively. We sometimes omit the subscript Snsubscript𝑆𝑛S_{n}italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT and z𝑧zitalic_z in the following sections.

2.1 Contraction theory

This section introduces some preliminary material on contraction theory (Lohmiller and Slotine, 1998, 2000), which provides a way to analyze solutions of slightly different dynamical systems. Contraction theory rewrites Lyapunov theory (Isidori, 1995; Marino and Tomei, 1995) using a quadratic Lyapunov function, defined by a Riemannian contraction metric and its uniform positive definite matrix, characterizing the necessary and sufficient conditions for exponential convergence of the multiple trajectories to each other and the stability of these trajectories to perturbations of the dynamics. Consider a nonlinear dynamical system

dξdt=h(ξ,t).derivative𝑡𝜉𝜉𝑡\derivative{\xi}{t}=h(\xi,t).divide start_ARG roman_d start_ARG italic_ξ end_ARG end_ARG start_ARG roman_d start_ARG italic_t end_ARG end_ARG = italic_h ( italic_ξ , italic_t ) . (2)

The following theorem gives guarantees of the exponential convergence of trajectories with different initializations.

Theorem 1 (Theorem 2.1 from Tsukamoto et al. (2021)).

If there exists a uniformly positive definite matrix M(ξ,t)0succeeds𝑀𝜉𝑡0M(\xi,t)\succ 0italic_M ( italic_ξ , italic_t ) ≻ 0 for all ξ,t𝜉𝑡\xi,titalic_ξ , italic_t, such that the following condition holds for some α>0𝛼0\alpha>0italic_α > 0,

ξ,t:M˙+Mξh+ξhM2αM,\forall\xi,t:\quad\dot{M}+M\nabla_{\xi}h+\nabla_{\xi}h^{\top}M\preceq-2\alpha M,∀ italic_ξ , italic_t : over˙ start_ARG italic_M end_ARG + italic_M ∇ start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT italic_h + ∇ start_POSTSUBSCRIPT italic_ξ end_POSTSUBSCRIPT italic_h start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_M ⪯ - 2 italic_α italic_M , (3)

then all trajectories of Eq. 2 converge to a single trajectory under the metric induced by M𝑀Mitalic_M exponentially fast regardless of their initial conditions, i.e. for all trajectories ξ𝜉\xiitalic_ξ, ξsuperscript𝜉\xi^{\prime}italic_ξ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT of Eq. 2, d(ξ(t),ξ(t))Md(ξ(0),ξ(0))Meαt𝑑subscript𝜉𝑡superscript𝜉𝑡𝑀𝑑subscript𝜉0superscript𝜉0𝑀superscript𝑒𝛼𝑡d(\xi(t),\xi^{\prime}(t))_{M}\leq d(\xi(0),\xi^{\prime}(0))_{M}e^{-\alpha t}italic_d ( italic_ξ ( italic_t ) , italic_ξ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_t ) ) start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT ≤ italic_d ( italic_ξ ( 0 ) , italic_ξ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( 0 ) ) start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT italic_e start_POSTSUPERSCRIPT - italic_α italic_t end_POSTSUPERSCRIPT, where d(,)M𝑑subscript𝑀d(\cdot,\cdot)_{M}italic_d ( ⋅ , ⋅ ) start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT denotes the distance under the metric induced by M𝑀Mitalic_M. Dynamical system Eq. 2 satisfying Eq. 3 is said to be “contracting”, under the “contraction metric” induced by M𝑀Mitalic_M. The factor α𝛼\alphaitalic_α is defined to be the “contraction factor”.

Using Theorem 1, we can also analyze trajectories of a perturbed dynamical system

dξdt=h(x,t)+b(x,t).derivative𝑡𝜉𝑥𝑡𝑏𝑥𝑡\derivative{\xi}{t}=h(x,t)+b(x,t).divide start_ARG roman_d start_ARG italic_ξ end_ARG end_ARG start_ARG roman_d start_ARG italic_t end_ARG end_ARG = italic_h ( italic_x , italic_t ) + italic_b ( italic_x , italic_t ) . (4)

Let ξ0(t)subscript𝜉0𝑡\xi_{0}(t)italic_ξ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_t ), ξ1(t)subscript𝜉1𝑡\xi_{1}(t)italic_ξ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_t ) be solutions of Eq. 2 and Eq. 4, respectively. The next theorem shows that for a contracting system, the solution of the perturbed system does not differ too much from that of the original system, under certain conditions.

Theorem 2 (Theorem 2.3 from Tsukamoto et al. (2021)).

Assume that the dynamical system Eq. 2 is contracting under M𝑀Mitalic_M with factor α𝛼\alphaitalic_α. If b¯=supx,tb(x,t)¯𝑏subscriptsupremum𝑥𝑡norm𝑏𝑥𝑡\overline{b}=\sup_{x,t}\norm{b(x,t)}over¯ start_ARG italic_b end_ARG = roman_sup start_POSTSUBSCRIPT italic_x , italic_t end_POSTSUBSCRIPT ∥ start_ARG italic_b ( italic_x , italic_t ) end_ARG ∥ and there exist constants m¯,m¯>0¯𝑚¯𝑚0\underline{m},\overline{m}>0under¯ start_ARG italic_m end_ARG , over¯ start_ARG italic_m end_ARG > 0 such that m¯IM(x,t)m¯Iprecedes-or-equals¯𝑚𝐼𝑀𝑥𝑡precedes-or-equals¯𝑚𝐼\underline{m}I\preceq M(x,t)\preceq\overline{m}Iunder¯ start_ARG italic_m end_ARG italic_I ⪯ italic_M ( italic_x , italic_t ) ⪯ over¯ start_ARG italic_m end_ARG italic_I for all x,t𝑥𝑡x,titalic_x , italic_t, then we have

d(ξ1(t),ξ0(t))𝑑subscript𝜉1𝑡subscript𝜉0𝑡\displaystyle d(\xi_{1}(t),\xi_{0}(t))italic_d ( italic_ξ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_t ) , italic_ξ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_t ) ) d(ξ1(0),ξ0(0))m¯eαt+b¯αm¯m¯(1eαt),absent𝑑subscript𝜉10subscript𝜉00¯𝑚superscript𝑒𝛼𝑡¯𝑏𝛼¯𝑚¯𝑚1superscript𝑒𝛼𝑡\displaystyle\leq\frac{d(\xi_{1}(0),\xi_{0}(0))}{\sqrt{\underline{m}}}e^{-% \alpha t}+\frac{\overline{b}}{\alpha}\sqrt{\frac{\overline{m}}{\underline{m}}}% \left(1-e^{-\alpha t}\right),≤ divide start_ARG italic_d ( italic_ξ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( 0 ) , italic_ξ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( 0 ) ) end_ARG start_ARG square-root start_ARG under¯ start_ARG italic_m end_ARG end_ARG end_ARG italic_e start_POSTSUPERSCRIPT - italic_α italic_t end_POSTSUPERSCRIPT + divide start_ARG over¯ start_ARG italic_b end_ARG end_ARG start_ARG italic_α end_ARG square-root start_ARG divide start_ARG over¯ start_ARG italic_m end_ARG end_ARG start_ARG under¯ start_ARG italic_m end_ARG end_ARG end_ARG ( 1 - italic_e start_POSTSUPERSCRIPT - italic_α italic_t end_POSTSUPERSCRIPT ) ,
d(ξ1(t),ξ0(t))M𝑑subscriptsubscript𝜉1𝑡subscript𝜉0𝑡𝑀\displaystyle d(\xi_{1}(t),\xi_{0}(t))_{M}italic_d ( italic_ξ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_t ) , italic_ξ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_t ) ) start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT d(ξ1(0),ξ0(0))Meαt+b¯m¯α(1eαt),absent𝑑subscriptsubscript𝜉10subscript𝜉00𝑀superscript𝑒𝛼𝑡¯𝑏¯𝑚𝛼1superscript𝑒𝛼𝑡\displaystyle\leq d(\xi_{1}(0),\xi_{0}(0))_{M}e^{-\alpha t}+\frac{\overline{b}% \sqrt{\overline{m}}}{\alpha}\left(1-e^{-\alpha t}\right),≤ italic_d ( italic_ξ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( 0 ) , italic_ξ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( 0 ) ) start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT italic_e start_POSTSUPERSCRIPT - italic_α italic_t end_POSTSUPERSCRIPT + divide start_ARG over¯ start_ARG italic_b end_ARG square-root start_ARG over¯ start_ARG italic_m end_ARG end_ARG end_ARG start_ARG italic_α end_ARG ( 1 - italic_e start_POSTSUPERSCRIPT - italic_α italic_t end_POSTSUPERSCRIPT ) ,

where d(,)𝑑d(\cdot,\cdot)italic_d ( ⋅ , ⋅ ), d(,)M𝑑subscript𝑀d(\cdot,\cdot)_{M}italic_d ( ⋅ , ⋅ ) start_POSTSUBSCRIPT italic_M end_POSTSUBSCRIPT denote the distance under Euclidean metric and metric induced by M𝑀Mitalic_M respectively.

In short, for contracting systems, for large times t𝑡titalic_t, the bound of the distance between the solution of the original dynamic and the perturbed dynamic is determined by the perturbation of the system, the contraction factor, and the eigenvalues of the metric. In this paper, we will be interested in using these ideas to understand the difference between two trajectories evaluated on certain loss functions that are fitted using slightly different datasets.

Kozachkov et al. (2023) gave a bound on generalization gap using Theorem 2 by analyzing the difference of gradient flow trajectories trained on datasets with one replaced sample under the assumption that the dynamic is contracting uniformly on the state space with factor α𝛼\alphaitalic_α. In Lemma 4, we will define another notion of contraction that does not require a uniform α𝛼\alphaitalic_α, or the uniform boundedness of b𝑏bitalic_b. This will enable a more refined analysis of the generalization gap.

3 Methods

We first show that the generalization gap can be approximated by the “averaged loss difference” Δ¯n(t)subscript¯Δ𝑛𝑡\bar{\Delta}_{n}(t)over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) defined in Section 3.1. We will compute in Section 3.2 how Δ¯n(t)subscript¯Δ𝑛𝑡\bar{\Delta}_{n}(t)over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) evolves over time, and show that its dynamics arises from a contraction factor c¯nsubscript¯𝑐𝑛\bar{c}_{n}over¯ start_ARG italic_c end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT and a perturbation factor ϵ¯nsubscript¯italic-ϵ𝑛\bar{\epsilon}_{n}over¯ start_ARG italic_ϵ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT. In Section 3.3, we show that a certain “residual” rn(t)subscript𝑟𝑛𝑡\vec{r}_{n}(t)over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) (precisely, the derivative of the loss with respect to the predictor for each datum) largely controls the perturbation factor. Through the evolution of Δ¯n(t)subscript¯Δ𝑛𝑡\bar{\Delta}_{n}(t)over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) and rn(t)subscript𝑟𝑛𝑡\vec{r}_{n}(t)over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ), we can define an “effective Gram matrix” Knsubscript𝐾𝑛K_{n}italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT and a complexity measure in terms of a quadratic form rnKnrnsuperscriptsubscript𝑟𝑛topsubscript𝐾𝑛subscript𝑟𝑛\vec{r}_{n}^{\top}K_{n}\vec{r}_{n}over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT that characterizes the generalization gap at the end of training (Section 3.4). Proofs of all theorems and lemmas are deferred to Appendix A.

3.1 Approximation of the generalization gap

We first define two quantities pertaining to the difference between perturbed trajectories. Define the pointwise loss difference to be the difference of trajectories wni(t)superscriptsubscript𝑤𝑛𝑖𝑡w_{n}^{-i}(t)italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) and wn(t)subscript𝑤𝑛𝑡w_{n}(t)italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) evaluated in terms of the loss (w,zi)𝑤subscript𝑧𝑖\ell(w,z_{i})roman_ℓ ( italic_w , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ),

Δni(t)=(wni(t),zi)(wn(t),zi),superscriptsubscriptΔ𝑛𝑖𝑡superscriptsubscript𝑤𝑛𝑖𝑡subscript𝑧𝑖subscript𝑤𝑛𝑡subscript𝑧𝑖\Delta_{n}^{-i}(t)=\ell(w_{n}^{-i}(t),z_{i})-\ell(w_{n}(t),z_{i}),roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) = roman_ℓ ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - roman_ℓ ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ,

and the averaged loss difference is defined to be

Δ¯n(t)=1ni=1nΔni(t).subscript¯Δ𝑛𝑡1𝑛superscriptsubscript𝑖1𝑛superscriptsubscriptΔ𝑛𝑖𝑡\bar{\Delta}_{n}(t)=\frac{1}{n}\sum_{i=1}^{n}\Delta_{n}^{-i}(t).over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) = divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) . (5)

Note that in the averaged loss difference, we use the Leave-One-Out-Cross-Validation (LOOCV) loss as an estimate of the generalization loss. The following lemma shows how the expected generalization gap can be approximated by 𝔼[Δ¯n]𝔼subscript¯Δ𝑛\operatorname*{\mathbb{E}}\left[\bar{\Delta}_{n}\right]blackboard_E [ over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ].

Lemma 3.

Assume that the expected generalization loss 𝔼[R(Sn,t)]𝔼𝑅subscript𝑆𝑛𝑡\operatorname*{\mathbb{E}}\left[R(S_{n},t)\right]blackboard_E [ italic_R ( italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_t ) ] is non-increasing in n𝑛nitalic_n, the expected training loss 𝔼[Rtrain(Sn,t)]𝔼subscript𝑅trainsubscript𝑆𝑛𝑡\operatorname*{\mathbb{E}}\left[R_{\text{train}}(S_{n},t)\right]blackboard_E [ italic_R start_POSTSUBSCRIPT train end_POSTSUBSCRIPT ( italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_t ) ] is non-decreasing in n𝑛nitalic_n and the expected generalization gap 𝔼[δR(Sn,t)]𝔼𝛿𝑅subscript𝑆𝑛𝑡\operatorname*{\mathbb{E}}\left[\delta R(S_{n},t)\right]blackboard_E [ italic_δ italic_R ( italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_t ) ] is non-negative for all n,t𝑛𝑡n,titalic_n , italic_t. Then

𝔼[δR(Sn,t)]𝔼[Δ¯n(t)]𝔼[δR(Sn1,t)].𝔼𝛿𝑅subscript𝑆𝑛𝑡𝔼delimited-[]subscript¯Δ𝑛𝑡𝔼𝛿𝑅subscript𝑆𝑛1𝑡\displaystyle\operatorname*{\mathbb{E}}\left[\delta R(S_{n},t)\right]\leq% \mathbb{E}\left[\bar{\Delta}_{n}(t)\right]\leq\operatorname*{\mathbb{E}}\left[% \delta R(S_{n-1},t)\right].blackboard_E [ italic_δ italic_R ( italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_t ) ] ≤ blackboard_E [ over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) ] ≤ blackboard_E [ italic_δ italic_R ( italic_S start_POSTSUBSCRIPT italic_n - 1 end_POSTSUBSCRIPT , italic_t ) ] .

If we also have 𝔼[δR(Sn,t)]/𝔼[δR(Sn1,t)]1𝔼𝛿𝑅subscript𝑆𝑛𝑡𝔼𝛿𝑅subscript𝑆𝑛1𝑡1\operatorname*{\mathbb{E}}[\delta R(S_{n},t)]/\operatorname*{\mathbb{E}}[% \delta R(S_{n-1},t)]\to 1blackboard_E [ italic_δ italic_R ( italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_t ) ] / blackboard_E [ italic_δ italic_R ( italic_S start_POSTSUBSCRIPT italic_n - 1 end_POSTSUBSCRIPT , italic_t ) ] → 1 as n𝑛n\to\inftyitalic_n → ∞, then,

𝔼[δR(Sn,t)]=𝔼[Δ¯n(t)]+o(𝔼[δR(Sn,t)]).𝔼𝛿𝑅subscript𝑆𝑛𝑡𝔼subscript¯Δ𝑛𝑡𝑜𝔼𝛿𝑅subscript𝑆𝑛𝑡\operatorname*{\mathbb{E}}\left[\delta R(S_{n},t)\right]=\operatorname*{% \mathbb{E}}\left[\bar{\Delta}_{n}(t)\right]+o\left(\operatorname*{\mathbb{E}}% \left[\delta R(S_{n},t)\right]\right).blackboard_E [ italic_δ italic_R ( italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_t ) ] = blackboard_E [ over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) ] + italic_o ( blackboard_E [ italic_δ italic_R ( italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_t ) ] ) .

The concentration of Δ¯n(t)subscript¯Δ𝑛𝑡\bar{\Delta}_{n}(t)over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) to 𝔼[Δ¯n(t)]𝔼subscript¯Δ𝑛𝑡\operatorname*{\mathbb{E}}\left[\bar{\Delta}_{n}(t)\right]blackboard_E [ over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) ] can also be guaranteed if algorithm stability is assumed (see Lemma 22). Hence, the expected generalization gap 𝔼Sn[δR(Sn,t)]subscript𝔼subscript𝑆𝑛𝛿𝑅subscript𝑆𝑛𝑡\operatorname*{\mathbb{E}}_{S_{n}}\left[\delta R(S_{n},t)\right]blackboard_E start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_δ italic_R ( italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_t ) ] can be well approximated by the averaged loss difference Δ¯n(t)subscript¯Δ𝑛𝑡\bar{\Delta}_{n}(t)over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) under certain conditions. See Section 3.1, Table S.1 for numerical results of generalization gap and averaged loss difference. We will next study the evolution of Δ¯n(t)subscript¯Δ𝑛𝑡\bar{\Delta}_{n}(t)over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ).

3.2 Evolution of the averaged loss difference

The pointwise loss difference ΔnisuperscriptsubscriptΔ𝑛𝑖\Delta_{n}^{-i}roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT describes the difference of two trajectories with slightly perturbed drifts. By deriving differential equations for the evolution of ΔnisuperscriptsubscriptΔ𝑛𝑖\Delta_{n}^{-i}roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT and Δ¯nsubscript¯Δ𝑛\bar{\Delta}_{n}over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT, we analyze the contraction and perturbation of the trajectories in a way that is non-uniform in both time and space, distinguishing it from classical contraction theory. We first give the following lemma for ΔnisuperscriptsubscriptΔ𝑛𝑖\Delta_{n}^{-i}roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT.

Lemma 4.

For loss functions (w,z)𝑤𝑧\ell(w,z)roman_ℓ ( italic_w , italic_z ) that is differentiable in w𝑤witalic_w for all z𝑧zitalic_z,

dΔni(t)dt=cni(t)Δni(t)+ϵni(t),derivative𝑡subscriptsuperscriptΔ𝑖𝑛𝑡superscriptsubscript𝑐𝑛𝑖𝑡superscriptsubscriptΔ𝑛𝑖𝑡superscriptsubscriptitalic-ϵ𝑛𝑖𝑡\displaystyle\derivative{\Delta^{-i}_{n}(t)}{t}=-c_{n}^{-i}(t)\Delta_{n}^{-i}(% t)+\epsilon_{n}^{-i}(t),divide start_ARG roman_d start_ARG roman_Δ start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_ARG end_ARG start_ARG roman_d start_ARG italic_t end_ARG end_ARG = - italic_c start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) + italic_ϵ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) ,

where the pointwise contraction factor cni(t)superscriptsubscript𝑐𝑛𝑖𝑡c_{n}^{-i}(t)italic_c start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) is given by

cni(t)=(w,zi)¯(w,Sni)|wn(t)wni(t)Δni(t),superscriptsubscript𝑐𝑛𝑖𝑡evaluated-at𝑤subscript𝑧𝑖¯𝑤superscriptsubscript𝑆𝑛𝑖subscript𝑤𝑛𝑡superscriptsubscript𝑤𝑛𝑖𝑡superscriptsubscriptΔ𝑛𝑖𝑡\displaystyle c_{n}^{-i}(t)=\frac{\nabla\ell(w,z_{i})\cdot\nabla\bar{\ell}(w,S% _{n}^{-i})\big{|}^{w_{n}^{-i}(t)}_{w_{n}(t)}}{\Delta_{n}^{-i}(t)},italic_c start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) = divide start_ARG ∇ roman_ℓ ( italic_w , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ⋅ ∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ) | start_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_POSTSUBSCRIPT end_ARG start_ARG roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) end_ARG ,

and the pointwise perturbation factor ϵni(t)superscriptsubscriptitalic-ϵ𝑛𝑖𝑡\epsilon_{n}^{-i}(t)italic_ϵ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) is given by

ϵni(t)=(w,zi)(¯(w,Sn)¯(w,Sni))|wn(t).superscriptsubscriptitalic-ϵ𝑛𝑖𝑡evaluated-at𝑤subscript𝑧𝑖¯𝑤subscript𝑆𝑛¯𝑤superscriptsubscript𝑆𝑛𝑖subscript𝑤𝑛𝑡\displaystyle\epsilon_{n}^{-i}(t)=\nabla\ell(w,z_{i})\cdot\left(\nabla\bar{% \ell}(w,S_{n})-\nabla\bar{\ell}(w,S_{n}^{-i})\right)\bigg{|}_{w_{n}(t)}.italic_ϵ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) = ∇ roman_ℓ ( italic_w , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ⋅ ( ∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) - ∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ) ) | start_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_POSTSUBSCRIPT .

We should note that the lemma can be extended to any piecewise differentiable loss if we define the the gradient (w,z)𝑤𝑧\nabla\ell(w,z)∇ roman_ℓ ( italic_w , italic_z ) at the non-differentiable point to be any constant vector with bounded norm. This covers all of the commonly used architectures and activation functions. To give some intuition of the lemma, the contraction factor cni(t)superscriptsubscript𝑐𝑛𝑖𝑡c_{n}^{-i}(t)italic_c start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) represents a force that pulls two trajectories with the same drift but different values at time t𝑡titalic_t closer together under the loss function (w,zi)𝑤subscript𝑧𝑖\ell(w,z_{i})roman_ℓ ( italic_w , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ), while the perturbation factor quantifies the differences between the two trajectories at time t𝑡titalic_t induced by the gradient divergence ¯(w,Sn)¯(w,Sni)¯𝑤subscript𝑆𝑛¯𝑤superscriptsubscript𝑆𝑛𝑖\nabla\bar{\ell}(w,S_{n})-\nabla\bar{\ell}(w,S_{n}^{-i})∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) - ∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ).

Remark 5 (Deviations from classical contraction theory).

In classical contraction theory, α𝛼\alphaitalic_α and b¯¯𝑏\overline{b}over¯ start_ARG italic_b end_ARG provide a uniform contraction rate and perturbation magnitude over time and trajectories (Theorems 1 and 2). In comparison, cni(t)superscriptsubscript𝑐𝑛𝑖𝑡c_{n}^{-i}(t)italic_c start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) and ϵni(t)superscriptsubscriptitalic-ϵ𝑛𝑖𝑡\epsilon_{n}^{-i}(t)italic_ϵ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) in Lemma 4 are derived directly from the evolution of Δ¯ni(t)superscriptsubscript¯Δ𝑛𝑖𝑡\bar{\Delta}_{n}^{-i}(t)over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ), which describes only the contraction and perturbation of wn(t)subscript𝑤𝑛𝑡w_{n}(t)italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) and wni(t)superscriptsubscript𝑤𝑛𝑖𝑡w_{n}^{-i}(t)italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ), and vary with time. This non-uniformity allows for a more refined analysis of gradient flow for neural networks. Indeed, the energy landscape may not be uniformly good in the entire weight space, but it could be benign along most of the training trajectory. Our development in this paper from here on will therefore diverge significantly from the generalization bounds derived under uniform assumptions on the energy landscape (Kozachkov et al., 2023; Charles and Papailiopoulos, 2018).

By taking the average over the numerator and denominator of cni(t)superscriptsubscript𝑐𝑛𝑖𝑡c_{n}^{-i}(t)italic_c start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ), and averaging over ϵni(t)superscriptsubscriptitalic-ϵ𝑛𝑖𝑡\epsilon_{n}^{-i}(t)italic_ϵ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) in Lemma 4, we obtain the following equation for the averaged loss difference Δ¯n(t)subscript¯Δ𝑛𝑡\bar{\Delta}_{n}(t)over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ):

dΔ¯n(t)dt=c¯n(t)Δ¯n(t)+ϵ¯n(t).derivative𝑡subscript¯Δ𝑛𝑡subscript¯𝑐𝑛𝑡subscript¯Δ𝑛𝑡subscript¯italic-ϵ𝑛𝑡\displaystyle\derivative{\bar{\Delta}_{n}(t)}{t}=-\bar{c}_{n}(t)\bar{\Delta}_{% n}(t)+\bar{\epsilon}_{n}(t).divide start_ARG roman_d start_ARG over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_ARG end_ARG start_ARG roman_d start_ARG italic_t end_ARG end_ARG = - over¯ start_ARG italic_c end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) + over¯ start_ARG italic_ϵ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) . (6)

The solution of this differential equation can be written in the integral form as

Δ¯n(t)=0tϵ¯n(s)exp(stc¯n(u)du)dssubscript¯Δ𝑛𝑡superscriptsubscript0𝑡subscript¯italic-ϵ𝑛𝑠superscriptsubscript𝑠𝑡subscript¯𝑐𝑛𝑢𝑢𝑠\displaystyle\bar{\Delta}_{n}(t)=\int_{0}^{t}\bar{\epsilon}_{n}(s)\exp\left(% \int_{s}^{t}-\bar{c}_{n}(u)\differential{u}\right)\differential{s}over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) = ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT over¯ start_ARG italic_ϵ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_s ) roman_exp ( ∫ start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT - over¯ start_ARG italic_c end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_u ) roman_d start_ARG italic_u end_ARG ) roman_d start_ARG italic_s end_ARG (7)

with the assumption that wn(0)=wni(0)subscript𝑤𝑛0superscriptsubscript𝑤𝑛𝑖0w_{n}(0)=w_{n}^{-i}(0)italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) = italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( 0 ) for all i𝑖iitalic_i. Here c¯n(t)subscript¯𝑐𝑛𝑡\bar{c}_{n}(t)over¯ start_ARG italic_c end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) is defined to be the averaged contraction factor

c¯n(t)=1ni=1n(w,zi)¯(w,Sni)|wn(t)wni(t)Δ¯n(t),subscript¯𝑐𝑛𝑡evaluated-at1𝑛superscriptsubscript𝑖1𝑛𝑤subscript𝑧𝑖¯𝑤superscriptsubscript𝑆𝑛𝑖subscript𝑤𝑛𝑡superscriptsubscript𝑤𝑛𝑖𝑡subscript¯Δ𝑛𝑡\displaystyle\bar{c}_{n}(t)=\frac{\frac{1}{n}\sum_{i=1}^{n}\nabla\ell(w,z_{i})% \cdot\nabla\bar{\ell}(w,S_{n}^{-i})\big{|}_{w_{n}(t)}^{w_{n}^{-i}(t)}}{\bar{% \Delta}_{n}(t)},over¯ start_ARG italic_c end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) = divide start_ARG divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∇ roman_ℓ ( italic_w , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ⋅ ∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ) | start_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_ARG start_ARG over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_ARG , (8)

and ϵ¯n(t)subscript¯italic-ϵ𝑛𝑡\bar{\epsilon}_{n}(t)over¯ start_ARG italic_ϵ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) is the averaged perturbation factor

ϵ¯n(t)=trΣ^n(t)n1,Σ^n(t)=CovzUnif(Sn)(wn(t),z),formulae-sequencesubscript¯italic-ϵ𝑛𝑡tracesubscript^Σ𝑛𝑡𝑛1subscript^Σ𝑛𝑡subscriptCovsimilar-to𝑧Unifsubscript𝑆𝑛subscript𝑤𝑛𝑡𝑧\displaystyle\bar{\epsilon}_{n}(t)=\frac{\tr\hat{\Sigma}_{n}(t)}{n-1},\quad% \hat{\Sigma}_{n}(t)=\operatorname*{\text{Cov}}_{z\sim\text{Unif}(S_{n})}\nabla% \ell(w_{n}(t),z),over¯ start_ARG italic_ϵ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) = divide start_ARG roman_tr over^ start_ARG roman_Σ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_ARG start_ARG italic_n - 1 end_ARG , over^ start_ARG roman_Σ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) = cov start_POSTSUBSCRIPT italic_z ∼ Unif ( italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT ∇ roman_ℓ ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_z ) , (9)

where Σ^n(t)subscript^Σ𝑛𝑡\hat{\Sigma}_{n}(t)over^ start_ARG roman_Σ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) represents the covariance matrix of (wn(t),z)subscript𝑤𝑛𝑡𝑧\nabla\ell(w_{n}(t),z)∇ roman_ℓ ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_z ) for z𝑧zitalic_z sampled uniformly from the dataset Snsubscript𝑆𝑛S_{n}italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT. We should note that ϵ¯n(t)subscript¯italic-ϵ𝑛𝑡\bar{\epsilon}_{n}(t)over¯ start_ARG italic_ϵ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) is a statistic that depends only on the training samples, while c¯n(t)subscript¯𝑐𝑛𝑡\bar{c}_{n}(t)over¯ start_ARG italic_c end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) depends on both training samples and the held-out test samples. Note that by taking the expectation over ϵ¯nsubscript¯italic-ϵ𝑛\bar{\epsilon}_{n}over¯ start_ARG italic_ϵ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT, and the numerator and denominator of c¯nsubscript¯𝑐𝑛\bar{c}_{n}over¯ start_ARG italic_c end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT, we get the evolution of 𝔼[Δ¯n(t)]𝔼subscript¯Δ𝑛𝑡\operatorname*{\mathbb{E}}\left[\bar{\Delta}_{n}(t)\right]blackboard_E [ over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) ], which represents the generalization gap better. See Section A.3 for details.

Remark 6 (Classical contraction theory with uniform bounds on contraction and perturbation).

With uniform guarantees ϵ¯n(t)ϵsubscript¯italic-ϵ𝑛𝑡superscriptitalic-ϵ\bar{\epsilon}_{n}(t)\leq\epsilon^{*}over¯ start_ARG italic_ϵ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) ≤ italic_ϵ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT and c¯n(t)csubscript¯𝑐𝑛𝑡superscript𝑐\bar{c}_{n}(t)\geq c^{*}over¯ start_ARG italic_c end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) ≥ italic_c start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT for all t𝑡titalic_t for some positive ϵ,csuperscriptitalic-ϵsuperscript𝑐\epsilon^{*},c^{*}italic_ϵ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , italic_c start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT, we can solve Eq. 7 to see that

Δ¯n(t)ϵc(1exp(ct)),subscript¯Δ𝑛𝑡superscriptitalic-ϵsuperscript𝑐1superscript𝑐𝑡\bar{\Delta}_{n}(t)\leq\frac{\epsilon^{*}}{c^{*}}\left(1-\exp(-c^{*}t)\right),over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) ≤ divide start_ARG italic_ϵ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG start_ARG italic_c start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG ( 1 - roman_exp ( start_ARG - italic_c start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT italic_t end_ARG ) ) ,

which derives similar bound as in Theorem 2.

Remark 7 (Comparing trajectories in terms of their loss vs. weight space difference).

Richards and Kuzborskij (2021); Akbari et al. (2021) analyze the difference of algorithm in the weight space when one sample is replaced, and derive generalization bound using the Lipchitz assumption of the loss function. However, in most cases, a uniform Lipchitz constant is far from good for most part of the weight space for deep networks. In such cases, the difference in the weights does not provide a tight estimate of the difference of the predictions—and this is the key reason for loose generalization bounds from this kind of analysis. By comparing the difference of wni(t)superscriptsubscript𝑤𝑛𝑖𝑡w_{n}^{-i}(t)italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) and wn(t)subscript𝑤𝑛𝑡w_{n}(t)italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) directly in terms of the loss (w,zi)𝑤subscript𝑧𝑖\ell(w,z_{i})roman_ℓ ( italic_w , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) in Lemma 4, instead of the weight space (with or without a modified Euclidean metric) as used in Theorem 1, in this paper, we can achieve a tighter estimate of the evolution of the generalization gap.

Remark 8 (Relationship to information theoretic generalization bounds).

In Negrea et al. (2019); Neu et al. (2021); Banerjee et al. (2022), the authors derive generalization bounds controlled by the sum of trace of the gradient covariance along the training trajectory, ttrΣ^n(t)subscript𝑡tracesubscript^Σ𝑛𝑡\sum_{t}\tr\hat{\Sigma}_{n}(t)∑ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT roman_tr over^ start_ARG roman_Σ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ). Intuitively, this summation tells us about the size of the tube of trajectories in loss space that arises from training on different datasets. The worse the estimate of this tube, the worse the generalization bound. Our expression in Eq. 7 provides a more general and tighter formulation, where the damping factor exp(stc¯n(u)du)superscriptsubscript𝑠𝑡subscript¯𝑐𝑛𝑢𝑑𝑢\exp\left(\int_{s}^{t}-\bar{c}_{n}(u)du\right)roman_exp ( ∫ start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT - over¯ start_ARG italic_c end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_u ) italic_d italic_u ) corrects for the size of this tube. A positive contraction factor leads to quicker shrinking of the two trajectories that are being trained on slightly different datasets. Furthermore, our analysis applies to deterministic algorithms, unlike previous works on information-theoretic bounds (Xu and Raginsky, 2017; Mou et al., 2018; Futami and Fujisawa, 2023), which holds only for randomized algorithms because the proof relies on the non-expansiveness of the Kullback-Leibler divergence of non-singular distributions.

Remark 9 (Some intuition on the contraction factor).

If we expand the contraction factor in Lemma 4 using the first-order Taylor expansion in both its numerator and denominator, and by approximating (w,S(m))𝑤subscript𝑆𝑚\ell(w,S_{(m)})roman_ℓ ( italic_w , italic_S start_POSTSUBSCRIPT ( italic_m ) end_POSTSUBSCRIPT ) and ¯(w,Sn(m))¯𝑤superscriptsubscript𝑆𝑛𝑚\bar{\ell}(w,S_{n}^{(m)})over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_m ) end_POSTSUPERSCRIPT ) by the loss on the full dataset ¯(w,Sn)¯𝑤subscript𝑆𝑛\bar{\ell}(w,S_{n})over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ), we see that

(10)

This is a general version of Rayleigh quotient xAy/xysuperscript𝑥top𝐴𝑦superscript𝑥top𝑦x^{\top}Ay/x^{\top}yitalic_x start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_A italic_y / italic_x start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_y, where x¯(wn(t),Sn)𝑥¯subscript𝑤𝑛𝑡subscript𝑆𝑛x\equiv\nabla\bar{\ell}(w_{n}(t),S_{n})italic_x ≡ ∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ), y𝔼(m)[wn(m)(t)wn(t)]𝑦subscript𝔼𝑚delimited-[]superscriptsubscript𝑤𝑛𝑚𝑡subscript𝑤𝑛𝑡y\equiv\mathbb{E}_{(m)}\left[w_{n}^{-(m)}(t)-w_{n}(t)\right]italic_y ≡ blackboard_E start_POSTSUBSCRIPT ( italic_m ) end_POSTSUBSCRIPT [ italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - ( italic_m ) end_POSTSUPERSCRIPT ( italic_t ) - italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) ], and A2¯(wn(t),Sn)𝐴superscript2¯subscript𝑤𝑛𝑡subscript𝑆𝑛A\equiv\nabla^{2}\bar{\ell}(w_{n}(t),S_{n})italic_A ≡ ∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT over¯ start_ARG roman_ℓ end_ARG ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ). Intuitively, positive contraction implies that the Hessian does not change the cosine angle of the gradient and the averaged difference of trajectories. Fig. 1 compares the true contraction factor and the full-gradient approximation Eq. 10. We can see the approximated contraction factor is positive (which indicates contractive dynamics) and that it is also close to the true contraction factor. This suggests that the Hessian of the training loss is positive definite along the directions of gradient and the averaged difference of trajectories, for most of the training time. See Section 4 for the batch version of the contraction factor c¯n(t)subscript¯𝑐𝑛𝑡\bar{c}_{n}(t)over¯ start_ARG italic_c end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ), and Section A.8 for the detailed calculations of Eq. 10.

Refer to caption
Figure 1: The contraction factor calculated through its analytical expression in Eq. 8 (orange) compared to its approximation using Eq. 10 (blue) for FC trained on MNIST with two selected classes, n=1000𝑛1000n=1000italic_n = 1000, m=100𝑚100m=100italic_m = 100.

3.3 Evolution of the residual and perturbation

In Section 3.2, we have shown that the evolution of Δ¯nsubscript¯Δ𝑛\bar{\Delta}_{n}over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT is controlled by the averaged perturbation ϵ¯nsubscript¯italic-ϵ𝑛\bar{\epsilon}_{n}over¯ start_ARG italic_ϵ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT and the averaged contraction c¯nsubscript¯𝑐𝑛\bar{c}_{n}over¯ start_ARG italic_c end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT, and that ϵ¯nsubscript¯italic-ϵ𝑛\bar{\epsilon}_{n}over¯ start_ARG italic_ϵ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT is closely related to the trace of the covariance of gradients. We will now introduce the notion of a “residual”. We further show how it relates to the perturbation factor, from which we derive the evolution of ϵ¯n(t)subscript¯italic-ϵ𝑛𝑡\bar{\epsilon}_{n}(t)over¯ start_ARG italic_ϵ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ).

Let r(w,z)=d(w,z)df(w,z)𝒴𝑟𝑤𝑧derivative𝑓𝑤𝑧𝑤𝑧𝒴r(w,z)=\derivative{\ell(w,z)}{f(w,z)}\in\mathcal{Y}italic_r ( italic_w , italic_z ) = divide start_ARG roman_d start_ARG roman_ℓ ( italic_w , italic_z ) end_ARG end_ARG start_ARG roman_d start_ARG italic_f ( italic_w , italic_z ) end_ARG end_ARG ∈ caligraphic_Y denote the gradient of the loss function with respect to the predictor f𝑓fitalic_f. Let ri(t)r(wn(t),zi)subscript𝑟𝑖𝑡𝑟subscript𝑤𝑛𝑡subscript𝑧𝑖r_{i}(t)\equiv r(w_{n}(t),z_{i})italic_r start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t ) ≡ italic_r ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) denote predictor gradient on zisubscript𝑧𝑖z_{i}italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, evaluated on weight wn(t)subscript𝑤𝑛𝑡w_{n}(t)italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ). We define the residual on dataset Snsubscript𝑆𝑛S_{n}italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT at time t𝑡titalic_t to be

rn(t)=1n[r1(t),,rn(t)]𝒴n.subscript𝑟𝑛𝑡1𝑛superscriptsubscript𝑟1𝑡subscript𝑟𝑛𝑡topsuperscript𝒴𝑛\vec{r}_{n}(t)=\frac{1}{\sqrt{n}}[r_{1}(t),\dots,r_{n}(t)]^{\top}\in\mathcal{Y% }^{n}.over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) = divide start_ARG 1 end_ARG start_ARG square-root start_ARG italic_n end_ARG end_ARG [ italic_r start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_t ) , … , italic_r start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∈ caligraphic_Y start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT . (11)

The residual is the collection of loss-predictor gradients on the dataset Snsubscript𝑆𝑛S_{n}italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT. It effectively describes the quality of the weights at time t𝑡titalic_t and indicates the direction of the training progresses in the predictor space. Intuitively, it represents the part of the “task” that remains to be fitted at time t𝑡titalic_t. As a special case, if we consider the squared loss (y,y)=12(yy)2𝑦superscript𝑦12superscript𝑦superscript𝑦2\ell(y,y^{\prime})=\frac{1}{2}(y-y^{\prime})^{2}roman_ℓ ( italic_y , italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) = divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( italic_y - italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT for y,y𝑦superscript𝑦y,y^{\prime}\in\mathbb{R}italic_y , italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ blackboard_R, the residual is the normalized displacement vector from the predictor to the target, i.e., rn(t)=1n(fy)subscript𝑟𝑛𝑡1𝑛𝑓𝑦\vec{r}_{n}(t)=\frac{1}{\sqrt{n}}\left(\vec{f}-\vec{y}\right)over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) = divide start_ARG 1 end_ARG start_ARG square-root start_ARG italic_n end_ARG end_ARG ( over→ start_ARG italic_f end_ARG - over→ start_ARG italic_y end_ARG ), where y[y1,,yn]𝑦superscriptsubscript𝑦1subscript𝑦𝑛top\vec{y}\equiv[y_{1},\dots,y_{n}]^{\top}over→ start_ARG italic_y end_ARG ≡ [ italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_y start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT, and f[f(wn(t),x1),,f(wn(t),xn)]𝑓superscript𝑓subscript𝑤𝑛𝑡subscript𝑥1𝑓subscript𝑤𝑛𝑡subscript𝑥𝑛top\vec{f}\equiv[f(w_{n}(t),x_{1}),\dots,f(w_{n}(t),x_{n})]^{\top}over→ start_ARG italic_f end_ARG ≡ [ italic_f ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , … , italic_f ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_x start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT. If we initialize the predictor such that f(wn(0),xi)=0𝑓subscript𝑤𝑛0subscript𝑥𝑖0f(w_{n}(0),x_{i})=0italic_f ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) , italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) = 0 for all i[n]𝑖delimited-[]𝑛i\in[n]italic_i ∈ [ italic_n ], then the 2subscript2\ell_{2}roman_ℓ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT-norm of the residual rn(0)2subscriptnormsubscript𝑟𝑛02\norm{\vec{r}_{n}(0)}_{2}∥ start_ARG over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT is the largest at initialization and vanishes at interpolation following the gradient flow Eq. 1. The definition of residual generalizes the displacement vector in the squared loss case, and can be applied to any loss function with global minimum 0. The factor 1/n1𝑛1/\sqrt{n}1 / square-root start_ARG italic_n end_ARG will be justified in Remark 20.

The evolution of rn(t)subscript𝑟𝑛𝑡\vec{r}_{n}(t)over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) is governed by the following equation derived from gradient flow in Eq. 1.

drn(t)dtderivative𝑡subscript𝑟𝑛𝑡\displaystyle\derivative{\vec{r}_{n}(t)}{t}divide start_ARG roman_d start_ARG over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_ARG end_ARG start_ARG roman_d start_ARG italic_t end_ARG end_ARG =1nPn(t)rn(t),absent1𝑛subscript𝑃𝑛𝑡subscript𝑟𝑛𝑡\displaystyle=-\frac{1}{n}P_{n}(t)\vec{r}_{n}(t),= - divide start_ARG 1 end_ARG start_ARG italic_n end_ARG italic_P start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , (12)
Pn(t)subscript𝑃𝑛𝑡\displaystyle P_{n}(t)italic_P start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) =[r(wn(t),zi)f(wn(t),xj)]i,j[n].absentsubscriptdelimited-[]𝑟superscriptsubscript𝑤𝑛𝑡subscript𝑧𝑖top𝑓subscript𝑤𝑛𝑡subscript𝑥𝑗𝑖𝑗delimited-[]𝑛\displaystyle=\left[\nabla r(w_{n}(t),z_{i})^{\top}\nabla f(w_{n}(t),x_{j})% \right]_{i,j\in[n]}.= [ ∇ italic_r ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∇ italic_f ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ] start_POSTSUBSCRIPT italic_i , italic_j ∈ [ italic_n ] end_POSTSUBSCRIPT .

This is a linear time-varying ordinary differential equation. In general, its solution can be written as

rn(t)=Ωn(t0,t)rn(t0),subscript𝑟𝑛𝑡subscriptΩ𝑛subscript𝑡0𝑡subscript𝑟𝑛subscript𝑡0\vec{r}_{n}(t)=\Omega_{n}(t_{0},t)\ \vec{r}_{n}(t_{0}),over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) = roman_Ω start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_t ) over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) , (13)

where Ωn(t0,t)subscriptΩ𝑛subscript𝑡0𝑡\Omega_{n}(t_{0},t)roman_Ω start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_t ) is called the propagator. The numerical approximation of Ωn(t0,t)subscriptΩ𝑛subscript𝑡0𝑡\Omega_{n}(t_{0},t)roman_Ω start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_t ) will be discussed in Section B.4.

Our next goal will be to show that the averaged perturbation factor ϵ¯nsubscript¯italic-ϵ𝑛\bar{\epsilon}_{n}over¯ start_ARG italic_ϵ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT is controlled by the residual. We will do so using the following lemma.

Lemma 10.

The trace of the gradient covariance Σ^n(t)subscript^Σ𝑛𝑡\hat{\Sigma}_{n}(t)over^ start_ARG roman_Σ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) can be decomposed in terms of two matrices Mn,Hn𝒴n×𝒴nsubscript𝑀𝑛subscript𝐻𝑛superscript𝒴𝑛superscript𝒴𝑛M_{n},H_{n}\in\mathcal{Y}^{n}\times\mathcal{Y}^{n}italic_M start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_H start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ∈ caligraphic_Y start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT × caligraphic_Y start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT as

trΣ^n(t)=rn(t)(Mn(t)Hn(t)n)rn(t),tracesubscript^Σ𝑛𝑡subscript𝑟𝑛superscript𝑡topsubscript𝑀𝑛𝑡subscript𝐻𝑛𝑡𝑛subscript𝑟𝑛𝑡\displaystyle\tr\hat{\Sigma}_{n}(t)=\vec{r}_{n}(t)^{\top}\left(M_{n}(t)-\frac{% H_{n}(t)}{n}\right)\vec{r}_{n}(t),roman_tr over^ start_ARG roman_Σ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) = over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( italic_M start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) - divide start_ARG italic_H start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_ARG start_ARG italic_n end_ARG ) over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , (14)
Mn(t)subscript𝑀𝑛𝑡\displaystyle M_{n}(t)italic_M start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) =diag(f(w,x1)f(w,x1),,f(w,xn)f(w,xn))|wn(t),absentevaluated-atdiag𝑓superscript𝑤subscript𝑥1top𝑓𝑤subscript𝑥1𝑓superscript𝑤subscript𝑥𝑛top𝑓𝑤subscript𝑥𝑛subscript𝑤𝑛𝑡\displaystyle=\scalebox{1.0}{\mbox{$\displaystyle\operatorname{\text{diag}}% \left(\nabla f(w,x_{1})^{\top}\nabla f(w,x_{1}),\dots,\nabla f(w,x_{n})^{\top}% \nabla f(w,x_{n})\right)\bigg{|}_{w_{n}(t)}$}},= diag ( ∇ italic_f ( italic_w , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∇ italic_f ( italic_w , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , … , ∇ italic_f ( italic_w , italic_x start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∇ italic_f ( italic_w , italic_x start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) ) | start_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_POSTSUBSCRIPT , (15)
Hn(t)subscript𝐻𝑛𝑡\displaystyle H_{n}(t)italic_H start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) =[f(wn(t),xi)f(wn(t),xj)]i,j[n].absentsubscriptdelimited-[]𝑓superscriptsubscript𝑤𝑛𝑡subscript𝑥𝑖top𝑓subscript𝑤𝑛𝑡subscript𝑥𝑗𝑖𝑗delimited-[]𝑛\displaystyle=\left[\nabla f(w_{n}(t),x_{i})^{\top}\nabla f(w_{n}(t),x_{j})% \right]_{i,j\in[n]}.= [ ∇ italic_f ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∇ italic_f ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ] start_POSTSUBSCRIPT italic_i , italic_j ∈ [ italic_n ] end_POSTSUBSCRIPT .
Remark 11.

Let us emphasize that all three quantities, Pn(t)subscript𝑃𝑛𝑡P_{n}(t)italic_P start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ), Mn(t)subscript𝑀𝑛𝑡M_{n}(t)italic_M start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) and Hn(t)subscript𝐻𝑛𝑡H_{n}(t)italic_H start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) are elements of 𝒴n×𝒴nsuperscript𝒴𝑛superscript𝒴𝑛\mathcal{Y}^{n}\times\mathcal{Y}^{n}caligraphic_Y start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT × caligraphic_Y start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT. For regression problems, we might have 𝒴𝒴\mathcal{Y}\subseteq\mathbb{R}caligraphic_Y ⊆ blackboard_R in which case they are simply matrices in n×nsuperscript𝑛𝑛\mathbb{R}^{n\times n}blackboard_R start_POSTSUPERSCRIPT italic_n × italic_n end_POSTSUPERSCRIPT. For classification problems with C𝐶Citalic_C categories, 𝒴C𝒴superscript𝐶\mathcal{Y}\subset\mathbb{R}^{C}caligraphic_Y ⊂ blackboard_R start_POSTSUPERSCRIPT italic_C end_POSTSUPERSCRIPT, and therefore these three quantities are four-dimensional tensors. But we can interpret them as elements of nC×nCsuperscript𝑛𝐶𝑛𝐶\mathbb{R}^{nC\times nC}blackboard_R start_POSTSUPERSCRIPT italic_n italic_C × italic_n italic_C end_POSTSUPERSCRIPT. This amounts to vectorizing the tensor as a matrix. Just like a matrix may be reshaped into a vector, we have reshaped a tensor into a matrix.

Equation Eq. 13 and the above lemma together give

trΣ^n(t)=rn(0)Ωn(t)(Mn(t)Hn(t)n)Ωn(t)rn(0),tracesubscript^Σ𝑛𝑡subscript𝑟𝑛superscript0topsubscriptΩ𝑛superscript𝑡topsubscript𝑀𝑛𝑡subscript𝐻𝑛𝑡𝑛subscriptΩ𝑛𝑡subscript𝑟𝑛0\tr\hat{\Sigma}_{n}(t)=\vec{r}_{n}(0)^{\top}\Omega_{n}(t)^{\top}\left(M_{n}(t)% -\frac{H_{n}(t)}{n}\right)\Omega_{n}(t)\vec{r}_{n}(0),roman_tr over^ start_ARG roman_Σ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) = over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT roman_Ω start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( italic_M start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) - divide start_ARG italic_H start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_ARG start_ARG italic_n end_ARG ) roman_Ω start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) , (16)

where we denote Ω(0,t)Ω0𝑡\Omega(0,t)roman_Ω ( 0 , italic_t ) as Ω(t)Ω𝑡\Omega(t)roman_Ω ( italic_t ) for short. In Eq. 14, the term MnHn/nsubscript𝑀𝑛subscript𝐻𝑛𝑛M_{n}-H_{n}/nitalic_M start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT - italic_H start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT / italic_n pertains to the covariance of the predictor on the training dataset Snsubscript𝑆𝑛S_{n}italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT. We can see that the residual rn(t)subscript𝑟𝑛𝑡\vec{r}_{n}(t)over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) controls the magnitude of gradients (w,zi)𝑤subscript𝑧𝑖\nabla\ell(w,z_{i})∇ roman_ℓ ( italic_w , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) for i[n]𝑖delimited-[]𝑛i\in[n]italic_i ∈ [ italic_n ] hence that of the covariance. For networks that train quickly, the residual norm rn(t)2subscriptnormsubscript𝑟𝑛𝑡2\norm{\vec{r}_{n}(t)}_{2}∥ start_ARG over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT vanishes quickly, leading to a smaller accumulation of the perturbation term ϵ¯nsubscript¯italic-ϵ𝑛\bar{\epsilon}_{n}over¯ start_ARG italic_ϵ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT above, and hence a smaller generalization gap — this explains the folklore theorem “networks that generalize well also train quickly”.

Remark 12 (Relationship to weight norm based bound).

Arora et al. (2019) also study the residual dynamics to obtain an estimate of the norm of eventual weights that can be reached by gradient descent. The authors derive a weight-norm-based bound that uses the results of Bartlett et al. (2017) on Rademacher complexity. Similar ideas are adpoted in Allen-Zhu et al. (2019); Cao and Gu (2019), analyzing SGD and online learning of fully connected neural nets respectively. Liu et al. (2022) derives weight norm bound under uniform-LGI conditions for general optimization problems. In contrast, we use the evolution of the residual in Eq. 12 to calculate the trace of the gradient covariance. This is directly related to the difference of the loss of networks trained on perturbed datasets—as opposed to the difference in their weights. Our analysis is conducted directly in the prediction space and provides a more direct and refined characterization of generalization for general neural networks.

3.4 Effective Gram matrix for neural networks

We next derive an expression of averaged loss difference Δ¯n(t)subscript¯Δ𝑛𝑡\bar{\Delta}_{n}(t)over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) in terms of a certain quadratic form of the initial residual and an “effective Gram matrix”, by analyzing the evolution of Δ¯n(t)subscript¯Δ𝑛𝑡\bar{\Delta}_{n}(t)over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) and rn(t)subscript𝑟𝑛𝑡\vec{r}_{n}(t)over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) during training. The following theorem combines the solution of Δ¯n(t)subscript¯Δ𝑛𝑡\bar{\Delta}_{n}(t)over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) in Eq. 7 and rn(t)subscript𝑟𝑛𝑡\vec{r}_{n}(t)over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) in Eq. 13, along with the decomposition of Σ^n(t)subscript^Σ𝑛𝑡\hat{\Sigma}_{n}(t)over^ start_ARG roman_Σ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) in Eq. 14.

Theorem 13.

Assume that the evolution of wn(t)subscript𝑤𝑛𝑡w_{n}(t)italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) and wni(t)superscriptsubscript𝑤𝑛𝑖𝑡w_{n}^{-i}(t)italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) follows Eq. 1 and the loss function (w,z)𝑤𝑧\ell(w,z)roman_ℓ ( italic_w , italic_z ) is smooth in w𝑤witalic_w for every z𝒵𝑧𝒵z\in\mathcal{Z}italic_z ∈ caligraphic_Z. We have

Δ¯n(t)=rn(0)Kn(0,t)rn(0),subscript¯Δ𝑛𝑡subscript𝑟𝑛superscript0topsubscript𝐾𝑛0𝑡subscript𝑟𝑛0\displaystyle\bar{\Delta}_{n}(t)=\vec{r}_{n}(0)^{\top}K_{n}(0,t)\vec{r}_{n}(0),over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) = over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 , italic_t ) over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) , (17)

where

Kn(0,t)=0tΩn(s)(Mn(s)Hn(s)n)Ωn(s)exp(stc¯n(u)du)dsn1subscript𝐾𝑛0𝑡superscriptsubscript0𝑡subscriptΩ𝑛superscript𝑠topsubscript𝑀𝑛𝑠subscript𝐻𝑛𝑠𝑛subscriptΩ𝑛𝑠superscriptsubscript𝑠𝑡subscript¯𝑐𝑛𝑢𝑢𝑠𝑛1\displaystyle K_{n}(0,t)=\frac{\int_{0}^{t}\Omega_{n}(s)^{\top}\left(M_{n}(s)-% \frac{H_{n}(s)}{n}\right)\Omega_{n}(s)\exp\left(-\int_{s}^{t}\bar{c}_{n}(u)% \differential u\right)\differential s}{n-1}italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 , italic_t ) = divide start_ARG ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT roman_Ω start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_s ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( italic_M start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_s ) - divide start_ARG italic_H start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_s ) end_ARG start_ARG italic_n end_ARG ) roman_Ω start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_s ) roman_exp ( - ∫ start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT over¯ start_ARG italic_c end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_u ) start_DIFFOP roman_d end_DIFFOP italic_u ) start_DIFFOP roman_d end_DIFFOP italic_s end_ARG start_ARG italic_n - 1 end_ARG (18)

is positive semi-definite. Let

KnlimtKn(0,t)subscript𝐾𝑛subscript𝑡subscript𝐾𝑛0𝑡K_{n}\triangleq\lim_{t\to\infty}K_{n}(0,t)italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ≜ roman_lim start_POSTSUBSCRIPT italic_t → ∞ end_POSTSUBSCRIPT italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 , italic_t )

when the limit exists, then we have

Δ¯n()limtΔ¯n(t)=rn(0)Knrn(0).subscript¯Δ𝑛subscript𝑡subscript¯Δ𝑛𝑡subscript𝑟𝑛superscript0topsubscript𝐾𝑛subscript𝑟𝑛0\displaystyle\bar{\Delta}_{n}(\infty)\triangleq\lim_{t\to\infty}\bar{\Delta}_{% n}(t)=\vec{r}_{n}(0)^{\top}K_{n}\vec{r}_{n}(0).over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( ∞ ) ≜ roman_lim start_POSTSUBSCRIPT italic_t → ∞ end_POSTSUBSCRIPT over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) = over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) .

We call Knsubscript𝐾𝑛K_{n}italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT the effective Gram matrix of a neural network.

We call Kn(0,t)subscript𝐾𝑛0𝑡K_{n}(0,t)italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 , italic_t ) the “effective Gram matrix” because it is a weighted average of Gram matrices 222In linear algebra, the Gram matrix of a set of vectors v1,,vnsubscript𝑣1subscript𝑣𝑛v_{1},\dots,v_{n}italic_v start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_v start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT is given by VVsuperscript𝑉top𝑉V^{\top}Vitalic_V start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_V, where v1,,vnsubscript𝑣1subscript𝑣𝑛v_{1},\dots,v_{n}italic_v start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_v start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT are columns of matrix V𝑉Vitalic_V. of the form VVsuperscript𝑉top𝑉V^{\top}Vitalic_V start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_V, where V=M(s)H(s)/nn1Ωn(s)𝑉𝑀𝑠𝐻𝑠𝑛𝑛1subscriptΩ𝑛𝑠V=\sqrt{\frac{M(s)-H(s)/n}{n-1}}\Omega_{n}(s)italic_V = square-root start_ARG divide start_ARG italic_M ( italic_s ) - italic_H ( italic_s ) / italic_n end_ARG start_ARG italic_n - 1 end_ARG end_ARG roman_Ω start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_s ). We next show the conditions that guarantee the existence of limtKn(0,t)subscript𝑡subscript𝐾𝑛0𝑡\lim_{t\to\infty}K_{n}(0,t)roman_lim start_POSTSUBSCRIPT italic_t → ∞ end_POSTSUBSCRIPT italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 , italic_t ).

Lemma 14.

Let m(t)=1n1Mn(t)Hn(t)n2𝑚𝑡1𝑛1subscriptnormsubscript𝑀𝑛𝑡subscript𝐻𝑛𝑡𝑛2m(t)=\frac{1}{n-1}\norm{M_{n}(t)-\frac{H_{n}(t)}{n}}_{2}italic_m ( italic_t ) = divide start_ARG 1 end_ARG start_ARG italic_n - 1 end_ARG ∥ start_ARG italic_M start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) - divide start_ARG italic_H start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_ARG start_ARG italic_n end_ARG end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT and ω(t)=exp(2n0tλmin(s)ds)𝜔𝑡2𝑛superscriptsubscript0𝑡subscript𝜆𝑠𝑠\omega(t)=\exp\left(-\frac{2}{n}\int_{0}^{t}\lambda_{\min}(s)\differential s\right)italic_ω ( italic_t ) = roman_exp ( - divide start_ARG 2 end_ARG start_ARG italic_n end_ARG ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_λ start_POSTSUBSCRIPT roman_min end_POSTSUBSCRIPT ( italic_s ) start_DIFFOP roman_d end_DIFFOP italic_s ). Let λmax(t)subscript𝜆𝑡\lambda_{\max}(t)italic_λ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ( italic_t ) and λmin(t)subscript𝜆𝑡\lambda_{\min}(t)italic_λ start_POSTSUBSCRIPT roman_min end_POSTSUBSCRIPT ( italic_t ) be the largest and smallest eigenvalues of (Pn(t)+Pn(t))/2subscript𝑃𝑛𝑡subscript𝑃𝑛superscript𝑡top2(P_{n}(t)+P_{n}(t)^{\top})/2( italic_P start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) + italic_P start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) / 2 respectively. If

  1. (i)

    limt0tω(s)m(s)dssubscript𝑡superscriptsubscript0𝑡𝜔𝑠𝑚𝑠𝑠\lim_{t\to\infty}\int_{0}^{t}\omega(s)\ m(s)\differential sroman_lim start_POSTSUBSCRIPT italic_t → ∞ end_POSTSUBSCRIPT ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_ω ( italic_s ) italic_m ( italic_s ) start_DIFFOP roman_d end_DIFFOP italic_s exists,

  2. (ii)

    there exists a constant B>0𝐵0B>0italic_B > 0 such that |ω(t)m(t)|B𝜔𝑡𝑚𝑡𝐵\absolutevalue{\omega(t)m(t)}\leq B| start_ARG italic_ω ( italic_t ) italic_m ( italic_t ) end_ARG | ≤ italic_B for all t𝑡titalic_t, and

  3. (iii)

    the contraction factor c¯n(t)0subscript¯𝑐𝑛𝑡0\bar{c}_{n}(t)\geq 0over¯ start_ARG italic_c end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) ≥ 0 for all t0𝑡0t\geq 0italic_t ≥ 0,

then limtKn(0,t)subscript𝑡subscript𝐾𝑛0𝑡\lim_{t\to\infty}K_{n}(0,t)roman_lim start_POSTSUBSCRIPT italic_t → ∞ end_POSTSUBSCRIPT italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 , italic_t ) exists in 2subscript2\ell_{2}roman_ℓ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT-norm.

Remark 15.

Sometimes the effective Gram matrix calculated from the propagator derived from Pn(t)subscript𝑃𝑛𝑡P_{n}(t)italic_P start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) does not converge. But in this case, we can create a perturbed version of Pn(t)subscript𝑃𝑛𝑡P_{n}(t)italic_P start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) with a controlled λmin(t)subscript𝜆𝑡\lambda_{\min}(t)italic_λ start_POSTSUBSCRIPT roman_min end_POSTSUBSCRIPT ( italic_t ) such that the conditions of Lemma 14 are satisfied. This guarantees the convergence of limtKn(0,t)subscript𝑡subscript𝐾𝑛0𝑡\lim_{t\to\infty}K_{n}(0,t)roman_lim start_POSTSUBSCRIPT italic_t → ∞ end_POSTSUBSCRIPT italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 , italic_t ) while preserving the trajectory of rn(t)subscript𝑟𝑛𝑡\vec{r}_{n}(t)over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) given r0(t)subscript𝑟0𝑡\vec{r}_{0}(t)over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_t ). For example, in Section 3.5 we construct Pnε(t)superscriptsubscript𝑃𝑛𝜀𝑡P_{n}^{\varepsilon}(t)italic_P start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_ε end_POSTSUPERSCRIPT ( italic_t ) as a perturbed version of Pn(t)subscript𝑃𝑛𝑡P_{n}(t)italic_P start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ).

We should note that to analyze generalization gap via the relation between the residual and the effective Gram matrix meaningfully, Kn(0,t)subscript𝐾𝑛0𝑡K_{n}(0,t)italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 , italic_t ) should corresponding to a trajectory that fits the data by time t𝑡titalic_t. This is true only when ¯(wn(t),Sn)=0¯subscript𝑤𝑛𝑡subscript𝑆𝑛0\bar{\ell}(w_{n}(t),S_{n})=0over¯ start_ARG roman_ℓ end_ARG ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) = 0, which by Eq. 12 also implies that ¯(wn(t),Sn)=0¯subscript𝑤𝑛superscript𝑡subscript𝑆𝑛0\bar{\ell}(w_{n}(t^{\prime}),S_{n})=0over¯ start_ARG roman_ℓ end_ARG ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) = 0 for all ttsuperscript𝑡𝑡t^{\prime}\geq titalic_t start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ≥ italic_t. Hence we only consider Kn=limtKn(0,t)subscript𝐾𝑛subscript𝑡subscript𝐾𝑛0𝑡K_{n}=\lim_{t\to\infty}K_{n}(0,t)italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT = roman_lim start_POSTSUBSCRIPT italic_t → ∞ end_POSTSUBSCRIPT italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 , italic_t ) in the interpolating regime where limt¯(wn(t),Sn)=0subscript𝑡¯subscript𝑤𝑛𝑡subscript𝑆𝑛0\lim_{t\to\infty}\bar{\ell}(w_{n}(t),S_{n})=0roman_lim start_POSTSUBSCRIPT italic_t → ∞ end_POSTSUBSCRIPT over¯ start_ARG roman_ℓ end_ARG ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) = 0, instead of a finite time Kn(0,t)subscript𝐾𝑛0𝑡K_{n}(0,t)italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 , italic_t ). This idea is also reflected in Arora et al. (2019), where the authors consider the NTK regime for infinite time, in which case, the training data is fitted perfectly.

Remark 16 (Data and architecture dependent generalization bound).

The quadratic form rn(0)Knrn(0)subscript𝑟𝑛superscript0topsubscript𝐾𝑛subscript𝑟𝑛0\vec{r}_{n}(0)^{\top}K_{n}\vec{r}_{n}(0)over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) in Theorem 13 gives a data and architecture dependent measure of complexity that characterizes the generalization gap of general deep neural networks. We will also see in the experimental section this faithfully captures the true generalization gap. Eigenvalues of Knsubscript𝐾𝑛K_{n}italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT represent the relative contribution to the generalization gap accumulated in the different subspaces during training. If the initial residual (roughly, the distance to the target) predominantly projects onto the subspace of Knsubscript𝐾𝑛K_{n}italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT with small eigenvalues, the training process is benign, resulting in a small eventual generalization gap (as showed in Section 4.2). This is therefore one of the key quantities that we will track in numerical experiments on different architectures and datasets in Section 4.

Remark 17 (Generalization gap of kernel machines).

The generalization loss in kernel ridge regression (Rakhlin and Liang, 2020; Mallinar et al., 2022) can be expressed in terms of quantities that resemble ours, namely, the alignment of the residuals with the Gram matrix r(0)Kr(0)𝑟superscript0top𝐾𝑟0r(0)^{\top}Kr(0)italic_r ( 0 ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_K italic_r ( 0 ) (Arora et al., 2019; Jacot et al., 2020). Our effective Gram matrix generalizes this type of complexity measure to arbitrary deep neural networks and loss functions, going beyond two-layer neural networks with infinite neurons and squared loss. However, unlike kernel ridge regression, where the Gram matrix is derived from a fixed kernel that directly recovers the target function, the effective Gram matrix Knsubscript𝐾𝑛K_{n}italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT in our setting varies for different datasets and training regimes and does not necessarily coincide with any fixed kernel.

3.5 An example calculation of the (effective) Gram matrix for linear regression

Assume that the sample space is supported on two points with orthonormal inputs, i.e., 𝒵={(x1,y1),(x2,y2)}𝒵subscript𝑥1subscript𝑦1subscript𝑥2subscript𝑦2\mathcal{Z}=\{(x_{1},y_{1}),(x_{2},y_{2})\}caligraphic_Z = { ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , ( italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) }, with orthogonal inputs x1x2=0superscriptsubscript𝑥1topsubscript𝑥20x_{1}^{\top}x_{2}=0italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = 0 each with unit norm, x12=x22=1subscriptnormsubscript𝑥12subscriptnormsubscript𝑥221\norm{x_{1}}_{2}=\norm{x_{2}}_{2}=1∥ start_ARG italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = ∥ start_ARG italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = 1. We choose the predictor to be f(w,x)=wx𝑓𝑤𝑥superscript𝑤top𝑥f(w,x)=w^{\top}xitalic_f ( italic_w , italic_x ) = italic_w start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x, and the loss function to be (y,y)=(yy)2/2superscript𝑦𝑦superscriptsuperscript𝑦𝑦22\ell(y^{\prime},y)=(y^{\prime}-y)^{2}/2roman_ℓ ( italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT , italic_y ) = ( italic_y start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT - italic_y ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT / 2. We therefore have (w,y)=(wxy)2/2𝑤𝑦superscriptsuperscript𝑤top𝑥𝑦22\ell(w,y)=(w^{\top}x-y)^{2}/2roman_ℓ ( italic_w , italic_y ) = ( italic_w start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x - italic_y ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT / 2. Consider the dataset Snsubscript𝑆𝑛S_{n}italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT with n𝑛nitalic_n even, where zi=(x1,y1)subscript𝑧𝑖subscript𝑥1subscript𝑦1z_{i}=(x_{1},y_{1})italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) when in/2𝑖𝑛2i\leq n/2italic_i ≤ italic_n / 2 and zi=(x2,y2)subscript𝑧𝑖subscript𝑥2subscript𝑦2z_{i}=(x_{2},y_{2})italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = ( italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) when i>n/2𝑖𝑛2i>n/2italic_i > italic_n / 2. Assume that wn(0)=wni(0)=0subscript𝑤𝑛0subscriptsuperscript𝑤𝑖𝑛00w_{n}(0)=w^{-i}_{n}(0)=\vec{0}italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) = italic_w start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) = over→ start_ARG 0 end_ARG for all i𝑖iitalic_i which ensures that the initial residual is simply the vector of ground-truth targets rn(0)=ysubscript𝑟𝑛0𝑦\vec{r}_{n}(0)=yover→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) = italic_y. The averaged contraction factor in Eq. 8 is

c¯n(t)=c¯:=n22(n1).subscript¯𝑐𝑛𝑡¯𝑐assign𝑛22𝑛1\displaystyle\bar{c}_{n}(t)=\bar{c}:=\frac{n-2}{2(n-1)}.over¯ start_ARG italic_c end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) = over¯ start_ARG italic_c end_ARG := divide start_ARG italic_n - 2 end_ARG start_ARG 2 ( italic_n - 1 ) end_ARG .

and we have from Eq. 12 and Eq. 15 that

Mn(t)=In,Hn(t)=Pn(t)=diag(11,11),formulae-sequencesubscript𝑀𝑛𝑡subscript𝐼𝑛subscript𝐻𝑛𝑡subscript𝑃𝑛𝑡diag1superscript1top1superscript1top\displaystyle M_{n}(t)=I_{n},\quad H_{n}(t)=P_{n}(t)=\operatorname{\text{diag}% }\left(\vec{1}\vec{1}^{\top},\vec{1}\vec{1}^{\top}\right),italic_M start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) = italic_I start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_H start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) = italic_P start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) = diag ( over→ start_ARG 1 end_ARG over→ start_ARG 1 end_ARG start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT , over→ start_ARG 1 end_ARG over→ start_ARG 1 end_ARG start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) ,

with 1=[1,,1]n/2111superscript𝑛2\vec{1}=[1,\dots,1]\in\mathbb{R}^{n/2}over→ start_ARG 1 end_ARG = [ 1 , … , 1 ] ∈ blackboard_R start_POSTSUPERSCRIPT italic_n / 2 end_POSTSUPERSCRIPT. Note that Pn(t)subscript𝑃𝑛𝑡P_{n}(t)italic_P start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) is not full rank when n>2𝑛2n>2italic_n > 2. By Lemma 14, the convergence of limtKn(0,t)subscript𝑡subscript𝐾𝑛0𝑡\lim_{t\to\infty}K_{n}(0,t)roman_lim start_POSTSUBSCRIPT italic_t → ∞ end_POSTSUBSCRIPT italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 , italic_t ) is largely controlled by the smallest eigenvalue of Pn(t)subscript𝑃𝑛𝑡P_{n}(t)italic_P start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ), which cannot be too small. Hence, to ensure convergence, we define a modified version of Pn(t)subscript𝑃𝑛𝑡P_{n}(t)italic_P start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) with small perturbation ε(t)𝜀𝑡\varepsilon(t)italic_ε ( italic_t ) on its singular subspace, i.e., Pnε(t)=UΛεUsuperscriptsubscript𝑃𝑛𝜀𝑡𝑈superscriptΛ𝜀superscript𝑈topP_{n}^{\varepsilon}(t)=U\Lambda^{\varepsilon}U^{\top}italic_P start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_ε end_POSTSUPERSCRIPT ( italic_t ) = italic_U roman_Λ start_POSTSUPERSCRIPT italic_ε end_POSTSUPERSCRIPT italic_U start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT, with Λε=diag(n/2,n/2,nε(t)/2,,nε(t)/2)superscriptΛ𝜀diag𝑛2𝑛2𝑛𝜀𝑡2𝑛𝜀𝑡2\Lambda^{\varepsilon}=\operatorname{\text{diag}}\left(n/2,n/2,n\varepsilon(t)/% 2,\dots,n\varepsilon(t)/2\right)roman_Λ start_POSTSUPERSCRIPT italic_ε end_POSTSUPERSCRIPT = diag ( italic_n / 2 , italic_n / 2 , italic_n italic_ε ( italic_t ) / 2 , … , italic_n italic_ε ( italic_t ) / 2 ), U=[u1,un]𝑈subscript𝑢1subscript𝑢𝑛U=[u_{1},\dots u_{n}]italic_U = [ italic_u start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … italic_u start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ], where

u1=2n[1,,1,0,,0],u2=2n[0,,0,1,,1].formulae-sequencesubscript𝑢12𝑛1100subscript𝑢22𝑛0011\displaystyle u_{1}=\sqrt{\frac{2}{n}}\ [1,\dots,1,0,\dots,0],\quad u_{2}=% \sqrt{\frac{2}{n}}\ [0,\dots,0,1,\dots,1].italic_u start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = square-root start_ARG divide start_ARG 2 end_ARG start_ARG italic_n end_ARG end_ARG [ 1 , … , 1 , 0 , … , 0 ] , italic_u start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = square-root start_ARG divide start_ARG 2 end_ARG start_ARG italic_n end_ARG end_ARG [ 0 , … , 0 , 1 , … , 1 ] .

In this case, when ε(t)0𝜀𝑡0\varepsilon(t)\equiv 0italic_ε ( italic_t ) ≡ 0, we have Pnε(t)Pn(t)superscriptsubscript𝑃𝑛𝜀𝑡subscript𝑃𝑛𝑡P_{n}^{\varepsilon}(t)\equiv P_{n}(t)italic_P start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_ε end_POSTSUPERSCRIPT ( italic_t ) ≡ italic_P start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ). The dynamics drn(t)dt=Pnε(t)rn(t)/nsubscript𝑟𝑛𝑡𝑡subscriptsuperscript𝑃𝜀𝑛𝑡subscript𝑟𝑛𝑡𝑛\differential\vec{r}_{n}(t)\differential t=-P^{\varepsilon}_{n}(t)\vec{r}_{n}(% t)/nstart_DIFFOP roman_d end_DIFFOP over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) start_DIFFOP roman_d end_DIFFOP italic_t = - italic_P start_POSTSUPERSCRIPT italic_ε end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) / italic_n gives the same trajectory of rn(t)subscript𝑟𝑛𝑡\vec{r}_{n}(t)over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) as Eq. 12, since rn(0)=[y1,,y1,y2,,y2]span(u1,u2)subscript𝑟𝑛0subscript𝑦1subscript𝑦1subscript𝑦2subscript𝑦2spansubscript𝑢1subscript𝑢2\vec{r}_{n}(0)=[y_{1},\dots,y_{1},y_{2},\dots,y_{2}]\in\text{span}(u_{1},u_{2})over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) = [ italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , italic_y start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ] ∈ span ( italic_u start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_u start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ). By setting ε(t)=ε¯(1[0,1](t)+1[1,](t)/t2)𝜀𝑡¯𝜀subscript101𝑡subscript11𝑡superscript𝑡2\varepsilon(t)=\bar{\varepsilon}\left(1_{[0,1]}(t)+1_{[1,\infty]}(t)/t^{2}\right)italic_ε ( italic_t ) = over¯ start_ARG italic_ε end_ARG ( 1 start_POSTSUBSCRIPT [ 0 , 1 ] end_POSTSUBSCRIPT ( italic_t ) + 1 start_POSTSUBSCRIPT [ 1 , ∞ ] end_POSTSUBSCRIPT ( italic_t ) / italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) with ε¯1much-less-than¯𝜀1\bar{\varepsilon}\ll 1over¯ start_ARG italic_ε end_ARG ≪ 1, the effective Gram matrix Kn(0,t)subscript𝐾𝑛0𝑡K_{n}(0,t)italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 , italic_t ) can be calculated from Eq. 18 as

Kn(0,t)=UΛK(t)Usubscript𝐾𝑛0𝑡𝑈superscriptΛ𝐾𝑡superscript𝑈top\displaystyle K_{n}(0,t)=U\Lambda^{K}(t)U^{\top}italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 , italic_t ) = italic_U roman_Λ start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT ( italic_t ) italic_U start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT

where ΛK(t)=[λ1K(t),,λnK(t)]superscriptΛ𝐾𝑡subscriptsuperscript𝜆𝐾1𝑡subscriptsuperscript𝜆𝐾𝑛𝑡\Lambda^{K}(t)=[\lambda^{K}_{1}(t),\dots,\lambda^{K}_{n}(t)]roman_Λ start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT ( italic_t ) = [ italic_λ start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_t ) , … , italic_λ start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) ], and

λ1K(t)=λ2K(t)=Θ(exp(c¯t)),λ3K(t)==λnK(t)=Θ(1),formulae-sequencesubscriptsuperscript𝜆𝐾1𝑡subscriptsuperscript𝜆𝐾2𝑡Θ¯𝑐𝑡subscriptsuperscript𝜆𝐾3𝑡subscriptsuperscript𝜆𝐾𝑛𝑡Θ1\displaystyle\lambda^{K}_{1}(t)=\lambda^{K}_{2}(t)=\Theta(\exp(-\bar{c}t)),% \quad\lambda^{K}_{3}(t)=\dots=\lambda^{K}_{n}(t)=\Theta(1),italic_λ start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_t ) = italic_λ start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_t ) = roman_Θ ( roman_exp ( start_ARG - over¯ start_ARG italic_c end_ARG italic_t end_ARG ) ) , italic_λ start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ( italic_t ) = ⋯ = italic_λ start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) = roman_Θ ( 1 ) ,

indicating that the initial residual rn(0)subscript𝑟𝑛0\vec{r}_{n}(0)over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) lies in the subspace of Kn=limtKn(t)subscript𝐾𝑛subscript𝑡subscript𝐾𝑛𝑡K_{n}=\lim_{t\to\infty}K_{n}(t)italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT = roman_lim start_POSTSUBSCRIPT italic_t → ∞ end_POSTSUBSCRIPT italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) with zero eigenvalue.

Note that since 𝒵𝒵\mathcal{Z}caligraphic_Z is supported on only two points, gradient flow wn(t)subscript𝑤𝑛𝑡w_{n}(t)italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) trained on a dataset containing both these samples generalizes and achieves zero loss for any data distribution D𝐷Ditalic_D supported on 𝒵𝒵\mathcal{Z}caligraphic_Z. This coincides with the calculation above, where the averaged loss difference Δ¯n(t)subscript¯Δ𝑛𝑡\bar{\Delta}_{n}(t)over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) as predicted by the quadratic form yKnysuperscript𝑦topsubscript𝐾𝑛𝑦y^{\top}K_{n}yitalic_y start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT italic_y in Theorem 13, approaches zero as t𝑡t\to\inftyitalic_t → ∞. The calculation holds regardless of what fraction of data in Snsubscript𝑆𝑛S_{n}italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT comes from either of the two points (so long as both are present). Our theorem correctly predicts that the generalization gap goes to zero. Now if we take an expectation, we have

𝔼[yKny]=𝔼[Δ¯n(t)]=Θ(2n)𝔼delimited-[]superscript𝑦topsubscript𝐾𝑛𝑦𝔼delimited-[]subscript¯Δ𝑛𝑡Θsuperscript2𝑛\mathbb{E}[y^{\top}K_{n}y]=\mathbb{E}[\bar{\Delta}_{n}(t)]=\Theta(2^{-n})blackboard_E [ italic_y start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT italic_y ] = blackboard_E [ over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) ] = roman_Θ ( 2 start_POSTSUPERSCRIPT - italic_n end_POSTSUPERSCRIPT )

because the dataset Snsubscript𝑆𝑛S_{n}italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT is supported on only one of the samples with probability 2(n1)superscript2𝑛12^{-(n-1)}2 start_POSTSUPERSCRIPT - ( italic_n - 1 ) end_POSTSUPERSCRIPT. Theorem 13 is therefore providing a tight prediction of the generalization gap.

The solution wn(t)subscript𝑤𝑛𝑡w_{n}(t)italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) lies in span(x1,x2)spansubscript𝑥1subscript𝑥2\text{span}(x_{1},x_{2})span ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ). When trained on the dataset Snisuperscriptsubscript𝑆𝑛𝑖S_{n}^{-i}italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT, the progress on the direction x2i/nsubscript𝑥2𝑖𝑛x_{\lceil 2i/n\rceil}italic_x start_POSTSUBSCRIPT ⌈ 2 italic_i / italic_n ⌉ end_POSTSUBSCRIPT is slightly less than the other direction, which introduces the non-zero averaged loss difference Δ¯n(t)subscript¯Δ𝑛𝑡\bar{\Delta}_{n}(t)over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) during training. We should also note that the calculation of the contraction and perturbation factors depends heavily on the sample-wise loss gradient (w,zi)𝑤subscript𝑧𝑖\nabla\ell(w,z_{i})∇ roman_ℓ ( italic_w , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) being supported on {x1,x2}subscript𝑥1subscript𝑥2\{x_{1},x_{2}\}{ italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT }. The clustering of per-sample gradients happens also in the training of neural networks, as shown in Fort and Ganguli (2019). See Section A.7 for details of the above calculation.

Remark 18 (Comparison with Arora et al. (2019)).

Let us use the technique of Arora et al. (2019) for our example. We can bound the generalization gap in terms of the norm of the eventual weights. The Gram matrix of the linear regression described above is Hε=Pnε(t)superscript𝐻𝜀superscriptsubscript𝑃𝑛𝜀𝑡H^{\varepsilon}=P_{n}^{\varepsilon}(t)italic_H start_POSTSUPERSCRIPT italic_ε end_POSTSUPERSCRIPT = italic_P start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_ε end_POSTSUPERSCRIPT ( italic_t ) with ε(t)=ε¯𝜀𝑡¯𝜀\varepsilon(t)=\bar{\varepsilon}italic_ε ( italic_t ) = over¯ start_ARG italic_ε end_ARG for some constant ε¯1much-less-than¯𝜀1\bar{\varepsilon}\ll 1over¯ start_ARG italic_ε end_ARG ≪ 1 (we choose this perturbed version to guarantee the positive definiteness while not affecting the evolution of the residual). The norm of weights can be bounded by y(Hε)1ysuperscript𝑦topsuperscriptsuperscript𝐻𝜀1𝑦\sqrt{y^{\top}(H^{\varepsilon})^{-1}y}square-root start_ARG italic_y start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( italic_H start_POSTSUPERSCRIPT italic_ε end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_y end_ARG, which gives a generalization bound for 1-Lipschitz loss,

2y(Hε)1yn=2(y12+y22)n.2superscript𝑦topsuperscriptsuperscript𝐻𝜀1𝑦𝑛2superscriptsubscript𝑦12superscriptsubscript𝑦22𝑛\sqrt{\frac{2y^{\top}(H^{\varepsilon})^{-1}y}{n}}=\sqrt{\frac{2(y_{1}^{2}+y_{2% }^{2})}{n}}.square-root start_ARG divide start_ARG 2 italic_y start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( italic_H start_POSTSUPERSCRIPT italic_ε end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_y end_ARG start_ARG italic_n end_ARG end_ARG = square-root start_ARG divide start_ARG 2 ( italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_y start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) end_ARG start_ARG italic_n end_ARG end_ARG .

This is far looser than the actual generalization error, which is Θ(2n)Θsuperscript2𝑛\Theta(2^{-n})roman_Θ ( 2 start_POSTSUPERSCRIPT - italic_n end_POSTSUPERSCRIPT ) for 1-Lipschitz loss from the calculation above. The key point to emphasize here is that by characterizing the evolution of the point-wise loss difference using the contraction factor, we can work directly in the prediction space instead of working in the weight space. This is the reason why our estimate of the generalization gap is more accurate.

4 Experimental Validation

Datasets

For experimental validation of our theoretical development, we use a number of different datasets and experimental settings.

  • MNIST (LeCun et al., 1998) and CIFAR-10 (Krizhevsky, 2009) classification datasets, both with 10 categories.

  • Synthetic datasets labeled Syn-(a,b)𝑎𝑏(a,b)( italic_a , italic_b ) are created by modifying the labeling function of the MNIST dataset as follows. We first project MNIST images onto the subspace of the empirical second moment matrix corresponding to the a𝑎aitalic_a-th to b𝑏bitalic_b-th eigenvalue, sorted from the largest to the smallest. We then relabel the MNIST inputs using a fully-connected teacher with random weights, applied to the projected images.

  • Synthetic datasets labeled Gaussian-α𝛼\alphaitalic_α are created using Gaussian data from different covariance matrices. Inputs data in Gaussian-α𝛼\alphaitalic_α are sampled from the multivariate Gaussian distributions with covariance matrix A𝐴Aitalic_A, where the i𝑖iitalic_i-th eigenvalue is exp(αi)𝛼𝑖\exp(-\alpha i)roman_exp ( start_ARG - italic_α italic_i end_ARG ). We then project these inputs onto the subspace of the covariance matrix corresponding to the 10 largest eigenvalues and label the original inputs using a fully connected teacher with random weights, applied to the projected inputs.

The rationale for creating these synthetic datasets will become clear as we discuss the experiments, but in short, we seek to create datasets where the signal-to-noise ratio can be controlled. The smaller the value of a𝑎aitalic_a and the larger the value of α𝛼\alphaitalic_α, the larger the signal-to-noise ratio in these synthetic data.

Architectures

We will train FC (fully connected neural networks), LeNet-5 (LeCun et al., 1998) (a network with two convolutional layers and one fully connected layer), and WRN-4-4 (Zagoruyko and Komodakis, 2016) (wide residual network with 4 layers and a widening factor of 4) using (non-stochastic) gradient descent with different numbers of samples drawn from the datasets described above. We use gradient descent as the Euler approximation of gradient flow in our theory. See Appendix B for more details.

Constructing perturbed datasets

The theory in this paper was written when the modified dataset Snisuperscriptsubscript𝑆𝑛𝑖S_{n}^{-i}italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT with n1𝑛1n-1italic_n - 1 samples is created by omitting the i𝑖iitalic_i-th sample. For numerical stability and efficiency of the approximation, in the experiments, we create datasets by omitting a batch of m𝑚mitalic_m samples. Let (m)𝑚(m)( italic_m ) denote a subset of [n]delimited-[]𝑛[n][ italic_n ] with size m𝑚mitalic_m. Let S(m)={zi=(xi,yi)i(m)}subscript𝑆𝑚subscript𝑧𝑖subscriptsubscript𝑥𝑖subscript𝑦𝑖𝑖𝑚S_{(m)}=\{z_{i}=(x_{i},y_{i})_{i\in(m)}\}italic_S start_POSTSUBSCRIPT ( italic_m ) end_POSTSUBSCRIPT = { italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i ∈ ( italic_m ) end_POSTSUBSCRIPT }. We will conduct experiments using modified datasets Sn(m)=SnS(m)superscriptsubscript𝑆𝑛𝑚subscript𝑆𝑛subscript𝑆𝑚S_{n}^{-(m)}=S_{n}\setminus S_{(m)}italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - ( italic_m ) end_POSTSUPERSCRIPT = italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ∖ italic_S start_POSTSUBSCRIPT ( italic_m ) end_POSTSUBSCRIPT obtained by removing S(m)subscript𝑆𝑚S_{(m)}italic_S start_POSTSUBSCRIPT ( italic_m ) end_POSTSUBSCRIPT from Snsubscript𝑆𝑛S_{n}italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT. We therefore consider the weight trajectory wn(m)(t)superscriptsubscript𝑤𝑛𝑚𝑡w_{n}^{-(m)}(t)italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - ( italic_m ) end_POSTSUPERSCRIPT ( italic_t ) of the differential equation w˙=¯(w,Sn(m))˙𝑤¯𝑤subscriptsuperscript𝑆𝑚𝑛\dot{w}=-\nabla\bar{\ell}(w,S^{-(m)}_{n})over˙ start_ARG italic_w end_ARG = - ∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUPERSCRIPT - ( italic_m ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ). The averaged loss difference is modified in the usual fashion Δ¯n(t)=𝔼(m)[Δn(m)]subscript¯Δ𝑛𝑡subscript𝔼𝑚superscriptsubscriptΔ𝑛𝑚\bar{\Delta}_{n}(t)=\operatorname*{\mathbb{E}}_{(m)}\left[\Delta_{n}^{-(m)}\right]over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) = blackboard_E start_POSTSUBSCRIPT ( italic_m ) end_POSTSUBSCRIPT [ roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - ( italic_m ) end_POSTSUPERSCRIPT ] with the batch-wise loss difference

Δn(m)(t)=(wn(m)(t),S(m))(wn(t),S(m)).superscriptsubscriptΔ𝑛𝑚𝑡superscriptsubscript𝑤𝑛𝑚𝑡subscript𝑆𝑚subscript𝑤𝑛𝑡subscript𝑆𝑚\Delta_{n}^{-(m)}(t)=\ell(w_{n}^{-(m)}(t),S_{(m)})-\ell(w_{n}(t),S_{(m)}).roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - ( italic_m ) end_POSTSUPERSCRIPT ( italic_t ) = roman_ℓ ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - ( italic_m ) end_POSTSUPERSCRIPT ( italic_t ) , italic_S start_POSTSUBSCRIPT ( italic_m ) end_POSTSUBSCRIPT ) - roman_ℓ ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_S start_POSTSUBSCRIPT ( italic_m ) end_POSTSUBSCRIPT ) .

Note that 𝔼(m)subscript𝔼𝑚\operatorname*{\mathbb{E}}_{(m)}blackboard_E start_POSTSUBSCRIPT ( italic_m ) end_POSTSUBSCRIPT denotes the expectation taken over the uniform distribution on all possible choices of (m)𝑚(m)( italic_m ) in [n]delimited-[]𝑛[n][ italic_n ]. The formulae for the averaged contraction and perturbation factors c¯n(t)subscript¯𝑐𝑛𝑡\bar{c}_{n}(t)over¯ start_ARG italic_c end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) and ϵ¯n(t)subscript¯italic-ϵ𝑛𝑡\bar{\epsilon}_{n}(t)over¯ start_ARG italic_ϵ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) in this setting are shown in Section B.1.

Calculating quantities that pertain to the generalization gap

We are interested in calculating the effective Gram matrix Knsubscript𝐾𝑛K_{n}italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT for different configurations of neural network training. To do this, we approximate the gradient flow Eq. 1 by gradient descent with different learning rate. We calculate the averaged contraction factor c¯n(t)subscript¯𝑐𝑛𝑡\bar{c}_{n}(t)over¯ start_ARG italic_c end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ), averaged perturbation factor ϵ¯n(t)subscript¯italic-ϵ𝑛𝑡\bar{\epsilon}_{n}(t)over¯ start_ARG italic_ϵ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ), decomposition of the trace of the gradient covariance Mn(t)subscript𝑀𝑛𝑡M_{n}(t)italic_M start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) and Hn(t)subscript𝐻𝑛𝑡H_{n}(t)italic_H start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) using Eq. 8, Eq. 9 and Eq. 15 (or its alternatives for omitting m𝑚mitalic_m samples setting as in Appendix B) respectively. The propagator Ωn(t)subscriptΩ𝑛𝑡\Omega_{n}(t)roman_Ω start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) is approximated by product methods as described in Section B.4, where we compared it with the Magnus expansion approximation. The integrals for Δ¯n(t)subscript¯Δ𝑛𝑡\bar{\Delta}_{n}(t)over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) in Eq. 6 and Knsubscript𝐾𝑛K_{n}italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT in Eq. 18 are approximated by the trapezoidal method. We will use the statistics in Table 1 to characterize the relation of the initial residual rn(0)subscript𝑟𝑛0\vec{r}_{n}(0)over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) and the effective Gram matrix Knsubscript𝐾𝑛K_{n}italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT. The rationale for defining these quantities comes from Theorem 13 which shows that the eventual generalization gap after training is a quadratic form that depends upon the effective Gram matrix and the initial residual. We are interested in understanding how different subspaces of the effective Gram matrix contribute to this quadratic form.

Notation Definition
E(K),σ(K)𝐸𝐾𝜎𝐾E(K),\sigma(K)italic_E ( italic_K ) , italic_σ ( italic_K ) The eigenspace and eigenspectrum of a symmetric matrix K𝐾Kitalic_K with eigenvalue decomposition
where K=E(K)diag(σ(K))E(K)𝐾𝐸𝐾diag𝜎𝐾𝐸superscript𝐾topK=E(K)\operatorname{\text{diag}}(\sigma(K))E(K)^{\top}italic_K = italic_E ( italic_K ) diag ( italic_σ ( italic_K ) ) italic_E ( italic_K ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT, σ(K)𝜎𝐾\sigma(K)italic_σ ( italic_K ) is the vector of eigenspectrum in ascending order.
σ¯(K)¯𝜎𝐾\bar{\sigma}(K)over¯ start_ARG italic_σ end_ARG ( italic_K ) The mean of the eigenspectrum of a symmetric matrix K𝐾Kitalic_K, σ¯(K)=iσ(K)i/n¯𝜎𝐾subscript𝑖𝜎subscript𝐾𝑖𝑛\bar{\sigma}(K)=\sum_{i}\sigma(K)_{i}/nover¯ start_ARG italic_σ end_ARG ( italic_K ) = ∑ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_σ ( italic_K ) start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT / italic_n.
Uk,U1:ksubscript𝑈𝑘subscript𝑈:1𝑘U_{k},U_{1:k}italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_U start_POSTSUBSCRIPT 1 : italic_k end_POSTSUBSCRIPT The k𝑘kitalic_k-th column of U𝑈Uitalic_U, and the first k𝑘kitalic_k columns of U𝑈Uitalic_U.
P(r,U)𝑃𝑟𝑈P(r,U)italic_P ( italic_r , italic_U ) Normalized projection of a vector r𝑟ritalic_r onto the space U𝑈Uitalic_U (with orthonormal columns).
the k𝑘kitalic_k-th element P(r,U)k=|rUk|/r2𝑃subscript𝑟𝑈𝑘superscript𝑟topsubscript𝑈𝑘subscriptnorm𝑟2P(r,U)_{k}=\absolutevalue{r^{\top}U_{k}}/\norm{r}_{2}italic_P ( italic_r , italic_U ) start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = | start_ARG italic_r start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG | / ∥ start_ARG italic_r end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT.
M(r,U)𝑀𝑟𝑈M(r,U)italic_M ( italic_r , italic_U ) “Explained magnitude” of a vector r𝑟ritalic_r in the space U1:ksubscript𝑈:1𝑘U_{1:k}italic_U start_POSTSUBSCRIPT 1 : italic_k end_POSTSUBSCRIPT (with orthonormal columns)
the k𝑘kitalic_k-th element M(r,U)k=rU1:k22/r22𝑀subscript𝑟𝑈𝑘subscriptsuperscriptnormsuperscript𝑟topsubscript𝑈:1𝑘22superscriptsubscriptnorm𝑟22M(r,U)_{k}=\norm{r^{\top}U_{1:k}}^{2}_{2}/\norm{r}_{2}^{2}italic_M ( italic_r , italic_U ) start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = ∥ start_ARG italic_r start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_U start_POSTSUBSCRIPT 1 : italic_k end_POSTSUBSCRIPT end_ARG ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT / ∥ start_ARG italic_r end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT.
M(K)𝑀𝐾M(K)italic_M ( italic_K ) “Explained magnitude” of a symmetric matrix K𝐾Kitalic_K in its eigenspace E(K)𝐸𝐾E(K)italic_E ( italic_K )
the k𝑘kitalic_k-th element M(K)k=i=1kσ(K)i/i=1nσ(K)i𝑀subscript𝐾𝑘superscriptsubscript𝑖1𝑘𝜎subscript𝐾𝑖superscriptsubscript𝑖1𝑛𝜎subscript𝐾𝑖M(K)_{k}=\sum_{i=1}^{k}\sigma(K)_{i}/\sum_{i=1}^{n}\sigma(K)_{i}italic_M ( italic_K ) start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_σ ( italic_K ) start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT / ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_σ ( italic_K ) start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT for Kn×n𝐾superscript𝑛𝑛K\in\mathbb{R}^{n\times n}italic_K ∈ blackboard_R start_POSTSUPERSCRIPT italic_n × italic_n end_POSTSUPERSCRIPT.
R(Idx)𝑅IdxR(\text{Idx})italic_R ( Idx ) “Relative index” of the index vector Idx=[1,2,,l]Idx12𝑙\text{Idx}=[1,2,\dots,l]Idx = [ 1 , 2 , … , italic_l ], where R(Idx)=[1/n,2/n,,1]𝑅Idx1𝑛2𝑛1R(\text{Idx})=[1/n,2/n,\dots,1]italic_R ( Idx ) = [ 1 / italic_n , 2 / italic_n , … , 1 ].
Table 1: Statistics characterizing the initial residual and effective Gram matrix.

4.1 Theorem 13 leads to a good approximation of the generalization gap

Consider Fig. 2. Observe that the true generalization gap δR(Sn,t)𝛿𝑅subscript𝑆𝑛𝑡\delta R(S_{n},t)italic_δ italic_R ( italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_t ), averaged loss difference Δ¯n(t)subscript¯Δ𝑛𝑡\bar{\Delta}_{n}(t)over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ), and the gap 𝔼(m)[δR(Sn(m),t)]subscript𝔼𝑚𝛿𝑅superscriptsubscript𝑆𝑛𝑚𝑡\operatorname*{\mathbb{E}}_{(m)}\left[\delta R(S_{n}^{-(m)},t)\right]blackboard_E start_POSTSUBSCRIPT ( italic_m ) end_POSTSUBSCRIPT [ italic_δ italic_R ( italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - ( italic_m ) end_POSTSUPERSCRIPT , italic_t ) ] (denoted by δR¯()𝛿¯𝑅\delta\bar{R}(\cdot)italic_δ over¯ start_ARG italic_R end_ARG ( ⋅ ) in the plot) are all close to each other throughout training. This indicates that the generalization gap can be well approximated by Δ¯n(t)subscript¯Δ𝑛𝑡\bar{\Delta}_{n}(t)over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ).

Refer to caption
Refer to caption
Figure 2: Left: FC trained on MNIST with all 10 classes, with n=1100𝑛1100n=1100italic_n = 1100 samples and statistics computed over datasets perturbed by m=100𝑚100m=100italic_m = 100 samples. Right: LeNet-5 trained on CIFAR-10 with 2 selected classes, n=1100𝑛1100n=1100italic_n = 1100, m=100𝑚100m=100italic_m = 100. We choose fewer samples n𝑛nitalic_n than the full dataset to be able to interpolate the data using gradient descent.

We calculate two numerical approximations of Δ¯n(t)subscript¯Δ𝑛𝑡\bar{\Delta}_{n}(t)over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ): the quantity Δ¯n(c,ϵ,t)subscript¯Δ𝑛𝑐italic-ϵ𝑡\bar{\Delta}_{n}(c,\epsilon,t)over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_c , italic_ϵ , italic_t ) computed with the true perturbation factor from Eq. 9, and Δ¯n(c,ϵ^,t)subscript¯Δ𝑛𝑐^italic-ϵ𝑡\bar{\Delta}_{n}(c,\hat{\epsilon},t)over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_c , over^ start_ARG italic_ϵ end_ARG , italic_t ) with an approximate perturbation factor derived from Eq. 16 with the propagator given by the product approximation Eq. 21. First note that Δ¯n(c,ϵ,t)subscript¯Δ𝑛𝑐italic-ϵ𝑡\bar{\Delta}_{n}(c,\epsilon,t)over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_c , italic_ϵ , italic_t ) is close to Δ¯n(t)subscript¯Δ𝑛𝑡\bar{\Delta}_{n}(t)over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ), which indicates that the gradient descent approximation of Eq. 1 and the trapezoidal approximation of Eq. 7 are good. Second, the similarity of Δ¯n(c,ϵ^,t)subscript¯Δ𝑛𝑐^italic-ϵ𝑡\bar{\Delta}_{n}(c,\hat{\epsilon},t)over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_c , over^ start_ARG italic_ϵ end_ARG , italic_t ) and Δ¯n(c,ϵ,t)subscript¯Δ𝑛𝑐italic-ϵ𝑡\bar{\Delta}_{n}(c,\epsilon,t)over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_c , italic_ϵ , italic_t ) indicates that the product approximation in Eq. 21 is working well. The results on CIFAR-10 using a convolutional network are largely similar, with slightly less accurate estimates of the generalization gap.

Note that Δ¯n(c,ϵ^,t)=rn(0)Kn(0,t)rn(0)subscript¯Δ𝑛𝑐^italic-ϵ𝑡subscript𝑟𝑛superscript0topsubscript𝐾𝑛0𝑡subscript𝑟𝑛0\bar{\Delta}_{n}(c,\hat{\epsilon},t)=\vec{r}_{n}(0)^{\top}K_{n}(0,t)\vec{r}_{n% }(0)over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_c , over^ start_ARG italic_ϵ end_ARG , italic_t ) = over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 , italic_t ) over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) for numerically approximated effective Gram matrix Kn(0,t)subscript𝐾𝑛0𝑡K_{n}(0,t)italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 , italic_t ). Hence, the good approximation of the generalization gap by Δ¯n(c,ϵ^,t)subscript¯Δ𝑛𝑐^italic-ϵ𝑡\bar{\Delta}_{n}(c,\hat{\epsilon},t)over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_c , over^ start_ARG italic_ϵ end_ARG , italic_t ) indicates that the numerically approximated effective Gram matrix Knsubscript𝐾𝑛K_{n}italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT (which is calculated in the following sections) is a good quantity to use for understanding generalization. We should note that Knsubscript𝐾𝑛K_{n}italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT refers to the numerical approximations in the following subsections. In Table S.1, we provide a complete list of results of generalization gap approximation for all the experiments.

Paper Architecture Dataset # samples Training method Bound on Test Data Actual Value Relative inaccuracy
Arora et al. (2019) FC MNIST-2 10,000 GD (second layer fixed) 0.05 (11\ell 1roman_ℓ 1 loss) <<< 0.01 (11\ell 1roman_ℓ 1 loss) >>>4
Dziugaite and Roy (2017) FC MNIST-2 55,000 SGD 0.161 (error) 0.018 (error) 7.9
Wang and Ma (2022) FC MNIST-2 55,000 SGD 0.25 (CE loss)
Ours FC MNIST-10 1,100 GD 0.47 (CE loss) 0.45 (CE loss) 0.05
Ours LENET-5 MNIST-10 1,100 GD 0.24 (CE loss) 0.20 (CE loss) 0.18
Negrea et al. (2019) CNN MNIST-10 55,000 SGLD 0.25 (CE loss) 0.02 (error)
Mou et al. (2018) CNN MNIST-10 55,000 SGLD 1.25 (CE loss) 0.02 (error)
Ours FC CIFAR-2 1,100 GD 0.34 (CE loss) 0.41 (CE loss) 0.17
Arora et al. (2019) FC CIFAR-2 10,000 GD (second layer fixed) 0.6 (11\ell 1roman_ℓ 1 loss) 0.45 (11\ell 1roman_ℓ 1 loss) 0.33
Ours LENET-5 CIFAR-2 1,100 GD 0.46 (CE loss) 0.49 (CE loss) 0.06
Ours WRN-4-4 CIFAR-2 1,100 GD 0.111 (CE loss) 0.107 (CE loss) 0.04
Table 2: Comparison with previous results in terms of the relative accuracy of the estimate of the generalization error. See Appendix B for the details of the datasets and architectures. CE loss indicates cross-entropy loss. (S)GD indicates (stochastic) gradient descent, SGLD is stochastic gradient Langevin dynamics. “Bound” in this table refers to the numerical value of the generalization bound. “Actual Value” is test loss or error on held-out test data. “Relative inaccuracy” equal “||||Bound-Actual Value|||| / Actual Value”. This characterizes the quality of these estimates. Different papers make different assumptions, apply to quite different models of neural networks, loss functions, training methods, and use different techniques. One must therefore be careful while interpreting this table. Note that, to be consistent with the calculations, all our experiments are conducted with gradient descent, not stochastic gradient descent. Practically, this means that in order to get the network to fit the training data well enough, we need to use small sample sizes.

Table 2 compares previous results of generalization bounds. The small relative inaccuracy of our methods shows good quality of our approximations.

Remark 19.

While we have tabulated the results above, we should emphasize the following three points for interpreting these results. First, it is not meaningful to compare different theories in Table 2 to find a superior theory. One upper bound being better than another numerically says little about the quantity they both bound. If one simply wanted to predict the generalization gap well, one would be content with using just cross-validation, see (Kawaguchi et al., 2018, Section 4). However, this does not mean that there is no need to do any theory. The goal of work on generalization is to understand what properties of data, architectures and training lead to good generalization. Each work is answering a different facet of this question. This is why there is a big diversity of assumptions, techniques and conclusions. Second, one cannot compare these methods against each other since they have different assumptions on the architecture and training method. Hence, in Table 2, we compare the relative accuracy, which is ||||approximation - actual|||| / actual. This is a reasonable way to compare these approaches. And our approach indeed does very well. Third, our theory is for gradient flow. Thus, our implementation uses gradient descent, not stochastic gradient descent. In practice we cannot get a small training error for these datasets with gradient descent. And this is why we use fewer samples.

4.2 Initial residual lies primarily in the subspace of effective Gram matrix with small eigenvalues

Fig. 3 (left) shows that for MNIST, the initial residual rn(0)subscript𝑟𝑛0\vec{r}_{n}(0)over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) lies primarily in the subspace of Knsubscript𝐾𝑛K_{n}italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT with small eigenvalues, while for the random task, the initial residual put more weights into subspace with larger eigenvalues, where the projection is not negligible even for the head eigenvalues. In Fig. 3 (right), the tail subspace of Knsubscript𝐾𝑛K_{n}italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT with less than 3% of the eigenvalues recovers 98% of rn(0)subscript𝑟𝑛0\vec{r}_{n}(0)over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ). This shows that if the task that we need to fit is simple, in the sense that the initial residual predominantly lies in the tail subspace of Knsubscript𝐾𝑛K_{n}italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT, then the eventual generalization gap is small (the generalization gaps for MNIST and random task are 0.47 and 3.27 respectively). This indicates a benign training process, i.e., the generalization loss accumulates slowly.

Remark 20 (Comparing the statistics for different numbers of samples).

The effective Gram matrix Kn𝒴n×𝒴nsubscript𝐾𝑛superscript𝒴𝑛superscript𝒴𝑛K_{n}\in\mathcal{Y}^{n}\times\mathcal{Y}^{n}italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ∈ caligraphic_Y start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT × caligraphic_Y start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT lies in a different space when neural networks are trained with different numbers of samples n𝑛nitalic_n. Therefore, to compare quantities like σ(Kn)𝜎subscript𝐾𝑛\sigma(K_{n})italic_σ ( italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ), M(rn,E(Kn))𝑀subscript𝑟𝑛𝐸subscript𝐾𝑛M(\vec{r}_{n},E(K_{n}))italic_M ( over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_E ( italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) ) and M(Kn)𝑀subscript𝐾𝑛M(K_{n})italic_M ( italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) for different n𝑛nitalic_n and the same 𝒴𝒴\mathcal{Y}caligraphic_Y, we use a “relative index” as described in Table 1. We rescale the original index vector to have indices from zero to one. We should emphasize that by normalizing the residual by n𝑛\sqrt{n}square-root start_ARG italic_n end_ARG in Eq. 11, the 2subscript2\ell_{2}roman_ℓ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT-norm of the initial residuals rn(0)2subscriptnormsubscript𝑟𝑛02\norm{\vec{r}_{n}(0)}_{2}∥ start_ARG over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT is similar when 𝒴𝒴\mathcal{Y}caligraphic_Y is the same, even if n𝑛nitalic_n is different. Note that the estimated generalization gap rn(0)Knrn(0)subscript𝑟𝑛superscript0topsubscript𝐾𝑛subscript𝑟𝑛0\vec{r}_{n}(0)^{\top}K_{n}\vec{r}_{n}(0)over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) is the average of the eigenvalues σ(Kn)i𝜎subscriptsubscript𝐾𝑛𝑖{\sigma(K_{n})}_{i}italic_σ ( italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, each weighted by the projected residual P(rn(0),E(Kn))i2𝑃superscriptsubscriptsubscript𝑟𝑛0𝐸subscript𝐾𝑛𝑖2P(\vec{r}_{n}(0),E(K_{n}))_{i}^{2}italic_P ( over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) , italic_E ( italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) ) start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT. We therefore also compute σ¯(Kn)¯𝜎subscript𝐾𝑛\bar{\sigma}(K_{n})over¯ start_ARG italic_σ end_ARG ( italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) to understand the effect of Knsubscript𝐾𝑛K_{n}italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT. Table S.1 details the numerical values of these quantities for different datasets and architectures.

Refer to caption
Refer to caption
Figure 3: Statistics of the residual rnsubscript𝑟𝑛\vec{r}_{n}over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT and effective Gram matrix Knsubscript𝐾𝑛K_{n}italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT for two different tasks. Benign task: FC trained on MNIST with all 10 classes, n=1000𝑛1000n=1000italic_n = 1000, m=100𝑚100m=100italic_m = 100. Random task: FC trained on MNIST with 10 randomly assigned classes, n=50𝑛50n=50italic_n = 50, m=5𝑚5m=5italic_m = 5. Left: Eigenspectrum of the Gram matrix σ(Kn)𝜎subscript𝐾𝑛\sigma(K_{n})italic_σ ( italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) and the normalized projection of initial residual P(rn(0),E(Kn))𝑃subscript𝑟𝑛0𝐸subscript𝐾𝑛P(\vec{r}_{n}(0),E(K_{n}))italic_P ( over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) , italic_E ( italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) ) for benign and random tasks. Right: Explained magnitude of the initial residual M(rn(0),E(Kn))𝑀subscript𝑟𝑛0𝐸subscript𝐾𝑛M(\vec{r}_{n}(0),E(K_{n}))italic_M ( over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) , italic_E ( italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) ) for benign and random tasks.

4.3 As training proceeds, the residual projects more into the principal subspace of the effective Gram matrix

We next consider the training process starting from different times t0subscript𝑡0t_{0}italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT instead of t0=0subscript𝑡00t_{0}=0italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = 0. Analogously to what we have done in Section 3, the increment of the averaged loss difference Δ¯n()Δ¯n(t0)subscript¯Δ𝑛subscript¯Δ𝑛subscript𝑡0\bar{\Delta}_{n}(\infty)-\bar{\Delta}_{n}(t_{0})over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( ∞ ) - over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) from time t0subscript𝑡0t_{0}italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT to the end of training can be approximated by rn(t0)Kn(t0)rn(t0)subscript𝑟𝑛superscriptsubscript𝑡0topsubscript𝐾𝑛subscript𝑡0subscript𝑟𝑛subscript𝑡0\vec{r}_{n}(t_{0})^{\top}K_{n}(t_{0})\vec{r}_{n}(t_{0})over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ). The effective Gram matrix Kn(t0)subscript𝐾𝑛subscript𝑡0K_{n}(t_{0})italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) for the training process starting from t0subscript𝑡0t_{0}italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT can be calculated using revised contraction and perturbation factors c¯nsubscript¯𝑐𝑛\bar{c}_{n}over¯ start_ARG italic_c end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT and ϵ¯nsubscript¯italic-ϵ𝑛\bar{\epsilon}_{n}over¯ start_ARG italic_ϵ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT. The detailed calculation is given in Section B.2.

Refer to caption
Figure 4: Explained magnitude of the residual M(rn(t0),E(Kn(t0)))𝑀subscript𝑟𝑛subscript𝑡0𝐸subscript𝐾𝑛subscript𝑡0M(\vec{r}_{n}(t_{0}),E(K_{n}(t_{0})))italic_M ( over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) , italic_E ( italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ) ) (Y-axis) as a function of the explained magnitude of the effective Gram matrix M(Kn(t0))𝑀subscript𝐾𝑛subscript𝑡0M(K_{n}(t_{0}))italic_M ( italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ) (X-axis) for FC trained on MNIST with all 10 classes, with n=1100𝑛1100n=1100italic_n = 1100 and m=100𝑚100m=100italic_m = 100, but computed for different times t0subscript𝑡0t_{0}italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. We see that as the number of training iterations increases, the explained magnitude of the residuals in the subspace of the effective Gram matrix with a small explained magnitude, i.e., the non-principal subspace, decreases. Residuals at later training times project more and more predominantly in the principal subspace of the effective Gram matrix.

From Fig. 4, as training proceeds, the residual rn(t0)subscript𝑟𝑛subscript𝑡0\vec{r}_{n}(t_{0})over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) aligns more and more with the subspace of the effective Gram matrix Kn(t0)subscript𝐾𝑛subscript𝑡0K_{n}(t_{0})italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) with large eigenvalues. This is because in the initial phases of training, the residual is first fitted in the subspace with small eigenvalues, and this accumulates the generalization gap slowly. As training proceeds, to reduce the training loss, the network updates the residual to lie in less benign subspaces, those with larger eigenvalues.

Refer to caption
Refer to caption
(a) The true generalization gaps of syn-(a,b)𝑎𝑏(a,b)( italic_a , italic_b ) (a𝑎aitalic_a from small to large) are 0.16, 0.37, 0.42, 0.53, 0.55, respectively. Left: Explained magnitude of the initial residual trends towards the top-left when we reduce a𝑎aitalic_a for a larger signal-to-noise ratio. Right: Eigenspectra of the effective Gram matrix σ(Kn)𝜎subscript𝐾𝑛\sigma(K_{n})italic_σ ( italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) for datasets syn-(a,b)𝑎𝑏(a,b)( italic_a , italic_b ) have similar shapes, although their mean σ¯(Kn)¯𝜎subscript𝐾𝑛\bar{\sigma}(K_{n})over¯ start_ARG italic_σ end_ARG ( italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) increases as a𝑎aitalic_a becomes larger (4.4, 5.0, 6.8, 8.6, 10.6, a𝑎aitalic_a from small to large), indicating a larger accumulation of generalization gap in all subspaces.
Refer to caption
Refer to caption
(b) The true generalization gaps of Gaussian-α𝛼\alphaitalic_α (α𝛼\alphaitalic_α from large to small) are 0.06, 0.12, 0.23, 0.28, 0.52 respectively. Left: Explained magnitude of the initial residual trends towards the top-left when we increase α𝛼\alphaitalic_α for a larger signal-to-noise ratio. Right: Eigenspectra of the effective Gram matrix σ(Kn)𝜎subscript𝐾𝑛\sigma(K_{n})italic_σ ( italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) for datasets Gaussian-α𝛼\alphaitalic_α have similar magnitudes (σ¯(Kn)¯𝜎subscript𝐾𝑛\bar{\sigma}(K_{n})over¯ start_ARG italic_σ end_ARG ( italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) are 2.7, 4.5, 3.6, 3.5, 3.3 respectively, α𝛼\alphaitalic_α from large to small).
Figure 6: Evaluation on synthetic datasets

4.4 Effective Gram matrix for different datasets

Fig. 6 compares the normalized projection of residual and eigenspectra of the effective Gram matrix for different synthetic datasets. From the classical analysis of linear regression, we know that data is more difficult to learn when labels are correlated with features corresponding to smaller proportions of eigenvalues of the input correlation matrix. In the relabeled MNIST datasets, Syn-(a,b)𝑎𝑏(a,b)( italic_a , italic_b ) with larger a𝑎aitalic_a labels with less prominent features, and in the Gaussian datasets, Gaussian-α𝛼\alphaitalic_α with smaller α𝛼\alphaitalic_α puts less weight on the top eigenvalues as showed in Yang et al. (2022). In both cases, we manually created difficult tasks. Using experiments on synthetic datasets with different levels of difficulty, we see that for difficult tasks, the residual projects more onto the subspace corresponding to larger eigenvalues, and the effective Gram matrix Knsubscript𝐾𝑛K_{n}italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT has larger magnitude, which jointly lead to a larger predicted generalization gap by our theory. And indeed, the true generalization gap corroborates this trend.

Fig. 7 compares the training on MNIST and CIFAR-10. The initial residual of MNIST projects more in the eigenspace of the effective Gram matrix with small eigenvalues, and the eigenvalues of Knsubscript𝐾𝑛K_{n}italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT for the training of CIFAR are uniformly larger than that of MNIST. This shows that both good task-Gram matrix alignment and the small magnitude of the eigenvalues of Knsubscript𝐾𝑛K_{n}italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT are necessary for a “benign training process” and a good eventual generalization gap.

Refer to caption
Refer to caption
Figure 7: Residuals rn(0)subscript𝑟𝑛0\vec{r}_{n}(0)over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) and effective Gram matrix Knsubscript𝐾𝑛K_{n}italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT for a fully connected network trained on MNIST and CIFAR-10 with n=1100𝑛1100n=1100italic_n = 1100 and m=100𝑚100m=100italic_m = 100. For this experiment, we created a two-class classification problem for both datasets, instead of the original 10 classes. The generalization gaps for MNIST and CIFAR10 are 0.02 and 0.34 respectively. Left: Explained magnitude of the initial residual M(rn(0),E(Kn))𝑀subscript𝑟𝑛0𝐸subscript𝐾𝑛M(\vec{r}_{n}(0),E(K_{n}))italic_M ( over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) , italic_E ( italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) ) for CIFAR-10 has a larger overlap with the principal subspace of the effective Gram matrix compared to MNIST. This indicates that the generalization gap on CIFAR-10 of the trained network is larger than that on MNIST, which is corroborated by the numerical estimates of the generalization gap in our experiments. Right: Eigenvalues of the effective Gram matrix σ(Kn)𝜎subscript𝐾𝑛\sigma(K_{n})italic_σ ( italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) for MNIST and CIFAR-10 have quite different magnitudes (σ¯(Kn)¯𝜎subscript𝐾𝑛\bar{\sigma}(K_{n})over¯ start_ARG italic_σ end_ARG ( italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) are 0.60 for MNIST and 10.06 for CIFAR10).

4.5 Effective Gram matrix for different architectures

Fig. 9 compares the normalized projection of residual and eigenvalues of the effective Gram matrix for MNIST and CIFAR when trained using different models (FC, LeNet-5 and WRN-4-4). The eigenspectrum σ(Kn)𝜎subscript𝐾𝑛\sigma(K_{n})italic_σ ( italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) of FC is uniformly larger than that of LeNeT-5 trained with MNIST. Similarly, σ(Kn)𝜎subscript𝐾𝑛\sigma(K_{n})italic_σ ( italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) of FC and LeNet-5 is larger than that of WRN-4-4 when trained with CIFAR. The large magnitude of the effective Gram matrix leads to a large generalization gap accumulation in all subspaces, resulting in worse generalization.

Refer to caption
Refer to caption
(a) Residual rn(0)subscript𝑟𝑛0\vec{r}_{n}(0)over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) and effective Gram matrix Knsubscript𝐾𝑛K_{n}italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT for MNIST with all 10 classes trained with FC (blue) and LeNet-5 (orange) with n=1100𝑛1100n=1100italic_n = 1100 and m=100𝑚100m=100italic_m = 100. The generalization gaps for FC and LENET-5 are 0.48 and 0.23 respectively. Left: Explained magnitude M(rn(0),E(Kn))𝑀subscript𝑟𝑛0𝐸subscript𝐾𝑛M(\vec{r}_{n}(0),E(K_{n}))italic_M ( over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) , italic_E ( italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) ) is rather similar for both networks. Right: Eigenvalues σ(Kn)𝜎subscript𝐾𝑛\sigma(K_{n})italic_σ ( italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) for FC is larger than that of LENET-5. The mean σ¯(Kn)¯𝜎subscript𝐾𝑛\bar{\sigma}(K_{n})over¯ start_ARG italic_σ end_ARG ( italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) are 22.12 and 6.58 respectively.
Refer to caption
Refer to caption
(b) Residual rn(0)subscript𝑟𝑛0\vec{r}_{n}(0)over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) and effective Gram matrix Knsubscript𝐾𝑛K_{n}italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT for CIFAR with 2 selected classes trained with FC (blue), LeNeT-5 (orange) and WRN-4-4 (green) with n=1100𝑛1100n=1100italic_n = 1100 and m=100𝑚100m=100italic_m = 100. The generalization gaps are 0.34, 0.37, 0.11 respectively. Left: Explained magnitude S(rn(0),E(Kn))𝑆subscript𝑟𝑛0𝐸subscript𝐾𝑛S(\vec{r}_{n}(0),E(K_{n}))italic_S ( over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) , italic_E ( italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) ) is similar for LENET-5 and WRN-4-4. Right: Eigenvalues σ(Kn)𝜎subscript𝐾𝑛\sigma(K_{n})italic_σ ( italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) for FC and LENET-5 is larger than that of WRN-4-4. The mean σ¯(Kn)¯𝜎subscript𝐾𝑛\bar{\sigma}(K_{n})over¯ start_ARG italic_σ end_ARG ( italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) are 10.06, 8.44 and 0.53 respectively.
Figure 9: Evaluation using different architectures.

To demonstrate that our theory also applies to models other than neural networks, in Fig. 10, we fit ridgeless kernel regression (Rakhlin and Liang, 2020) with neural tangent kernel Ktker(x,x)=f(wn(t),x)f(wn(t),x)subscript𝐾subscript𝑡ker𝑥superscript𝑥𝑓superscriptsubscript𝑤𝑛𝑡𝑥top𝑓subscript𝑤𝑛𝑡superscript𝑥K_{t_{\text{ker}}}(x,x^{\prime})=\nabla f(w_{n}(t),x)^{\top}\nabla f(w_{n}(t),% x^{\prime})italic_K start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT ker end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x , italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) = ∇ italic_f ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_x ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∇ italic_f ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_x start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) using cross-entropy loss for different times tkersubscript𝑡kert_{\text{ker}}italic_t start_POSTSUBSCRIPT ker end_POSTSUBSCRIPT (we manually choose the time points tkersubscript𝑡kert_{\text{ker}}italic_t start_POSTSUBSCRIPT ker end_POSTSUBSCRIPT so that they spread out over the full training process). Note that the kernel here is the standard NTK, which is not related to our effective Gram matrix. The evolution of the predictor is

dft(x)dt=1ni=1nKtker(x,xi)rt(zi),rt(zi)=d(ft(x),y)dft(x).formulae-sequencederivative𝑡subscript𝑓𝑡𝑥1𝑛superscriptsubscript𝑖1𝑛subscript𝐾subscript𝑡ker𝑥subscript𝑥𝑖subscript𝑟𝑡subscript𝑧𝑖subscript𝑟𝑡subscript𝑧𝑖derivativesubscript𝑓𝑡𝑥subscript𝑓𝑡𝑥𝑦\displaystyle\derivative{f_{t}(x)}{t}=-\frac{1}{n}\sum_{i=1}^{n}K_{t_{\text{% ker}}}(x,x_{i})r_{t}(z_{i}),\quad r_{t}(z_{i})=\derivative{\ell(f_{t}(x),y)}{f% _{t}(x)}.divide start_ARG roman_d start_ARG italic_f start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) end_ARG end_ARG start_ARG roman_d start_ARG italic_t end_ARG end_ARG = - divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_K start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT ker end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_x , italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) , italic_r start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) = divide start_ARG roman_d start_ARG roman_ℓ ( italic_f start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) , italic_y ) end_ARG end_ARG start_ARG roman_d start_ARG italic_f start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( italic_x ) end_ARG end_ARG .

Using a fixed Jacobian at initialization leads to larger eigenvalues and more projection of the residual onto the stiff subspaces of the effective Gram matrix Knsubscript𝐾𝑛K_{n}italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT, i.e., a larger eventual generalization gap, as is widely known (Fort et al. (2020)).

Refer to caption
Refer to caption
Figure 10: This plot compares ridgeless kernel regression using NTK at different times. The generalization gaps are 0.67, 0.65, 0.49, 0.47, 0.43, 0.39, 0.27 respectively, tkersubscript𝑡kert_{\text{ker}}italic_t start_POSTSUBSCRIPT ker end_POSTSUBSCRIPT from small to large. Left: Explained magnitude M(rn(0),E(Kn))𝑀subscript𝑟𝑛0𝐸subscript𝐾𝑛M(\vec{r}_{n}(0),E(K_{n}))italic_M ( over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) , italic_E ( italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) ) for kernels corresponding to different times. Right: Eigenvalues σ(Kn)𝜎subscript𝐾𝑛\sigma(K_{n})italic_σ ( italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) for kernels corresponding to different times.
Refer to caption
Refer to caption
Figure 11: Residuals rn(0)subscript𝑟𝑛0\vec{r}_{n}(0)over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) and effective Gram matrix Knsubscript𝐾𝑛K_{n}italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT for FC trained on MNIST with different number of samples. For this experiment, we created a 5-class classification problem, instead of the original 10 classes. The generalization gaps are 0.09, 0.13, 0.14, 0.23, 0.27 for n𝑛nitalic_n from small to large. Left: Explained magnitude of the initial residual M(rn(0),E(Kn))𝑀subscript𝑟𝑛0𝐸subscript𝐾𝑛M(\vec{r}_{n}(0),E(K_{n}))italic_M ( over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) , italic_E ( italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) ) has a similar shape for all n𝑛nitalic_n, but the overlap with the principal subspace of the effective Gram matrix is larger for smaller n𝑛nitalic_n, which is corroborated by the numerical estimates of the generalization gap in our experiments. Right: The tail eigenvalues of the effective Gram matrix σ(Kn)𝜎subscript𝐾𝑛\sigma(K_{n})italic_σ ( italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) decreases as n𝑛nitalic_n increases.

4.6 Effective Gram matrix for different number of samples

Fig. 11 compares the training of datasets with different sizes. When n𝑛nitalic_n becomes larger, the initial residual of MNIST projects more in the tail subspaces of the effective Gram matrix Knsubscript𝐾𝑛K_{n}italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT, and the tail eigenvalue of Knsubscript𝐾𝑛K_{n}italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT becomes smaller. This coincides with the smaller generalization gap as we train with more samples.

5 Conclusion

We identified key quantities in the training process that control the generalization gap, namely, a contraction factor that brings trajectories on different datasets together, and a perturbation factor that arises from the differences in the sample sets. The merit of our analysis is that it can succinctly and faithfully characterize the generalization gap—of general neural networks. The expression in Theorem 13 depends only on the initial residual rn(0)subscript𝑟𝑛0\vec{r}_{n}(0)over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) and the effective kernel Kn(0,t)subscript𝐾𝑛0𝑡K_{n}(0,t)italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 , italic_t ). It is important to emphasize that this effective kernel is designed to understand the generalization gap, not the training dynamics. The existence and utility of this kernel indicates that we might be able to fruitfully think of deep networks in cohort with other models in a machine learning practitioner’s toolkit—perhaps they are not as anomalous as they appear to be.

Acknowledgment

This work was funded by grants provided by the National Science Foundation (IIS-2145164, CCF-2212519).

\c@NAT@ctr

References

  • Akbari et al. (2021) A. Akbari, M. Awais, M. Bashar, and J. Kittler. How does loss function affect generalization performance of deep learning? application to human age estimation. In Proceedings of the 38th International Conference on Machine Learning, volume 139 of Proceedings of Machine Learning Research, pages 141–151. PMLR, 2021.
  • Allen-Zhu et al. (2019) Z. Allen-Zhu, Y. Li, and Y. Liang. Learning and generalization in overparameterized neural networks, going beyond two layers. In Advances in Neural Information Processing Systems, volume 32, pages 6158–6169, 2019.
  • Arora et al. (2019) S. Arora, S. S. Du, W. Hu, Z. Li, and R. Wang. Fine-grained analysis of optimization and generalization for overparameterized two-layer neural networks. In Proceedings of the 36th International Conference on Machine Learning, 2019.
  • Ba et al. (2016) J. L. Ba, J. R. Kiros, and G. E. Hinton. Layer normalization. arXiv preprint arXiv:1607.06450, 2016.
  • Banerjee et al. (2022) A. Banerjee, T. Chen, X. Li, and Y. Zhou. Stability based generalization bounds for exponential family Langevin dynamics. In Proceedings of the 39th International Conference on Machine Learning, volume 162 of Proceedings of Machine Learning Research, pages 1412–1449. PMLR, 17–23 Jul 2022.
  • Bartlett et al. (2017) P. L. Bartlett, D. J. Foster, and M. Telgarsky. Spectrally-normalized margin bounds for neural networks. In Advances in Neural Information Processing Systems, volume 30, pages 6240–6249, 2017.
  • Bartlett et al. (2020) P. L. Bartlett, P. M. Long, G. Lugosi, and A. Tsigler. Benign overfitting in linear regression. Proceedings of the National Academy of Sciences, 117(48):30063–30070, 2020. doi: 10.1073/pnas.1907378117.
  • Belkin et al. (2018) M. Belkin, S. Ma, and S. Mandal. To understand deep learning we need to understand kernel learning. In Proceedings of the 35th International Conference on Machine Learning, volume 80 of Proceedings of Machine Learning Research, pages 541–549. PMLR, 10–15 Jul 2018.
  • Belkin et al. (2020) M. Belkin, D. Hsu, and J. Xu. Two models of double descent for weak features. SIAM Journal on Mathematics of Data Science, 2(4):1167–1180, 2020. doi: 10.1137/20M1336072.
  • Bousquet and Elisseeff (2002) O. Bousquet and A. Elisseeff. Stability and generalization. Journal of Machine Learning Research, 2:499–526, 2002.
  • Bowman and Montúfar (2022) B. Bowman and G. Montúfar. Spectral bias outside the training set for deep networks in the kernel regime. In Advances in Neural Information Processing Systems, volume 35, 2022.
  • Cao and Gu (2019) Y. Cao and Q. Gu. Generalization bounds of stochastic gradient descent for wide and deep neural networks. In Advances in Neural Information Processing Systems, volume 32, pages 10836–10846, 2019.
  • Charles and Papailiopoulos (2018) Z. Charles and D. Papailiopoulos. Stability and generalization of learning algorithms that converge to global optima. In Proceedings of the 35th International Conference on Machine Learning, volume 80 of Proceedings of Machine Learning Research, pages 745–754. PMLR, 10–15 Jul 2018.
  • Chizat and Bach (2018) L. Chizat and F. Bach. On the global convergence of gradient descent for over-parameterized models using optimal transport. In Advances in Neural Information Processing Systems, volume 31, pages 3036–3046. Curran Associates, Inc., 2018.
  • Chu and Raginsky (2023) Y. Chu and M. Raginsky. A unified framework for information-theoretic generalization bounds. In Advances in Neural Information Processing Systems, volume 36, pages 79260–79278, 2023.
  • Chuang et al. (2021) C.-Y. Chuang, Y. Mroueh, K. Greenewald, A. Torralba, and S. Jegelka. Measuring generalization with optimal transport. In Advances in Neural Information Processing Systems, volume 34, pages 3031–3044, 2021.
  • Du et al. (2019) S. S. Du, X. Zhai, B. Póczos, and A. Singh. Gradient descent provably optimizes over-parameterized neural networks. In International Conference on Learning Representations, 2019.
  • Dziugaite and Roy (2017) G. K. Dziugaite and D. M. Roy. Computing nonvacuous generalization bounds for deep (stochastic) neural networks with many more parameters than training data. In Proceedings of the 33rd Conference on Uncertainty in Artificial Intelligence, 2017.
  • Fort and Ganguli (2019) S. Fort and S. Ganguli. Emergent properties of the local geometry of neural loss landscapes. arXiv preprint arXiv:1910.05929, 2019.
  • Fort et al. (2020) S. Fort, G. K. Dziugaite, M. Paul, S. Kharaghani, D. M. Roy, and S. Ganguli. Deep learning versus kernel learning: An empirical study of loss landscape geometry and the time evolution of the neural tangent kernel. In Advances in Neural Information Processing Systems, volume 33, pages 5850–5861, 2020.
  • Futami and Fujisawa (2023) F. Futami and M. Fujisawa. Time-independent information-theoretic generalization bounds for sgld. arXiv preprint arXiv:2311.01046, 2023.
  • Hafez-Kolahi et al. (2020) H. Hafez-Kolahi, Z. Golgooni, S. Kasaei, and M. Soleymani. Conditioning and processing: Techniques to improve information-theoretic generalization bounds. In Advances in Neural Information Processing Systems, volume 33, pages 3492–3503. Curran Associates, Inc., 2020.
  • Hardt et al. (2016) M. Hardt, B. Recht, and Y. Singer. Train faster, generalize better: Stability of stochastic gradient descent. In Proceedings of The 33rd International Conference on Machine Learning, volume 48 of Proceedings of Machine Learning Research, pages 1225–1234, New York, NY, USA, June 20–22 2016. PMLR.
  • Hastie et al. (2022) T. Hastie, A. Montanari, S. Rosset, and R. J. Tibshirani. Surprises in high-dimensional ridgeless least squares interpolation. The Annals of Statistics, 50(2):949–986, 2022. doi: 10.1214/21-AOS2133.
  • Isidori (1995) A. Isidori. Nonlinear Control Systems. Communications and Control Engineering. Springer-Verlag, London, 3rd edition, 1995. doi: 10.1007/978-1-84628-615-5.
  • Jacot et al. (2018) A. Jacot, F. Gabriel, and C. Hongler. Neural tangent kernel: Convergence and generalization in neural networks. In Advances in Neural Information Processing Systems, volume 31, pages 8571–8580. Curran Associates, Inc., 2018.
  • Jacot et al. (2020) A. Jacot, B. Şimşek, F. Spadaro, C. Hongler, and F. Gabriel. Kernel alignment risk estimator: Risk prediction from training data. In Advances in Neural Information Processing Systems, volume 33, pages 15568–15578, 2020.
  • Kawaguchi et al. (2018) K. Kawaguchi, L. P. Kaelbling, and Y. Bengio. Generalization in deep learning. Technical report, Massachusetts Institute of Technology, 2018.
  • Kawaguchi et al. (2022) K. Kawaguchi, Z. Deng, K. Luh, and J. Huang. Robustness implies generalization via data-dependent generalization bounds. In Proceedings of the 39th International Conference on Machine Learning, volume 162 of Proceedings of Machine Learning Research, pages 10866–10894. PMLR, 17–23 Jul 2022.
  • Kozachkov et al. (2023) L. Kozachkov, P. M. Wensing, and J.-J. E. Slotine. Generalization as dynamical robustness—the role of riemannian contraction in supervised learning. Transactions on Machine Learning Research, 2023.
  • Krizhevsky (2009) A. Krizhevsky. Learning multiple layers of features from tiny images. Technical report, University of Toronto, 2009.
  • Laurent and von Brecht (2018) T. Laurent and J. von Brecht. Deep linear networks with arbitrary loss: All local minima are global. In Proceedings of the 35th International Conference on Machine Learning, volume 80 of Proceedings of Machine Learning Research, pages 2902–2907. PMLR, 2018.
  • LeCun et al. (1998) Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner. Gradient-based learning applied to document recognition. Proceedings of the IEEE, 86(11):2278–2324, 1998.
  • Li and Liang (2018) Y. Li and Y. Liang. Learning overparameterized neural networks via stochastic gradient descent on structured data. In Advances in Neural Information Processing Systems, volume 31, pages 8157–8166. Curran Associates, Inc., 2018.
  • Liu et al. (2022) F. Liu, H. Yang, S. Hayou, and Q. Li. From optimization dynamics to generalization bounds via lojasiewicz gradient inequality. Transactions on Machine Learning Research, 2022. Accepted; to appear.
  • Lohmiller and Slotine (1998) W. Lohmiller and J.-J. E. Slotine. On contraction analysis for non-linear systems. Automatica, 34(6):683–696, 1998. ISSN 0005-1098.
  • Lohmiller and Slotine (2000) W. Lohmiller and J.-J. E. Slotine. Nonlinear process control using contraction theory. AIChE Journal, 46(3):588–596, March 2000. doi: 10.1002/aic.690460316.
  • Lugosi and Neu (2022) G. Lugosi and G. Neu. Generalization bounds via convex analysis. In Proceedings of the 35th Conference on Learning Theory (COLT), volume 178 of Proceedings of Machine Learning Research, pages 3523–3544. PMLR, 2022.
  • Magnus (1954) W. Magnus. On the exponential solution of differential equations. Communications on Pure and Applied Mathematics, 7(4):649–673, 1954. doi: 10.1002/cpa.3160070404.
  • Mallinar et al. (2022) N. Mallinar, J. B. Simon, A. Abedsoltan, P. Pandit, M. Belkin, and P. Nakkiran. Benign, tempered, or catastrophic: A taxonomy of overfitting. In Advances in Neural Information Processing Systems, volume 35, pages 29912–29925. Curran Associates, Inc., 2022.
  • Mao et al. (2024) J. Mao, I. Griniasty, H. K. Teoh, R. Ramesh, R. Yang, M. K. Transtrum, J. P. Sethna, and P. Chaudhari. The training process of many deep networks explores the same low-dimensional manifold. Proceedings of the National Academy of Sciences, 121(12):e2310002121, 2024. doi: 10.1073/pnas.2310002121.
  • Marino and Tomei (1995) R. Marino and P. Tomei. Nonlinear Control Design: Geometric, Adaptive, and Robust. Prentice Hall, London, 1995. ISBN 978-0133426359.
  • McAllester (1999) D. A. McAllester. Pac-bayesian model averaging. In Proceedings of the Twelfth Annual Conference on Computational Learning Theory (COLT), pages 164–170. ACM, 1999. doi: 10.1145/307400.307435.
  • McDiarmid (1989) C. McDiarmid. On the method of bounded differences. In Surveys in Combinatorics, 1989 (Norwich, 1989), volume 141 of London Mathematical Society Lecture Note Series, pages 148–188. Cambridge University Press, 1989.
  • Mei and Montanari (2022) S. Mei and A. Montanari. The generalization error of random features regression: Precise asymptotics and the double descent curve. Communications on Pure and Applied Mathematics, 75(4):667–766, 2022. doi: 10.1002/cpa.22008.
  • Mei et al. (2019) S. Mei, T. Misiakiewicz, and A. Montanari. Mean-field theory of two-layers neural networks: dimension-free bounds and kernel limit. In Proceedings of the 32nd Conference on Learning Theory, volume 99 of Proceedings of Machine Learning Research, pages 2388–2464. PMLR, 2019.
  • Mou et al. (2018) W. Mou, L. Wang, X. Zhai, and K. Zheng. Generalization bounds of sgld for non-convex learning: Two theoretical viewpoints. In Proceedings of the 31st Conference On Learning Theory, volume 75 of Proceedings of Machine Learning Research, pages 605–638. PMLR, 06–09 Jul 2018.
  • Negrea et al. (2019) J. Negrea, M. Haghifam, G. K. Dziugaite, A. Khisti, and D. M. Roy. Information-theoretic generalization bounds for sgld via data-dependent estimates. In Advances in Neural Information Processing Systems, volume 32, 2019.
  • Neu et al. (2021) G. Neu, G. K. Dziugaite, M. Haghifam, and D. M. Roy. Information-theoretic generalization bounds for stochastic gradient descent. In Proceedings of the 34th Annual Conference on Learning Theory, volume 134 of Proceedings of Machine Learning Research, pages 3526–3546. PMLR, 2021.
  • Neyshabur et al. (2018) B. Neyshabur, S. Bhojanapalli, and N. Srebro. A pac-bayesian approach to spectrally-normalized margin bounds for neural networks. In International Conference on Learning Representations, 2018.
  • Rakhlin and Liang (2020) A. Rakhlin and T. Liang. Just interpolate: Kernel ’ridgeless’ regression can generalize. Annals of Statistics, 48(3):1329–1347, 2020. doi: 10.1214/19-AOS1849.
  • Ramesh et al. (2024) R. Ramesh, A. Bisulco, R. W. DiTullio, L. Wei, V. Balasubramanian, K. Daniilidis, and P. Chaudhari. Many perception tasks are highly redundant functions of their input data. arXiv preprint arXiv:2407.13841, 2024.
  • Richards and Kuzborskij (2021) D. Richards and I. Kuzborskij. Stability & generalisation of gradient descent for shallow neural networks without the neural tangent kernel. In Advances in Neural Information Processing Systems, volume 34, pages 21812–21823, 2021.
  • Steinke and Zakynthinou (2020) T. Steinke and L. Zakynthinou. Reasoning about generalization via conditional mutual information. In Proceedings of the 33rd Conference on Learning Theory, volume 125 of Proceedings of Machine Learning Research, pages 3437–3452. PMLR, 2020.
  • Tsukamoto et al. (2021) H. Tsukamoto, S.-J. Chung, and J.-J. Slotine. Contraction theory for nonlinear stability analysis and learning-based control: A tutorial overview. Annual Reviews in Control, 52:135–148, October 2021. doi: 10.1016/j.arcontrol.2021.10.001.
  • Valiant (1984) L. G. Valiant. A theory of the learnable. Communications of the ACM, 27(11):1134–1142, 1984. doi: 10.1145/1968.1972.
  • Wang and Ma (2022) M. Wang and C. Ma. Generalization error bounds for deep neural networks trained by sgd. arXiv preprint arXiv:2206.03299, 2022.
  • Xu and Raginsky (2017) A. Xu and M. Raginsky. Information-theoretic analysis of generalization capability of learning algorithms. In Advances in Neural Information Processing Systems, volume 30, pages 2524–2533, 2017.
  • Xu and Mannor (2012) H. Xu and S. Mannor. Robustness and generalization. Machine Learning, 86(3):391–423, 2012. doi: 10.1007/s10994-011-5268-1.
  • Yang et al. (2022) R. Yang, J. Mao, and P. Chaudhari. Does the data induce capacity control in deep learning? In Proceedings of the 39th International Conference on Machine Learning, volume 162 of Proceedings of Machine Learning Research, pages 25348–25368. PMLR, 2022.
  • Zagoruyko and Komodakis (2016) S. Zagoruyko and N. Komodakis. Wide residual networks. In Proceedings of the British Machine Vision Conference (BMVC). BMVA Press, 2016.

Appendix A Proofs and Calculations in Section 3

A.1 Proof of Lemma 3

By the definition of the averaged loss difference Δ¯n(t)subscript¯Δ𝑛𝑡\bar{\Delta}_{n}(t)over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ),

Δ¯n(t)=1n(i=1n(wni(t),zi))¯(wn(t),Sn)subscript¯Δ𝑛𝑡1𝑛superscriptsubscript𝑖1𝑛superscriptsubscript𝑤𝑛𝑖𝑡subscript𝑧𝑖¯subscript𝑤𝑛𝑡subscript𝑆𝑛\displaystyle\bar{\Delta}_{n}(t)=\frac{1}{n}\left(\sum_{i=1}^{n}\ell(w_{n}^{-i% }(t),z_{i})\right)-\bar{\ell}(w_{n}(t),S_{n})over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) = divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ( ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT roman_ℓ ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) - over¯ start_ARG roman_ℓ end_ARG ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT )

Taking the expectation on both sides, we have

𝔼[Δ¯n(t)]=𝔼[R(Sn1,t)]𝔼[Rtrain(Sn,t)]𝔼subscript¯Δ𝑛𝑡𝔼𝑅subscript𝑆𝑛1𝑡𝔼subscript𝑅trainsubscript𝑆𝑛𝑡\displaystyle\operatorname*{\mathbb{E}}\left[\bar{\Delta}_{n}(t)\right]=% \operatorname*{\mathbb{E}}\left[R(S_{n-1},t)\right]-\operatorname*{\mathbb{E}}% \left[R_{\text{train}}(S_{n},t)\right]blackboard_E [ over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) ] = blackboard_E [ italic_R ( italic_S start_POSTSUBSCRIPT italic_n - 1 end_POSTSUBSCRIPT , italic_t ) ] - blackboard_E [ italic_R start_POSTSUBSCRIPT train end_POSTSUBSCRIPT ( italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_t ) ]

By assumption, we have

𝔼[R(Sn,t)]𝔼[R(Sn1,t)],𝔼[Rtrain(Sn1,t)]𝔼[Rtrain(Sn,t)].formulae-sequence𝔼𝑅subscript𝑆𝑛𝑡𝔼𝑅subscript𝑆𝑛1𝑡𝔼subscript𝑅trainsubscript𝑆𝑛1𝑡𝔼subscript𝑅trainsubscript𝑆𝑛𝑡\displaystyle\operatorname*{\mathbb{E}}\left[R(S_{n},t)\right]\leq% \operatorname*{\mathbb{E}}\left[R(S_{n-1},t)\right],\quad\operatorname*{% \mathbb{E}}\left[R_{\text{train}}(S_{n-1},t)\right]\leq\operatorname*{\mathbb{% E}}\left[R_{\text{train}}(S_{n},t)\right].blackboard_E [ italic_R ( italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_t ) ] ≤ blackboard_E [ italic_R ( italic_S start_POSTSUBSCRIPT italic_n - 1 end_POSTSUBSCRIPT , italic_t ) ] , blackboard_E [ italic_R start_POSTSUBSCRIPT train end_POSTSUBSCRIPT ( italic_S start_POSTSUBSCRIPT italic_n - 1 end_POSTSUBSCRIPT , italic_t ) ] ≤ blackboard_E [ italic_R start_POSTSUBSCRIPT train end_POSTSUBSCRIPT ( italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_t ) ] .

Therefore,

𝔼[δR(Sn,t)]𝔼[Δ¯n(t)]𝔼[δR(Sn1,t)].𝔼𝛿𝑅subscript𝑆𝑛𝑡𝔼delimited-[]subscript¯Δ𝑛𝑡𝔼𝛿𝑅subscript𝑆𝑛1𝑡\displaystyle\operatorname*{\mathbb{E}}\left[\delta R(S_{n},t)\right]\leq% \mathbb{E}\left[\bar{\Delta}_{n}(t)\right]\leq\operatorname*{\mathbb{E}}\left[% \delta R(S_{n-1},t)\right].blackboard_E [ italic_δ italic_R ( italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_t ) ] ≤ blackboard_E [ over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) ] ≤ blackboard_E [ italic_δ italic_R ( italic_S start_POSTSUBSCRIPT italic_n - 1 end_POSTSUBSCRIPT , italic_t ) ] .

By the assumption that 𝔼[δR(Sn,t)]/𝔼[δR(Sn1,t)]1𝔼𝛿𝑅subscript𝑆𝑛𝑡𝔼𝛿𝑅subscript𝑆𝑛1𝑡1\operatorname*{\mathbb{E}}[\delta R(S_{n},t)]/\operatorname*{\mathbb{E}}[% \delta R(S_{n-1},t)]\to 1blackboard_E [ italic_δ italic_R ( italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_t ) ] / blackboard_E [ italic_δ italic_R ( italic_S start_POSTSUBSCRIPT italic_n - 1 end_POSTSUBSCRIPT , italic_t ) ] → 1 as n𝑛n\to\inftyitalic_n → ∞, we have that

𝔼[δR(Sn1,t)]𝔼[δR(Sn,t)]𝔼[δR(Sn,t)]0𝔼𝛿𝑅subscript𝑆𝑛1𝑡𝔼𝛿𝑅subscript𝑆𝑛𝑡𝔼𝛿𝑅subscript𝑆𝑛𝑡0\displaystyle\frac{\operatorname*{\mathbb{E}}[\delta R(S_{n-1},t)]-% \operatorname*{\mathbb{E}}[\delta R(S_{n},t)]}{\operatorname*{\mathbb{E}}[% \delta R(S_{n},t)]}\to 0divide start_ARG blackboard_E [ italic_δ italic_R ( italic_S start_POSTSUBSCRIPT italic_n - 1 end_POSTSUBSCRIPT , italic_t ) ] - blackboard_E [ italic_δ italic_R ( italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_t ) ] end_ARG start_ARG blackboard_E [ italic_δ italic_R ( italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_t ) ] end_ARG → 0

as n𝑛n\to\inftyitalic_n → ∞. Hence,

𝔼[δR(Sn,t)]=𝔼[Δ¯n(t)]+o(𝔼[δR(Sn,t)]).𝔼𝛿𝑅subscript𝑆𝑛𝑡𝔼subscript¯Δ𝑛𝑡𝑜𝔼𝛿𝑅subscript𝑆𝑛𝑡\operatorname*{\mathbb{E}}\left[\delta R(S_{n},t)\right]=\operatorname*{% \mathbb{E}}\left[\bar{\Delta}_{n}(t)\right]+o\left(\operatorname*{\mathbb{E}}% \left[\delta R(S_{n},t)\right]\right).blackboard_E [ italic_δ italic_R ( italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_t ) ] = blackboard_E [ over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) ] + italic_o ( blackboard_E [ italic_δ italic_R ( italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_t ) ] ) .

The scenarios when the assumptions in Lemma 3 hold

  • The expected generalization loss 𝔼[R(Sn,t)]𝔼𝑅subscript𝑆𝑛𝑡\operatorname*{\mathbb{E}}[R(S_{n},t)]blackboard_E [ italic_R ( italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_t ) ] is non-increasing in n𝑛nitalic_n: This holds for ridgeless linear regression without label noise. When label noise is non-zero, the generalization loss decays monotonically when the number of samples is greater than the number of features. This result also holds for ridge regression when the ridge coefficient λ𝜆\lambdaitalic_λ decays with n𝑛nitalic_n, but not too fast, i.e., λ>σ2σ2+θ21n𝜆superscript𝜎2superscript𝜎2superscriptnormsuperscript𝜃21𝑛\lambda>\frac{\sigma^{2}}{\sigma^{2}+\norm{\theta^{*}}^{2}}\cdot\frac{1}{n}italic_λ > divide start_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + ∥ start_ARG italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_ARG ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ⋅ divide start_ARG 1 end_ARG start_ARG italic_n end_ARG where σ𝜎\sigmaitalic_σ is the noise variance and θsuperscript𝜃\theta^{*}italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT is the true regressor. See Hastie et al. (2022) for reference. Similar results hold for kernel regression and random feature regression-based architectures Mei and Montanari (2022), which are both widely used models in the analysis of neural networks. For consistent estimators, the generalization loss converges to the Bayes risk asymptotically. Although this decrease need not be strictly monotonic. Estimators like Empirical Risk Minimizer (ERM), Structural Risk Minimization (SRM) are consistent under mild assumptions on the hypothesis class, e.g., having finite capacity.

  • The expected training loss 𝔼[Rtrain(Sn,t)]𝔼subscript𝑅trainsubscript𝑆𝑛𝑡\operatorname*{\mathbb{E}}[R_{\text{train}}(S_{n},t)]blackboard_E [ italic_R start_POSTSUBSCRIPT train end_POSTSUBSCRIPT ( italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_t ) ] is non-decreasing: This holds for ridgeless linear regression in general, and therefore for kernel regression and random feature-based models of neural networks. Note that this assumption can be modified slightly to be 𝔼[Rtrain(Sn1,t)]𝔼[Rtrain(Sn,t)]+B/n𝔼subscript𝑅trainsubscript𝑆𝑛1𝑡𝔼subscript𝑅trainsubscript𝑆𝑛𝑡𝐵𝑛\operatorname*{\mathbb{E}}[R_{\text{train}}(S_{n-1},t)]\leq\operatorname*{% \mathbb{E}}[R_{\text{train}}(S_{n},t)]+B/nblackboard_E [ italic_R start_POSTSUBSCRIPT train end_POSTSUBSCRIPT ( italic_S start_POSTSUBSCRIPT italic_n - 1 end_POSTSUBSCRIPT , italic_t ) ] ≤ blackboard_E [ italic_R start_POSTSUBSCRIPT train end_POSTSUBSCRIPT ( italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_t ) ] + italic_B / italic_n. This new condition holds for empirical risk minimization (ERM) with bounded loss |(w,z)|B𝑤𝑧𝐵|\ell(w,z)|\leq B| roman_ℓ ( italic_w , italic_z ) | ≤ italic_B in general. The resulting left-hand side of the inequality in Lemma 3 gets an additive term of B/n𝐵𝑛B/nitalic_B / italic_n correspondingly. The rest of our calculations stay as they are.

  • The expected generalization gap 𝔼[δR(Sn,t)]𝔼𝛿𝑅subscript𝑆𝑛𝑡\operatorname*{\mathbb{E}}[\delta R(S_{n},t)]blackboard_E [ italic_δ italic_R ( italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_t ) ] is non-negative: This holds for empirical risk minimization (ERM), in general.

Concentration of Δ¯n(t)subscript¯Δ𝑛𝑡\bar{\Delta}_{n}(t)over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) to 𝔼[Δ¯n(t)]𝔼delimited-[]subscript¯Δ𝑛𝑡\mathbb{E}[\bar{\Delta}_{n}(t)]blackboard_E [ over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) ]

We first define the notion of stability for deterministic algorithm 𝒜𝒜\mathcal{A}caligraphic_A that maps from space of datasets to weight space, i.e. 𝒜:n=0𝒵n𝒲:𝒜superscriptsubscript𝑛0superscript𝒵𝑛𝒲\mathcal{A}:\cup_{n=0}^{\infty}\mathcal{Z}^{n}\to\mathcal{W}caligraphic_A : ∪ start_POSTSUBSCRIPT italic_n = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT caligraphic_Z start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT → caligraphic_W.

Definition 21.

An algorithm 𝒜𝒜\mathcal{A}caligraphic_A is uniformly ε𝜀\varepsilonitalic_ε-stable if for all datasets S𝑆Sitalic_S, Ssuperscript𝑆S^{\prime}italic_S start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT differing in at most one sample, we have

supz|(𝒜(S),z)(𝒜(S),z)|εsubscriptsupremum𝑧𝒜𝑆𝑧𝒜superscript𝑆𝑧𝜀\displaystyle\sup_{z}\absolutevalue{\ell(\mathcal{A}(S),z)-\ell(\mathcal{A}(S^% {\prime}),z)}\leq\varepsilonroman_sup start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT | start_ARG roman_ℓ ( caligraphic_A ( italic_S ) , italic_z ) - roman_ℓ ( caligraphic_A ( italic_S start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) , italic_z ) end_ARG | ≤ italic_ε

Now we define the set of algorithms ΓΓ\Gammaroman_Γ that maps dataset S𝑆Sitalic_S to points on the gradient flow trajectory trained on S𝑆Sitalic_S at certain time points.

Γ={𝒜:𝒜(S)=w(t),t0,w satisfiesdwdt=¯(w,S),w(0)𝒲,Sn𝒵n}Γconditional-set𝒜formulae-sequence𝒜𝑆𝑤𝑡formulae-sequence𝑡0formulae-sequencew satisfiesderivative𝑡𝑤¯𝑤𝑆formulae-sequence𝑤0𝒲𝑆subscript𝑛superscript𝒵𝑛\displaystyle\Gamma=\bigg{\{}\mathcal{A}:\mathcal{A}(S)=w(t),t\geq 0,\text{$w$% satisfies}\derivative{w}{t}=-\nabla\bar{\ell}(w,S),w(0)\in\mathcal{W},S\in% \cup_{n\in\mathbb{N}}\mathcal{Z}^{n}\bigg{\}}roman_Γ = { caligraphic_A : caligraphic_A ( italic_S ) = italic_w ( italic_t ) , italic_t ≥ 0 , italic_w satisfies divide start_ARG roman_d start_ARG italic_w end_ARG end_ARG start_ARG roman_d start_ARG italic_t end_ARG end_ARG = - ∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S ) , italic_w ( 0 ) ∈ caligraphic_W , italic_S ∈ ∪ start_POSTSUBSCRIPT italic_n ∈ blackboard_N end_POSTSUBSCRIPT caligraphic_Z start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT }
Lemma 22.

Assume that (1) |(w,z)|B𝑤𝑧𝐵\absolutevalue{\ell(w,z)}\leq B| start_ARG roman_ℓ ( italic_w , italic_z ) end_ARG | ≤ italic_B for all w𝒲,z𝒵formulae-sequence𝑤𝒲𝑧𝒵w\in\mathcal{W},z\in\mathcal{Z}italic_w ∈ caligraphic_W , italic_z ∈ caligraphic_Z, (2) 𝒜Γfor-all𝒜Γ\forall\mathcal{A}\in\Gamma∀ caligraphic_A ∈ roman_Γ, 𝒜𝒜\mathcal{A}caligraphic_A is ε𝜀\varepsilonitalic_ε-stable, then for all t>0𝑡0t>0italic_t > 0, with probability 1δ1𝛿1-\delta1 - italic_δ,

|Δ¯n(t)𝔼[Δ¯n(t)]|(nε+2B)2log(2/δ)nsubscript¯Δ𝑛𝑡𝔼delimited-[]subscript¯Δ𝑛𝑡𝑛𝜀2𝐵22𝛿𝑛\displaystyle\absolutevalue{\bar{\Delta}_{n}(t)-\mathbb{E}[\bar{\Delta}_{n}(t)% ]}\leq(n\varepsilon+2B)\sqrt{\frac{2\log(2/\delta)}{n}}| start_ARG over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) - blackboard_E [ over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) ] end_ARG | ≤ ( italic_n italic_ε + 2 italic_B ) square-root start_ARG divide start_ARG 2 roman_log ( start_ARG 2 / italic_δ end_ARG ) end_ARG start_ARG italic_n end_ARG end_ARG
  • Proof.

    Let S~nsubscript~𝑆𝑛\tilde{S}_{n}over~ start_ARG italic_S end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT denote a modified dataset of Snsubscript𝑆𝑛S_{n}italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT by replacing the sample zjsubscript𝑧𝑗z_{j}italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT with a different sample z~jsubscript~𝑧𝑗\tilde{z}_{j}over~ start_ARG italic_z end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT. Let w~n(t)subscript~𝑤𝑛𝑡\tilde{w}_{n}(t)over~ start_ARG italic_w end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ), w~ni(t)superscriptsubscript~𝑤𝑛𝑖𝑡\tilde{w}_{n}^{-i}(t)over~ start_ARG italic_w end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) be the corresponding trajectories trained with S~nsubscript~𝑆𝑛\tilde{S}_{n}over~ start_ARG italic_S end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT and S~nisuperscriptsubscript~𝑆𝑛𝑖\tilde{S}_{n}^{-i}over~ start_ARG italic_S end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT (the removed-i𝑖iitalic_ith sample version of Snsubscript𝑆𝑛S_{n}italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT). Note that w~nj(t)=wnj(t)superscriptsubscript~𝑤𝑛𝑗𝑡superscriptsubscript𝑤𝑛𝑗𝑡\tilde{w}_{n}^{-j}(t)=w_{n}^{-j}(t)over~ start_ARG italic_w end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_j end_POSTSUPERSCRIPT ( italic_t ) = italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_j end_POSTSUPERSCRIPT ( italic_t ). Let Δ¯(S~n,t)¯Δsubscript~𝑆𝑛𝑡\bar{\Delta}(\tilde{S}_{n},t)over¯ start_ARG roman_Δ end_ARG ( over~ start_ARG italic_S end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_t ) and Δ¯(Sn,t)¯Δsubscript𝑆𝑛𝑡\bar{\Delta}(S_{n},t)over¯ start_ARG roman_Δ end_ARG ( italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_t ) be the averaged loss difference calculated on S~nsubscript~𝑆𝑛\tilde{S}_{n}over~ start_ARG italic_S end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT and Snsubscript𝑆𝑛S_{n}italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT respectively. By assumptions (1) and (2), we have

    |¯(w~n(t),S~n)¯(wn(t),Sn)|¯subscript~𝑤𝑛𝑡subscript~𝑆𝑛¯subscript𝑤𝑛𝑡subscript𝑆𝑛\displaystyle|\bar{\ell}(\tilde{w}_{n}(t),\tilde{S}_{n})-\bar{\ell}(w_{n}(t),S% _{n})|| over¯ start_ARG roman_ℓ end_ARG ( over~ start_ARG italic_w end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , over~ start_ARG italic_S end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) - over¯ start_ARG roman_ℓ end_ARG ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) | (n1)εn+2Bnε+2Bnabsent𝑛1𝜀𝑛2𝐵𝑛𝜀2𝐵𝑛\displaystyle\leq\frac{(n-1)\varepsilon}{n}+\frac{2B}{n}\leq\varepsilon+\frac{% 2B}{n}≤ divide start_ARG ( italic_n - 1 ) italic_ε end_ARG start_ARG italic_n end_ARG + divide start_ARG 2 italic_B end_ARG start_ARG italic_n end_ARG ≤ italic_ε + divide start_ARG 2 italic_B end_ARG start_ARG italic_n end_ARG
    |(w~ni(t),zi)(wni(t),zi)|superscriptsubscript~𝑤𝑛𝑖𝑡subscript𝑧𝑖superscriptsubscript𝑤𝑛𝑖𝑡subscript𝑧𝑖\displaystyle|\ell(\tilde{w}_{n}^{-i}(t),z_{i})-\ell(w_{n}^{-i}(t),z_{i})|| roman_ℓ ( over~ start_ARG italic_w end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - roman_ℓ ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) | εijformulae-sequenceabsent𝜀for-all𝑖𝑗\displaystyle\leq\varepsilon\quad\forall i\neq j≤ italic_ε ∀ italic_i ≠ italic_j
    |(w~nj(t),zj)(wnj(t),zj)|superscriptsubscript~𝑤𝑛𝑗𝑡subscript𝑧𝑗superscriptsubscript𝑤𝑛𝑗𝑡subscript𝑧𝑗\displaystyle|\ell(\tilde{w}_{n}^{-j}(t),z_{j})-\ell(w_{n}^{-j}(t),z_{j})|| roman_ℓ ( over~ start_ARG italic_w end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_j end_POSTSUPERSCRIPT ( italic_t ) , italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) - roman_ℓ ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_j end_POSTSUPERSCRIPT ( italic_t ) , italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) | 2Bnabsent2𝐵𝑛\displaystyle\leq\frac{2B}{n}≤ divide start_ARG 2 italic_B end_ARG start_ARG italic_n end_ARG

    Hence we have

    |Δ¯(S~n,t)Δ¯(Sn,t)|2ε+4Bn¯Δsubscript~𝑆𝑛𝑡¯Δsubscript𝑆𝑛𝑡2𝜀4𝐵𝑛\displaystyle\absolutevalue{\bar{\Delta}(\tilde{S}_{n},t)-\bar{\Delta}(S_{n},t% )}\leq 2\varepsilon+\frac{4B}{n}| start_ARG over¯ start_ARG roman_Δ end_ARG ( over~ start_ARG italic_S end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_t ) - over¯ start_ARG roman_Δ end_ARG ( italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_t ) end_ARG | ≤ 2 italic_ε + divide start_ARG 4 italic_B end_ARG start_ARG italic_n end_ARG (19)

    Inequality Eq. 19 gives the replace-one-sample difference of Δ¯nsubscript¯Δ𝑛\bar{\Delta}_{n}over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT, hence by McDiarmid’s inequality (McDiarmid, 1989), we have the following concentration inequality,

    Sn[|Δ¯n𝔼[Δ¯n]|a]2exp(2a2n(2ε+4B/n)2)subscriptsubscript𝑆𝑛delimited-[]subscript¯Δ𝑛𝔼delimited-[]subscript¯Δ𝑛𝑎22superscript𝑎2𝑛superscript2𝜀4𝐵𝑛2\displaystyle\mathbb{P}_{S_{n}}\left[|\bar{\Delta}_{n}-\mathbb{E}[\bar{\Delta}% _{n}]|\geq a\right]\leq 2\exp\left(-\frac{2a^{2}}{n(2\varepsilon+4B/n)^{2}}\right)blackboard_P start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ | over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT - blackboard_E [ over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ] | ≥ italic_a ] ≤ 2 roman_exp ( - divide start_ARG 2 italic_a start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_n ( 2 italic_ε + 4 italic_B / italic_n ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG )

    Setting the right hand side to δ𝛿\deltaitalic_δ, we have with probability at least 1δ1𝛿1-\delta1 - italic_δ,

    |Δ¯n𝔼[Δ¯n]|(nε+2B)2log(2/δ)n.subscript¯Δ𝑛𝔼delimited-[]subscript¯Δ𝑛𝑛𝜀2𝐵22𝛿𝑛\displaystyle\absolutevalue{\bar{\Delta}_{n}-\mathbb{E}[\bar{\Delta}_{n}]}\leq% (n\varepsilon+2B)\cdot\sqrt{\frac{2\log(2/\delta)}{n}}.| start_ARG over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT - blackboard_E [ over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ] end_ARG | ≤ ( italic_n italic_ε + 2 italic_B ) ⋅ square-root start_ARG divide start_ARG 2 roman_log ( start_ARG 2 / italic_δ end_ARG ) end_ARG start_ARG italic_n end_ARG end_ARG .

Remark 23.

In general, the convergence of Δ¯nsubscript¯Δ𝑛\bar{\Delta}_{n}over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT to 𝔼[Δ¯n]𝔼delimited-[]subscript¯Δ𝑛\mathbb{E}\left[\bar{\Delta}_{n}\right]blackboard_E [ over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ] can be guaranteed by different versions of algorithm stability (e.g. hypothesis stability, pointwise hypothesis stability and uniform stability (Bousquet and Elisseeff, 2002)). Charles and Papailiopoulos (2018) shows that the algorithm 𝒜𝒜\mathcal{A}caligraphic_A is C(L,μ)/(n1)𝐶𝐿𝜇𝑛1C(L,\mu)/(n-1)italic_C ( italic_L , italic_μ ) / ( italic_n - 1 )-uniformly stable if (w,z)𝑤𝑧\ell(w,z)roman_ℓ ( italic_w , italic_z ) is L𝐿Litalic_L-Lipchitz in w𝑤witalic_w and ¯(w,S)¯𝑤𝑆\bar{\ell}(w,S)over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S ) is μ𝜇\muitalic_μ-PL (Polyak Lojasiewicz) in w𝑤witalic_w, where C(L,μ)𝐶𝐿𝜇C(L,\mu)italic_C ( italic_L , italic_μ ) is a constant depending on L𝐿Litalic_L and μ𝜇\muitalic_μ. Other versions of stability can also be guaranteed by PL and QG (quadratic growth) conditions as showed in Charles and Papailiopoulos (2018).

A.2 Proof of Lemma 4

By taking the derivative of the pointwise loss difference Δni(t)superscriptsubscriptΔ𝑛𝑖𝑡\Delta_{n}^{-i}(t)roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ), we have,

dΔni(t)dtsuperscriptsubscriptΔ𝑛𝑖𝑡𝑡\displaystyle\frac{\differential\Delta_{n}^{-i}(t)}{\differential t}divide start_ARG start_DIFFOP roman_d end_DIFFOP roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) end_ARG start_ARG start_DIFFOP roman_d end_DIFFOP italic_t end_ARG =d((wni(t),zi)(wn(t),zi))dtabsentsuperscriptsubscript𝑤𝑛𝑖𝑡subscript𝑧𝑖subscript𝑤𝑛𝑡subscript𝑧𝑖𝑡\displaystyle=\frac{\differential\left(\ell(w_{n}^{-i}(t),z_{i})-\ell(w_{n}(t)% ,z_{i})\right)}{\differential t}= divide start_ARG start_DIFFOP roman_d end_DIFFOP ( roman_ℓ ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - roman_ℓ ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) end_ARG start_ARG start_DIFFOP roman_d end_DIFFOP italic_t end_ARG
=(w,zi)¯(w,Sni)|wni(t)((w,zi)¯(w,Sn)|wn(t))absentevaluated-at𝑤subscript𝑧𝑖¯𝑤superscriptsubscript𝑆𝑛𝑖superscriptsubscript𝑤𝑛𝑖𝑡evaluated-at𝑤subscript𝑧𝑖¯𝑤subscript𝑆𝑛subscript𝑤𝑛𝑡\displaystyle=-\nabla\ell(w,z_{i})\cdot\nabla\bar{\ell}(w,S_{n}^{-i})\big{|}_{% w_{n}^{-i}(t)}-\left(-\nabla\ell(w,z_{i})\cdot\nabla\bar{\ell}(w,S_{n})\big{|}% _{w_{n}(t)}\right)= - ∇ roman_ℓ ( italic_w , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ⋅ ∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ) | start_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) end_POSTSUBSCRIPT - ( - ∇ roman_ℓ ( italic_w , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ⋅ ∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) | start_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_POSTSUBSCRIPT )
=((w,zi)¯(w,Sni)|wni(t)(w,zi)¯(w,Sni)|wn(t))absentevaluated-at𝑤subscript𝑧𝑖¯𝑤superscriptsubscript𝑆𝑛𝑖superscriptsubscript𝑤𝑛𝑖𝑡evaluated-at𝑤subscript𝑧𝑖¯𝑤superscriptsubscript𝑆𝑛𝑖subscript𝑤𝑛𝑡\displaystyle=-\left(\nabla\ell(w,z_{i})\cdot\nabla\bar{\ell}(w,S_{n}^{-i})% \big{|}_{w_{n}^{-i}(t)}-\nabla\ell(w,z_{i})\cdot\nabla\bar{\ell}(w,S_{n}^{-i})% \big{|}_{w_{n}(t)}\right)= - ( ∇ roman_ℓ ( italic_w , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ⋅ ∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ) | start_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) end_POSTSUBSCRIPT - ∇ roman_ℓ ( italic_w , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ⋅ ∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ) | start_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_POSTSUBSCRIPT )
+((w,zi)¯(w,Sni)|wn(t)+(w,zi)¯(w,Sn)|wn(t))evaluated-at𝑤subscript𝑧𝑖¯𝑤superscriptsubscript𝑆𝑛𝑖subscript𝑤𝑛𝑡evaluated-at𝑤subscript𝑧𝑖¯𝑤subscript𝑆𝑛subscript𝑤𝑛𝑡\displaystyle+\left(-\nabla\ell(w,z_{i})\cdot\nabla\bar{\ell}(w,S_{n}^{-i})% \big{|}_{w_{n}(t)}+\nabla\ell(w,z_{i})\cdot\nabla\bar{\ell}(w,S_{n})\big{|}_{w% _{n}(t)}\right)+ ( - ∇ roman_ℓ ( italic_w , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ⋅ ∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ) | start_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_POSTSUBSCRIPT + ∇ roman_ℓ ( italic_w , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ⋅ ∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) | start_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_POSTSUBSCRIPT )
=((w,zi)¯(w,Sni)|wni(t)(w,zi)¯(w,Sni)|wn(t))absentevaluated-at𝑤subscript𝑧𝑖¯𝑤superscriptsubscript𝑆𝑛𝑖superscriptsubscript𝑤𝑛𝑖𝑡evaluated-at𝑤subscript𝑧𝑖¯𝑤superscriptsubscript𝑆𝑛𝑖subscript𝑤𝑛𝑡\displaystyle=-\left(\nabla\ell(w,z_{i})\cdot\nabla\bar{\ell}(w,S_{n}^{-i})% \big{|}_{w_{n}^{-i}(t)}-\nabla\ell(w,z_{i})\cdot\nabla\bar{\ell}(w,S_{n}^{-i})% \big{|}_{w_{n}(t)}\right)= - ( ∇ roman_ℓ ( italic_w , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ⋅ ∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ) | start_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) end_POSTSUBSCRIPT - ∇ roman_ℓ ( italic_w , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ⋅ ∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ) | start_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_POSTSUBSCRIPT )
+(w,zi)(¯(w,Sn)¯(w,Sni))|wn(t)evaluated-at𝑤subscript𝑧𝑖¯𝑤subscript𝑆𝑛¯𝑤superscriptsubscript𝑆𝑛𝑖subscript𝑤𝑛𝑡\displaystyle+\nabla\ell(w,z_{i})\left(\nabla\bar{\ell}(w,S_{n})-\nabla\bar{% \ell}(w,S_{n}^{-i})\right)\big{|}_{w_{n}(t)}+ ∇ roman_ℓ ( italic_w , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( ∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) - ∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ) ) | start_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_POSTSUBSCRIPT

Hence,

dΔni(t)dt=cni(t)Δni(t)+ϵni(t),derivative𝑡subscriptsuperscriptΔ𝑖𝑛𝑡superscriptsubscript𝑐𝑛𝑖𝑡superscriptsubscriptΔ𝑛𝑖𝑡superscriptsubscriptitalic-ϵ𝑛𝑖𝑡\displaystyle\derivative{\Delta^{-i}_{n}(t)}{t}=-c_{n}^{-i}(t)\Delta_{n}^{-i}(% t)+\epsilon_{n}^{-i}(t),divide start_ARG roman_d start_ARG roman_Δ start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_ARG end_ARG start_ARG roman_d start_ARG italic_t end_ARG end_ARG = - italic_c start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) + italic_ϵ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) ,

where

cni(t)=(w,zi)¯(w,Sni)|wn(t)wni(t)Δni(t),superscriptsubscript𝑐𝑛𝑖𝑡evaluated-at𝑤subscript𝑧𝑖¯𝑤superscriptsubscript𝑆𝑛𝑖subscript𝑤𝑛𝑡superscriptsubscript𝑤𝑛𝑖𝑡superscriptsubscriptΔ𝑛𝑖𝑡\displaystyle c_{n}^{-i}(t)=\frac{\nabla\ell(w,z_{i})\cdot\nabla\bar{\ell}(w,S% _{n}^{-i})\big{|}^{w_{n}^{-i}(t)}_{w_{n}(t)}}{\Delta_{n}^{-i}(t)},italic_c start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) = divide start_ARG ∇ roman_ℓ ( italic_w , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ⋅ ∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ) | start_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_POSTSUBSCRIPT end_ARG start_ARG roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) end_ARG ,

and

ϵni(t)=(w,zi)(¯(w,Sn)¯(w,Sni))|wn(t).superscriptsubscriptitalic-ϵ𝑛𝑖𝑡evaluated-at𝑤subscript𝑧𝑖¯𝑤subscript𝑆𝑛¯𝑤superscriptsubscript𝑆𝑛𝑖subscript𝑤𝑛𝑡\displaystyle\epsilon_{n}^{-i}(t)=\nabla\ell(w,z_{i})\cdot\left(\nabla\bar{% \ell}(w,S_{n})-\nabla\bar{\ell}(w,S_{n}^{-i})\right)\bigg{|}_{w_{n}(t)}.italic_ϵ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) = ∇ roman_ℓ ( italic_w , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ⋅ ( ∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) - ∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ) ) | start_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_POSTSUBSCRIPT .

A.3 Evolution of Δ¯n(t)subscript¯Δ𝑛𝑡\bar{\Delta}_{n}(t)over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t )

The evolution of Δ¯n(t)subscript¯Δ𝑛𝑡\bar{\Delta}_{n}(t)over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) can be derived through that of Δni(t)superscriptsubscriptΔ𝑛𝑖𝑡\Delta_{n}^{-i}(t)roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ).

dΔ¯n(t)dtsubscript¯Δ𝑛𝑡𝑡\displaystyle\frac{\differential\bar{\Delta}_{n}(t)}{\differential t}divide start_ARG start_DIFFOP roman_d end_DIFFOP over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_ARG start_ARG start_DIFFOP roman_d end_DIFFOP italic_t end_ARG =1ni=1ndΔni(t)dtabsent1𝑛superscriptsubscript𝑖1𝑛superscriptsubscriptΔ𝑛𝑖𝑡𝑡\displaystyle=\frac{1}{n}\sum_{i=1}^{n}\frac{\differential\Delta_{n}^{-i}(t)}{% \differential t}= divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT divide start_ARG start_DIFFOP roman_d end_DIFFOP roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) end_ARG start_ARG start_DIFFOP roman_d end_DIFFOP italic_t end_ARG
=1ni=1n(cni(t)Δni(t)+ϵni(t))absent1𝑛superscriptsubscript𝑖1𝑛superscriptsubscript𝑐𝑛𝑖𝑡superscriptsubscriptΔ𝑛𝑖𝑡superscriptsubscriptitalic-ϵ𝑛𝑖𝑡\displaystyle=-\frac{1}{n}\sum_{i=1}^{n}\left(c_{n}^{-i}(t)\Delta_{n}^{-i}(t)+% \epsilon_{n}^{-i}(t)\right)= - divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ( italic_c start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) + italic_ϵ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) )
=1ni=1ncni(t)Δni(t)Δ¯n(t)Δ¯n(t)+1ni=1nϵni(t)absent1𝑛superscriptsubscript𝑖1𝑛superscriptsubscript𝑐𝑛𝑖𝑡superscriptsubscriptΔ𝑛𝑖𝑡subscript¯Δ𝑛𝑡subscript¯Δ𝑛𝑡1𝑛superscriptsubscript𝑖1𝑛superscriptsubscriptitalic-ϵ𝑛𝑖𝑡\displaystyle=-\frac{\frac{1}{n}\sum_{i=1}^{n}c_{n}^{-i}(t)\Delta_{n}^{-i}(t)}% {\bar{\Delta}_{n}(t)}\bar{\Delta}_{n}(t)+\frac{1}{n}\sum_{i=1}^{n}\epsilon_{n}% ^{-i}(t)= - divide start_ARG divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_c start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) end_ARG start_ARG over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_ARG over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) + divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_ϵ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t )
=c¯n(t)Δ¯n(t)+ϵ¯n(t).absentsubscript¯𝑐𝑛𝑡subscript¯Δ𝑛𝑡subscript¯italic-ϵ𝑛𝑡\displaystyle=-\bar{c}_{n}(t)\bar{\Delta}_{n}(t)+\bar{\epsilon}_{n}(t).= - over¯ start_ARG italic_c end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) + over¯ start_ARG italic_ϵ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) .

Here we have,

c¯n(t)subscript¯𝑐𝑛𝑡\displaystyle\bar{c}_{n}(t)over¯ start_ARG italic_c end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) =1ni=1ncni(t)Δni(t)Δ¯n(t)absent1𝑛superscriptsubscript𝑖1𝑛superscriptsubscript𝑐𝑛𝑖𝑡superscriptsubscriptΔ𝑛𝑖𝑡subscript¯Δ𝑛𝑡\displaystyle=\frac{\frac{1}{n}\sum_{i=1}^{n}c_{n}^{-i}(t)\Delta_{n}^{-i}(t)}{% \bar{\Delta}_{n}(t)}= divide start_ARG divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_c start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) end_ARG start_ARG over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_ARG
=1ni=1n(w,zi)¯(w,Sni)|wn(t)wni(t)Δ¯n(t),absentevaluated-at1𝑛superscriptsubscript𝑖1𝑛𝑤subscript𝑧𝑖¯𝑤superscriptsubscript𝑆𝑛𝑖subscript𝑤𝑛𝑡superscriptsubscript𝑤𝑛𝑖𝑡subscript¯Δ𝑛𝑡\displaystyle=\frac{\frac{1}{n}\sum_{i=1}^{n}\nabla\ell(w,z_{i})\cdot\nabla% \bar{\ell}(w,S_{n}^{-i})\big{|}_{w_{n}(t)}^{w_{n}^{-i}(t)}}{\bar{\Delta}_{n}(t% )},= divide start_ARG divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∇ roman_ℓ ( italic_w , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ⋅ ∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ) | start_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT end_ARG start_ARG over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_ARG ,

and

ϵ¯n(t)subscript¯italic-ϵ𝑛𝑡\displaystyle\bar{\epsilon}_{n}(t)over¯ start_ARG italic_ϵ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) =1ni=1nϵni(t)absent1𝑛superscriptsubscript𝑖1𝑛superscriptsubscriptitalic-ϵ𝑛𝑖𝑡\displaystyle=\frac{1}{n}\sum_{i=1}^{n}\epsilon_{n}^{-i}(t)= divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_ϵ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t )
=1ni=1ni(1nj=1nj1n1jij)absent1𝑛superscriptsubscript𝑖1𝑛superscriptsubscript𝑖top1𝑛superscriptsubscript𝑗1𝑛subscript𝑗1𝑛1subscript𝑗𝑖subscript𝑗\displaystyle=\frac{1}{n}\sum_{i=1}^{n}\nabla\ell_{i}^{\top}\left(\frac{1}{n}% \sum_{j=1}^{n}\nabla\ell_{j}-\frac{1}{n-1}\sum_{j\neq i}\nabla\ell_{j}\right)= divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∇ roman_ℓ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∇ roman_ℓ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT - divide start_ARG 1 end_ARG start_ARG italic_n - 1 end_ARG ∑ start_POSTSUBSCRIPT italic_j ≠ italic_i end_POSTSUBSCRIPT ∇ roman_ℓ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT )
=1ni=1ni(1ni1n(n1)jij)absent1𝑛superscriptsubscript𝑖1𝑛superscriptsubscript𝑖top1𝑛subscript𝑖1𝑛𝑛1subscript𝑗𝑖subscript𝑗\displaystyle=\frac{1}{n}\sum_{i=1}^{n}\nabla\ell_{i}^{\top}\left(\frac{1}{n}% \nabla\ell_{i}-\frac{1}{n(n-1)}\sum_{j\neq i}\nabla\ell_{j}\right)= divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∇ roman_ℓ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∇ roman_ℓ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - divide start_ARG 1 end_ARG start_ARG italic_n ( italic_n - 1 ) end_ARG ∑ start_POSTSUBSCRIPT italic_j ≠ italic_i end_POSTSUBSCRIPT ∇ roman_ℓ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT )
=1n2i=1nii1n2(n1)ijij.absent1superscript𝑛2superscriptsubscript𝑖1𝑛superscriptsubscript𝑖topsubscript𝑖1superscript𝑛2𝑛1subscript𝑖𝑗superscriptsubscript𝑖topsubscript𝑗\displaystyle=\frac{1}{n^{2}}\sum_{i=1}^{n}\nabla\ell_{i}^{\top}\nabla\ell_{i}% -\frac{1}{n^{2}(n-1)}\sum_{i\neq j}\nabla\ell_{i}^{\top}\nabla\ell_{j}.= divide start_ARG 1 end_ARG start_ARG italic_n start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∇ roman_ℓ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∇ roman_ℓ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - divide start_ARG 1 end_ARG start_ARG italic_n start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_n - 1 ) end_ARG ∑ start_POSTSUBSCRIPT italic_i ≠ italic_j end_POSTSUBSCRIPT ∇ roman_ℓ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∇ roman_ℓ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT .

In the calculation above, we use isubscript𝑖\nabla\ell_{i}∇ roman_ℓ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, ¯¯\nabla\bar{\ell}∇ over¯ start_ARG roman_ℓ end_ARG as an abbreviation for (wn(t),zi)subscript𝑤𝑛𝑡subscript𝑧𝑖\nabla\ell(w_{n}(t),z_{i})∇ roman_ℓ ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ), ¯(wn(t),Sn)¯subscript𝑤𝑛𝑡subscript𝑆𝑛\nabla\bar{\ell}(w_{n}(t),S_{n})∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) respectively. Notice that we have the following decomposition of the gradient covariance matrix Σ^(t)^Σ𝑡\hat{\Sigma}(t)over^ start_ARG roman_Σ end_ARG ( italic_t ):

Σ^(t)^Σ𝑡\displaystyle\hat{\Sigma}(t)over^ start_ARG roman_Σ end_ARG ( italic_t ) =1ni=1n(i¯)(i¯)absent1𝑛superscriptsubscript𝑖1𝑛subscript𝑖¯superscriptsubscript𝑖¯top\displaystyle=\frac{1}{n}\sum_{i=1}^{n}\left(\nabla\ell_{i}-\nabla\bar{\ell}% \right)\left(\nabla\ell_{i}-\nabla\bar{\ell}\right)^{\top}= divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ( ∇ roman_ℓ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - ∇ over¯ start_ARG roman_ℓ end_ARG ) ( ∇ roman_ℓ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - ∇ over¯ start_ARG roman_ℓ end_ARG ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT
=1ni=1nii1n2(i=1ni)(i=1ni)absent1𝑛superscriptsubscript𝑖1𝑛subscript𝑖superscriptsubscript𝑖top1superscript𝑛2superscriptsubscript𝑖1𝑛subscript𝑖superscriptsuperscriptsubscript𝑖1𝑛subscript𝑖top\displaystyle=\frac{1}{n}\sum_{i=1}^{n}\nabla\ell_{i}\nabla\ell_{i}^{\top}-% \frac{1}{n^{2}}\left(\sum_{i=1}^{n}\nabla\ell_{i}\right)\left(\sum_{i=1}^{n}% \nabla\ell_{i}\right)^{\top}= divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∇ roman_ℓ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∇ roman_ℓ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT - divide start_ARG 1 end_ARG start_ARG italic_n start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ( ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∇ roman_ℓ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∇ roman_ℓ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT
=n1n2i=1nii1n2ijijabsent𝑛1superscript𝑛2superscriptsubscript𝑖1𝑛subscript𝑖superscriptsubscript𝑖top1superscript𝑛2subscript𝑖𝑗subscript𝑖superscriptsubscript𝑗top\displaystyle=\frac{n-1}{n^{2}}\sum_{i=1}^{n}\nabla\ell_{i}\nabla\ell_{i}^{% \top}-\frac{1}{n^{2}}\sum_{i\neq j}\nabla\ell_{i}\nabla\ell_{j}^{\top}= divide start_ARG italic_n - 1 end_ARG start_ARG italic_n start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∇ roman_ℓ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∇ roman_ℓ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT - divide start_ARG 1 end_ARG start_ARG italic_n start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ∑ start_POSTSUBSCRIPT italic_i ≠ italic_j end_POSTSUBSCRIPT ∇ roman_ℓ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∇ roman_ℓ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT

Hence, we have

ϵ¯n(t)=trΣ^(t)n1,Σ^n(t)=CovzUnif(Sn)(wn(t),z),formulae-sequencesubscript¯italic-ϵ𝑛𝑡trace^Σ𝑡𝑛1subscript^Σ𝑛𝑡subscriptCovsimilar-to𝑧Unifsubscript𝑆𝑛subscript𝑤𝑛𝑡𝑧\displaystyle\bar{\epsilon}_{n}(t)=\frac{\tr\hat{\Sigma}(t)}{n-1},\quad\hat{% \Sigma}_{n}(t)=\operatorname*{\text{Cov}}_{z\sim\text{Unif}(S_{n})}\nabla\ell(% w_{n}(t),z),over¯ start_ARG italic_ϵ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) = divide start_ARG roman_tr over^ start_ARG roman_Σ end_ARG ( italic_t ) end_ARG start_ARG italic_n - 1 end_ARG , over^ start_ARG roman_Σ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) = cov start_POSTSUBSCRIPT italic_z ∼ Unif ( italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT ∇ roman_ℓ ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_z ) ,

where Σ^n(t)subscript^Σ𝑛𝑡\hat{\Sigma}_{n}(t)over^ start_ARG roman_Σ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) represents the covariance matrix of (wn(t),z)subscript𝑤𝑛𝑡𝑧\nabla\ell(w_{n}(t),z)∇ roman_ℓ ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_z ) for z𝑧zitalic_z sampled uniformly from the dataset Snsubscript𝑆𝑛S_{n}italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT.

Evolution of 𝔼[Δ¯n(t)]𝔼subscript¯Δ𝑛𝑡\operatorname*{\mathbb{E}}\left[\bar{\Delta}_{n}(t)\right]blackboard_E [ over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) ]:

A modified version of c¯nsubscript¯𝑐𝑛\bar{c}_{n}over¯ start_ARG italic_c end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT and ϵ¯nsubscript¯italic-ϵ𝑛\bar{\epsilon}_{n}over¯ start_ARG italic_ϵ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT,

c¯n=𝔼[1ni=1n(w,zi)¯(w,Sni)|wn(t)wni(t)]𝔼[Δ¯n(t)],ϵ¯n=𝔼[trΣ^(t)]n1,formulae-sequencesubscript¯𝑐𝑛𝔼evaluated-at1𝑛superscriptsubscript𝑖1𝑛𝑤subscript𝑧𝑖¯𝑤superscriptsubscript𝑆𝑛𝑖subscript𝑤𝑛𝑡superscriptsubscript𝑤𝑛𝑖𝑡𝔼subscript¯Δ𝑛𝑡subscript¯italic-ϵ𝑛𝔼trace^Σ𝑡𝑛1\displaystyle\bar{c}_{n}=\frac{\operatorname*{\mathbb{E}}\left[\frac{1}{n}\sum% _{i=1}^{n}\nabla\ell(w,z_{i})\cdot\nabla\bar{\ell}(w,S_{n}^{-i})\big{|}_{w_{n}% (t)}^{w_{n}^{-i}(t)}\right]}{\operatorname*{\mathbb{E}}\left[\bar{\Delta}_{n}(% t)\right]},\quad\bar{\epsilon}_{n}=\frac{\operatorname*{\mathbb{E}}\left[\tr% \hat{\Sigma}(t)\right]}{n-1},over¯ start_ARG italic_c end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT = divide start_ARG blackboard_E [ divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∇ roman_ℓ ( italic_w , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ⋅ ∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ) | start_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ] end_ARG start_ARG blackboard_E [ over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) ] end_ARG , over¯ start_ARG italic_ϵ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT = divide start_ARG blackboard_E [ roman_tr over^ start_ARG roman_Σ end_ARG ( italic_t ) ] end_ARG start_ARG italic_n - 1 end_ARG ,

gives the evolution of 𝔼[Δ¯n(t)]𝔼delimited-[]subscript¯Δ𝑛𝑡\mathbb{E}\left[\bar{\Delta}_{n}(t)\right]blackboard_E [ over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) ],

d𝔼[Δ¯n(t)]dt=c¯n(t)𝔼[Δ¯n(t)]+ϵ¯n(t).derivative𝑡𝔼delimited-[]subscript¯Δ𝑛𝑡subscript¯𝑐𝑛𝑡𝔼delimited-[]subscript¯Δ𝑛𝑡subscript¯italic-ϵ𝑛𝑡\displaystyle\derivative{\mathbb{E}\left[\bar{\Delta}_{n}(t)\right]}{t}=-\bar{% c}_{n}(t)\mathbb{E}\left[\bar{\Delta}_{n}(t)\right]+\bar{\epsilon}_{n}(t).divide start_ARG roman_d start_ARG blackboard_E [ over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) ] end_ARG end_ARG start_ARG roman_d start_ARG italic_t end_ARG end_ARG = - over¯ start_ARG italic_c end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) blackboard_E [ over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) ] + over¯ start_ARG italic_ϵ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) .

A.4 Evolution of rn(t)subscript𝑟𝑛𝑡\vec{r}_{n}(t)over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t )

We derive the equation governing the evolution of rn(t)subscript𝑟𝑛𝑡\vec{r}_{n}(t)over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) by calculating its time derivative.

drn(t)dtsubscript𝑟𝑛𝑡𝑡\displaystyle\frac{\differential\vec{r}_{n}(t)}{\differential t}divide start_ARG start_DIFFOP roman_d end_DIFFOP over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_ARG start_ARG start_DIFFOP roman_d end_DIFFOP italic_t end_ARG =1n[dr(wn(t),z1)dtdr(wn(t),zn)dt]=1n[r(wn(t),z1)dwn(t)dtr(wn(t),zn)dwn(t)dt]absent1𝑛matrixderivative𝑡𝑟subscript𝑤𝑛𝑡subscript𝑧1derivative𝑡𝑟subscript𝑤𝑛𝑡subscript𝑧𝑛1𝑛matrix𝑟superscriptsubscript𝑤𝑛𝑡subscript𝑧1topderivative𝑡subscript𝑤𝑛𝑡𝑟superscriptsubscript𝑤𝑛𝑡subscript𝑧𝑛topderivative𝑡subscript𝑤𝑛𝑡\displaystyle=\frac{1}{\sqrt{n}}\begin{bmatrix}\derivative{r(w_{n}(t),z_{1})}{% t}\\ ...\\ \derivative{r(w_{n}(t),z_{n})}{t}\end{bmatrix}=\frac{1}{\sqrt{n}}\begin{% bmatrix}\nabla r(w_{n}(t),z_{1})^{\top}\derivative{w_{n}(t)}{t}\\ ...\\ \nabla r(w_{n}(t),z_{n})^{\top}\derivative{w_{n}(t)}{t}\end{bmatrix}= divide start_ARG 1 end_ARG start_ARG square-root start_ARG italic_n end_ARG end_ARG [ start_ARG start_ROW start_CELL divide start_ARG roman_d start_ARG italic_r ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_ARG end_ARG start_ARG roman_d start_ARG italic_t end_ARG end_ARG end_CELL end_ROW start_ROW start_CELL … end_CELL end_ROW start_ROW start_CELL divide start_ARG roman_d start_ARG italic_r ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_z start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) end_ARG end_ARG start_ARG roman_d start_ARG italic_t end_ARG end_ARG end_CELL end_ROW end_ARG ] = divide start_ARG 1 end_ARG start_ARG square-root start_ARG italic_n end_ARG end_ARG [ start_ARG start_ROW start_CELL ∇ italic_r ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT divide start_ARG roman_d start_ARG italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_ARG end_ARG start_ARG roman_d start_ARG italic_t end_ARG end_ARG end_CELL end_ROW start_ROW start_CELL … end_CELL end_ROW start_ROW start_CELL ∇ italic_r ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_z start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT divide start_ARG roman_d start_ARG italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_ARG end_ARG start_ARG roman_d start_ARG italic_t end_ARG end_ARG end_CELL end_ROW end_ARG ]
=1n[1nj=1nr(wn(t),z1)f(wn(t),xj)r(wn(t),zj)1nj=1nr(wn(t),zn)f(wn(t),xj)r(wn(t),zj)]=1nPn(t)rn(t).absent1𝑛matrix1𝑛superscriptsubscript𝑗1𝑛𝑟superscriptsubscript𝑤𝑛𝑡subscript𝑧1top𝑓subscript𝑤𝑛𝑡subscript𝑥𝑗𝑟subscript𝑤𝑛𝑡subscript𝑧𝑗1𝑛superscriptsubscript𝑗1𝑛𝑟superscriptsubscript𝑤𝑛𝑡subscript𝑧𝑛top𝑓subscript𝑤𝑛𝑡subscript𝑥𝑗𝑟subscript𝑤𝑛𝑡subscript𝑧𝑗1𝑛subscript𝑃𝑛𝑡subscript𝑟𝑛𝑡\displaystyle=\frac{1}{\sqrt{n}}\begin{bmatrix}-\frac{1}{n}\sum_{j=1}^{n}% \nabla r(w_{n}(t),z_{1})^{\top}\nabla f(w_{n}(t),x_{j})\cdot r(w_{n}(t),z_{j})% \\ ...\\ -\frac{1}{n}\sum_{j=1}^{n}\nabla r(w_{n}(t),z_{n})^{\top}\nabla f(w_{n}(t),x_{% j})\cdot r(w_{n}(t),z_{j})\end{bmatrix}=-\frac{1}{n}P_{n}(t)\vec{r}_{n}(t).= divide start_ARG 1 end_ARG start_ARG square-root start_ARG italic_n end_ARG end_ARG [ start_ARG start_ROW start_CELL - divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∇ italic_r ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∇ italic_f ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ⋅ italic_r ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) end_CELL end_ROW start_ROW start_CELL … end_CELL end_ROW start_ROW start_CELL - divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∇ italic_r ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_z start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∇ italic_f ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ⋅ italic_r ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) end_CELL end_ROW end_ARG ] = - divide start_ARG 1 end_ARG start_ARG italic_n end_ARG italic_P start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) .

Here the third equality follows from the evolution of wn(t)subscript𝑤𝑛𝑡w_{n}(t)italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ):

dwn(t)dtderivative𝑡subscript𝑤𝑛𝑡\displaystyle\derivative{w_{n}(t)}{t}divide start_ARG roman_d start_ARG italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_ARG end_ARG start_ARG roman_d start_ARG italic_t end_ARG end_ARG =(wn(t),Sn)absentsubscript𝑤𝑛𝑡subscript𝑆𝑛\displaystyle=-\nabla\ell(w_{n}(t),S_{n})= - ∇ roman_ℓ ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT )
=1ni=1nf(wn(t),xi)d(f(wn(t),xi),yi)df(wn(t),xi)absent1𝑛superscriptsubscript𝑖1𝑛𝑓subscript𝑤𝑛𝑡subscript𝑥𝑖derivative𝑓subscript𝑤𝑛𝑡subscript𝑥𝑖𝑓subscript𝑤𝑛𝑡subscript𝑥𝑖subscript𝑦𝑖\displaystyle=-\frac{1}{n}\sum_{i=1}^{n}\nabla f(w_{n}(t),x_{i})\derivative{% \ell(f(w_{n}(t),x_{i}),y_{i})}{f(w_{n}(t),x_{i})}= - divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∇ italic_f ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) divide start_ARG roman_d start_ARG roman_ℓ ( italic_f ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG end_ARG start_ARG roman_d start_ARG italic_f ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG end_ARG
=1ni=1nf(wn(t),xi)r(wn(t),xi).absent1𝑛superscriptsubscript𝑖1𝑛𝑓subscript𝑤𝑛𝑡subscript𝑥𝑖𝑟subscript𝑤𝑛𝑡subscript𝑥𝑖\displaystyle=-\frac{1}{n}\sum_{i=1}^{n}\nabla f(w_{n}(t),x_{i})r(w_{n}(t),x_{% i}).= - divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∇ italic_f ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) italic_r ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) .

Note that

r(w,zi),f(w,zi)𝑟𝑤subscript𝑧𝑖𝑓𝑤subscript𝑧𝑖\displaystyle\nabla r(w,z_{i}),\nabla f(w,z_{i})∇ italic_r ( italic_w , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) , ∇ italic_f ( italic_w , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) 𝒲×𝒴absent𝒲𝒴\displaystyle\in\mathcal{W}\times\mathcal{Y}∈ caligraphic_W × caligraphic_Y
Pn(t)=[r(wn(t),zi)f(wn(t),xj)]i,j[n]subscript𝑃𝑛𝑡subscriptdelimited-[]𝑟superscriptsubscript𝑤𝑛𝑡subscript𝑧𝑖top𝑓subscript𝑤𝑛𝑡subscript𝑥𝑗𝑖𝑗delimited-[]𝑛\displaystyle P_{n}(t)=\left[\nabla r(w_{n}(t),z_{i})^{\top}\nabla f(w_{n}(t),% x_{j})\right]_{i,j\in[n]}italic_P start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) = [ ∇ italic_r ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∇ italic_f ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ] start_POSTSUBSCRIPT italic_i , italic_j ∈ [ italic_n ] end_POSTSUBSCRIPT 𝒴n×𝒴n.absentsuperscript𝒴𝑛superscript𝒴𝑛\displaystyle\in\mathcal{Y}^{n}\times\mathcal{Y}^{n}.∈ caligraphic_Y start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT × caligraphic_Y start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT .

A.5 Proof of Lemma 10

We have the following decomposition of the gradient covariance Σ^n(t)subscript^Σ𝑛𝑡\hat{\Sigma}_{n}(t)over^ start_ARG roman_Σ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ):

Σ^n(t)subscript^Σ𝑛𝑡\displaystyle\hat{\Sigma}_{n}(t)over^ start_ARG roman_Σ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) =1ni=1n(i¯)(i¯)absent1𝑛superscriptsubscript𝑖1𝑛subscript𝑖¯superscriptsubscript𝑖¯top\displaystyle=\frac{1}{n}\sum_{i=1}^{n}\left(\nabla\ell_{i}-\nabla\bar{\ell}% \right)\left(\nabla\ell_{i}-\nabla\bar{\ell}\right)^{\top}= divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ( ∇ roman_ℓ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - ∇ over¯ start_ARG roman_ℓ end_ARG ) ( ∇ roman_ℓ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - ∇ over¯ start_ARG roman_ℓ end_ARG ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT
=1ni=1nii¯¯absent1𝑛superscriptsubscript𝑖1𝑛subscript𝑖superscriptsubscript𝑖top¯superscript¯top\displaystyle=\frac{1}{n}\sum_{i=1}^{n}\nabla\ell_{i}\nabla\ell_{i}^{\top}-% \nabla\bar{\ell}\nabla\bar{\ell}^{\top}= divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∇ roman_ℓ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∇ roman_ℓ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT - ∇ over¯ start_ARG roman_ℓ end_ARG ∇ over¯ start_ARG roman_ℓ end_ARG start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT

where

isubscript𝑖\displaystyle\nabla\ell_{i}∇ roman_ℓ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT =f(wn(t),xi)ri(t)absent𝑓subscript𝑤𝑛𝑡subscript𝑥𝑖subscript𝑟𝑖𝑡\displaystyle=\nabla f(w_{n}(t),x_{i})r_{i}(t)= ∇ italic_f ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) italic_r start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t )
¯¯\displaystyle\nabla\bar{\ell}∇ over¯ start_ARG roman_ℓ end_ARG =1ni=1nf(wn(t),xi)ri(t).absent1𝑛superscriptsubscript𝑖1𝑛𝑓subscript𝑤𝑛𝑡subscript𝑥𝑖subscript𝑟𝑖𝑡\displaystyle=\frac{1}{n}\sum_{i=1}^{n}\nabla f(w_{n}(t),x_{i})r_{i}(t).= divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∇ italic_f ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) italic_r start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t ) .

Hence, we have

Σ^n(t)subscript^Σ𝑛𝑡\displaystyle\hat{\Sigma}_{n}(t)over^ start_ARG roman_Σ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) =rn(t)Mn(t)rn(t)1nrn(t)Hn(t)rn(t)absentsubscript𝑟𝑛superscript𝑡topsubscript𝑀𝑛𝑡subscript𝑟𝑛𝑡1𝑛subscript𝑟𝑛superscript𝑡topsubscript𝐻𝑛𝑡subscript𝑟𝑛𝑡\displaystyle=\vec{r}_{n}(t)^{\top}M_{n}(t)\vec{r}_{n}(t)-\frac{1}{n}\vec{r}_{% n}(t)^{\top}H_{n}(t)\vec{r}_{n}(t)= over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_M start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) - divide start_ARG 1 end_ARG start_ARG italic_n end_ARG over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_H start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t )
=rn(t)(Mn(t)Hn(t)n)rn(t).absentsubscript𝑟𝑛superscript𝑡topsubscript𝑀𝑛𝑡subscript𝐻𝑛𝑡𝑛subscript𝑟𝑛𝑡\displaystyle=\vec{r}_{n}(t)^{\top}\left(M_{n}(t)-\frac{H_{n}(t)}{n}\right)% \vec{r}_{n}(t).= over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( italic_M start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) - divide start_ARG italic_H start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_ARG start_ARG italic_n end_ARG ) over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) .

where

Mn(t)subscript𝑀𝑛𝑡\displaystyle M_{n}(t)italic_M start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) =diag(f(w,x1)f(w,x1),,f(w,xn)f(w,xn))|wn(t),absentevaluated-atdiag𝑓superscript𝑤subscript𝑥1top𝑓𝑤subscript𝑥1𝑓superscript𝑤subscript𝑥𝑛top𝑓𝑤subscript𝑥𝑛subscript𝑤𝑛𝑡\displaystyle=\scalebox{1.0}{\mbox{$\displaystyle\operatorname{\text{diag}}% \left(\nabla f(w,x_{1})^{\top}\nabla f(w,x_{1}),\dots,\nabla f(w,x_{n})^{\top}% \nabla f(w,x_{n})\right)\bigg{|}_{w_{n}(t)}$}},= diag ( ∇ italic_f ( italic_w , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∇ italic_f ( italic_w , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , … , ∇ italic_f ( italic_w , italic_x start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∇ italic_f ( italic_w , italic_x start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) ) | start_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_POSTSUBSCRIPT ,
Hn(t)subscript𝐻𝑛𝑡\displaystyle H_{n}(t)italic_H start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) =[f(wn(t),xi)f(wn(t),xj)]i,j[n].absentsubscriptdelimited-[]𝑓superscriptsubscript𝑤𝑛𝑡subscript𝑥𝑖top𝑓subscript𝑤𝑛𝑡subscript𝑥𝑗𝑖𝑗delimited-[]𝑛\displaystyle=\left[\nabla f(w_{n}(t),x_{i})^{\top}\nabla f(w_{n}(t),x_{j})% \right]_{i,j\in[n]}.= [ ∇ italic_f ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∇ italic_f ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ] start_POSTSUBSCRIPT italic_i , italic_j ∈ [ italic_n ] end_POSTSUBSCRIPT .

A.6 Proof of Theorem 13

By Eq. 13 and Lemma 10,

trΣ^n(t)=rn(0)Ωn(t)(Mn(t)Hn(t)n)Ωn(t)rn(0).tracesubscript^Σ𝑛𝑡subscript𝑟𝑛superscript0topsubscriptΩ𝑛superscript𝑡topsubscript𝑀𝑛𝑡subscript𝐻𝑛𝑡𝑛subscriptΩ𝑛𝑡subscript𝑟𝑛0\tr\hat{\Sigma}_{n}(t)=\vec{r}_{n}(0)^{\top}\Omega_{n}(t)^{\top}\left(M_{n}(t)% -\frac{H_{n}(t)}{n}\right)\Omega_{n}(t)\vec{r}_{n}(0).roman_tr over^ start_ARG roman_Σ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) = over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT roman_Ω start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( italic_M start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) - divide start_ARG italic_H start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_ARG start_ARG italic_n end_ARG ) roman_Ω start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) . (20)

Combining with the solution of Δ¯n(t)subscript¯Δ𝑛𝑡\bar{\Delta}_{n}(t)over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) Eq. 7, we have

Δ¯n(t)=rn(0)(0tΩn(s)(Mn(s)Hn(s)n)Ωn(s)exp(stc¯n(u)du)dsn1)rn(0)subscript¯Δ𝑛𝑡subscript𝑟𝑛superscript0topsuperscriptsubscript0𝑡subscriptΩ𝑛superscript𝑠topsubscript𝑀𝑛𝑠subscript𝐻𝑛𝑠𝑛subscriptΩ𝑛𝑠superscriptsubscript𝑠𝑡subscript¯𝑐𝑛𝑢𝑢𝑠𝑛1subscript𝑟𝑛0\displaystyle\bar{\Delta}_{n}(t)=\vec{r}_{n}(0)^{\top}\left(\frac{\int_{0}^{t}% \Omega_{n}(s)^{\top}\left(M_{n}(s)-\frac{H_{n}(s)}{n}\right)\Omega_{n}(s)\exp% \left(-\int_{s}^{t}\bar{c}_{n}(u)\differential u\right)\differential s}{n-1}% \right)\vec{r}_{n}(0)over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) = over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( divide start_ARG ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT roman_Ω start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_s ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( italic_M start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_s ) - divide start_ARG italic_H start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_s ) end_ARG start_ARG italic_n end_ARG ) roman_Ω start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_s ) roman_exp ( - ∫ start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT over¯ start_ARG italic_c end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_u ) start_DIFFOP roman_d end_DIFFOP italic_u ) start_DIFFOP roman_d end_DIFFOP italic_s end_ARG start_ARG italic_n - 1 end_ARG ) over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 )

Hence, we have

Δ¯n(t)=rn(0)Kn(0,t)rn(0)subscript¯Δ𝑛𝑡subscript𝑟𝑛superscript0topsubscript𝐾𝑛0𝑡subscript𝑟𝑛0\displaystyle\bar{\Delta}_{n}(t)=\vec{r}_{n}(0)^{\top}K_{n}(0,t)\vec{r}_{n}(0)over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) = over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 , italic_t ) over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 )

where

Kn(0,t)=0tΩn(s)(Mn(s)Hn(s)n)Ωn(s)exp(stc¯n(u)du)dsn1.subscript𝐾𝑛0𝑡superscriptsubscript0𝑡subscriptΩ𝑛superscript𝑠topsubscript𝑀𝑛𝑠subscript𝐻𝑛𝑠𝑛subscriptΩ𝑛𝑠superscriptsubscript𝑠𝑡subscript¯𝑐𝑛𝑢𝑢𝑠𝑛1\displaystyle K_{n}(0,t)=\frac{\int_{0}^{t}\Omega_{n}(s)^{\top}\left(M_{n}(s)-% \frac{H_{n}(s)}{n}\right)\Omega_{n}(s)\exp\left(-\int_{s}^{t}\bar{c}_{n}(u)% \differential u\right)\differential s}{n-1}.italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 , italic_t ) = divide start_ARG ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT roman_Ω start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_s ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( italic_M start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_s ) - divide start_ARG italic_H start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_s ) end_ARG start_ARG italic_n end_ARG ) roman_Ω start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_s ) roman_exp ( - ∫ start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT over¯ start_ARG italic_c end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_u ) start_DIFFOP roman_d end_DIFFOP italic_u ) start_DIFFOP roman_d end_DIFFOP italic_s end_ARG start_ARG italic_n - 1 end_ARG .

Now we prove the positive semi-definiteness (PSD) of Kn(0,t)subscript𝐾𝑛0𝑡K_{n}(0,t)italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 , italic_t ) by showing that Mn(s)Hn(s)/nsubscript𝑀𝑛𝑠subscript𝐻𝑛𝑠𝑛M_{n}(s)-H_{n}(s)/nitalic_M start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_s ) - italic_H start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_s ) / italic_n is PSD. For any vector r𝒴n𝑟superscript𝒴𝑛r\in\mathcal{Y}^{n}italic_r ∈ caligraphic_Y start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT, rewrite r𝑟ritalic_r as r=[r1,,rn]𝑟subscript𝑟1subscript𝑟𝑛r=[r_{1},\dots,r_{n}]italic_r = [ italic_r start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_r start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ], where ri𝒴subscript𝑟𝑖𝒴r_{i}\in\mathcal{Y}italic_r start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ caligraphic_Y for all i[n]𝑖delimited-[]𝑛i\in[n]italic_i ∈ [ italic_n ]. Then r(MnHn/n)rsuperscript𝑟topsubscript𝑀𝑛subscript𝐻𝑛𝑛𝑟r^{\top}(M_{n}-H_{n}/n)ritalic_r start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( italic_M start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT - italic_H start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT / italic_n ) italic_r is the trace of covariance of the set of vectors {f(w,xi)ri}i[n]subscript𝑓𝑤subscript𝑥𝑖subscript𝑟𝑖𝑖delimited-[]𝑛\{\nabla f(w,x_{i})r_{i}\}_{i\in[n]}{ ∇ italic_f ( italic_w , italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) italic_r start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i ∈ [ italic_n ] end_POSTSUBSCRIPT, hence non-negative, which implies that Mn(s)Hn(s)/nsubscript𝑀𝑛𝑠subscript𝐻𝑛𝑠𝑛M_{n}(s)-H_{n}(s)/nitalic_M start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_s ) - italic_H start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_s ) / italic_n is PSD for all s𝑠sitalic_s. Hence, the matrix Kn(0,t)subscript𝐾𝑛0𝑡K_{n}(0,t)italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 , italic_t ) is PSD as an integral of PSD matrices.

  • Proof of Lemma 14.

    Let σmax(t)subscript𝜎𝑡\sigma_{\max}(t)italic_σ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ( italic_t ) be the largest singular value of the propagator Ωn(t)subscriptΩ𝑛𝑡\Omega_{n}(t)roman_Ω start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) and u(t)𝑢𝑡u(t)italic_u ( italic_t ) and v(t)𝑣𝑡v(t)italic_v ( italic_t ) be its corresponding left and right singular vectors respectively, i.e.,

    u(t)Ω(t)v(t)=σmax(t),𝑢superscript𝑡topΩ𝑡𝑣𝑡subscript𝜎𝑡u(t)^{\top}\Omega(t)v(t)=\sigma_{\max}(t),italic_u ( italic_t ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT roman_Ω ( italic_t ) italic_v ( italic_t ) = italic_σ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ( italic_t ) ,

    where u(t)u(t)=v(t)v(t)=1𝑢superscript𝑡top𝑢𝑡𝑣superscript𝑡top𝑣𝑡1u(t)^{\top}u(t)=v(t)^{\top}v(t)=1italic_u ( italic_t ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_u ( italic_t ) = italic_v ( italic_t ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_v ( italic_t ) = 1. We first give a bound on σmax(t)subscript𝜎𝑡\sigma_{\max}(t)italic_σ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ( italic_t ) through its evolution.

    dσmax2(t)dtderivative𝑡subscriptsuperscript𝜎2𝑡\displaystyle\derivative{\sigma^{2}_{\max}(t)}{t}divide start_ARG roman_d start_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ( italic_t ) end_ARG end_ARG start_ARG roman_d start_ARG italic_t end_ARG end_ARG =v(t)d(Ω(t)Ω(t))dtv(t)+2dv(t)dtΩ(t)Ω(t)v(t)absent𝑣superscript𝑡topderivative𝑡superscriptΩtop𝑡Ω𝑡𝑣𝑡2derivative𝑡𝑣superscript𝑡topsuperscriptΩtop𝑡Ω𝑡𝑣𝑡\displaystyle=v(t)^{\top}\derivative{(\Omega^{\top}(t)\Omega(t))}{t}v(t)+2% \derivative{v(t)^{\top}}{t}\Omega^{\top}(t)\Omega(t)v(t)= italic_v ( italic_t ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT divide start_ARG roman_d start_ARG ( roman_Ω start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( italic_t ) roman_Ω ( italic_t ) ) end_ARG end_ARG start_ARG roman_d start_ARG italic_t end_ARG end_ARG italic_v ( italic_t ) + 2 divide start_ARG roman_d start_ARG italic_v ( italic_t ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_ARG end_ARG start_ARG roman_d start_ARG italic_t end_ARG end_ARG roman_Ω start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( italic_t ) roman_Ω ( italic_t ) italic_v ( italic_t )
    =v(t)d(Ω(t)Ω(t))dtv(t)absent𝑣superscript𝑡topderivative𝑡superscriptΩtop𝑡Ω𝑡𝑣𝑡\displaystyle=v(t)^{\top}\derivative{(\Omega^{\top}(t)\Omega(t))}{t}v(t)= italic_v ( italic_t ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT divide start_ARG roman_d start_ARG ( roman_Ω start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( italic_t ) roman_Ω ( italic_t ) ) end_ARG end_ARG start_ARG roman_d start_ARG italic_t end_ARG end_ARG italic_v ( italic_t )
    =1nv(t)Ω(t)(Pn(t)+Pn(t))Ω(t)v(t)absent1𝑛𝑣superscript𝑡topΩsuperscript𝑡topsubscript𝑃𝑛𝑡subscript𝑃𝑛superscript𝑡topΩ𝑡𝑣𝑡\displaystyle=-\frac{1}{n}v(t)^{\top}\Omega(t)^{\top}(P_{n}(t)+P_{n}(t)^{\top}% )\Omega(t)v(t)= - divide start_ARG 1 end_ARG start_ARG italic_n end_ARG italic_v ( italic_t ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT roman_Ω ( italic_t ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( italic_P start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) + italic_P start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) roman_Ω ( italic_t ) italic_v ( italic_t )
    =σmax2(t)nu(t)(Pn(t)+Pn(t))u(t)absentsubscriptsuperscript𝜎2𝑡𝑛𝑢superscript𝑡topsubscript𝑃𝑛𝑡subscript𝑃𝑛superscript𝑡top𝑢𝑡\displaystyle=-\frac{\sigma^{2}_{\max}(t)}{n}u(t)^{\top}(P_{n}(t)+P_{n}(t)^{% \top})u(t)= - divide start_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ( italic_t ) end_ARG start_ARG italic_n end_ARG italic_u ( italic_t ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( italic_P start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) + italic_P start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) italic_u ( italic_t )
    2σmax2(t)λmin(t)n.absent2subscriptsuperscript𝜎2𝑡subscript𝜆𝑡𝑛\displaystyle\leq-\frac{2\sigma^{2}_{\max}(t)\lambda_{\min}(t)}{n}.≤ - divide start_ARG 2 italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ( italic_t ) italic_λ start_POSTSUBSCRIPT roman_min end_POSTSUBSCRIPT ( italic_t ) end_ARG start_ARG italic_n end_ARG .

    In the second equality, since Ω(t)v(t)=σmax(t)u(t)Ω𝑡𝑣𝑡subscript𝜎𝑡𝑢𝑡\Omega(t)v(t)=\sigma_{\max}(t)u(t)roman_Ω ( italic_t ) italic_v ( italic_t ) = italic_σ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ( italic_t ) italic_u ( italic_t ), we have

    dv(t)dtΩ(t)Ω(t)v(t)=σmax2(t)dv(t)dtv(t)=12σmax2(t)d(v(t)v(t))dt=0.derivative𝑡𝑣superscript𝑡topsuperscriptΩtop𝑡Ω𝑡𝑣𝑡subscriptsuperscript𝜎2𝑡derivative𝑡𝑣superscript𝑡top𝑣𝑡12subscriptsuperscript𝜎2𝑡derivative𝑡𝑣superscript𝑡top𝑣𝑡0\displaystyle\derivative{v(t)^{\top}}{t}\Omega^{\top}(t)\Omega(t)v(t)=\sigma^{% 2}_{\max}(t)\derivative{v(t)^{\top}}{t}v(t)=\frac{1}{2}\sigma^{2}_{\max}(t)% \derivative{(v(t)^{\top}v(t))}{t}=0.divide start_ARG roman_d start_ARG italic_v ( italic_t ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_ARG end_ARG start_ARG roman_d start_ARG italic_t end_ARG end_ARG roman_Ω start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( italic_t ) roman_Ω ( italic_t ) italic_v ( italic_t ) = italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ( italic_t ) divide start_ARG roman_d start_ARG italic_v ( italic_t ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_ARG end_ARG start_ARG roman_d start_ARG italic_t end_ARG end_ARG italic_v ( italic_t ) = divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ( italic_t ) divide start_ARG roman_d start_ARG ( italic_v ( italic_t ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_v ( italic_t ) ) end_ARG end_ARG start_ARG roman_d start_ARG italic_t end_ARG end_ARG = 0 .

    The third equality follows from the evolution of the propagator dΩ(t)/dt=Pn(t)Ω(t)/nΩ𝑡𝑡subscript𝑃𝑛𝑡Ω𝑡𝑛\differential\Omega(t)/\differential t=-P_{n}(t)\Omega(t)/nstart_DIFFOP roman_d end_DIFFOP roman_Ω ( italic_t ) / start_DIFFOP roman_d end_DIFFOP italic_t = - italic_P start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) roman_Ω ( italic_t ) / italic_n. Note that Ω(0)=IΩ0𝐼\Omega(0)=Iroman_Ω ( 0 ) = italic_I, which implies that σmax(0)=1subscript𝜎01\sigma_{\max}(0)=1italic_σ start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ( 0 ) = 1, hence we have

    σmax2(t)exp(20tλmin(s)nds).subscriptsuperscript𝜎2𝑡2superscriptsubscript0𝑡subscript𝜆𝑠𝑛𝑠\displaystyle\sigma^{2}_{\max}(t)\leq\exp\left(-2\int_{0}^{t}\frac{\lambda_{% \min}(s)}{n}\differential s\right).italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ( italic_t ) ≤ roman_exp ( - 2 ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT divide start_ARG italic_λ start_POSTSUBSCRIPT roman_min end_POSTSUBSCRIPT ( italic_s ) end_ARG start_ARG italic_n end_ARG start_DIFFOP roman_d end_DIFFOP italic_s ) .

    Let A(t)=Ωn(t)(Mn(t)Hn(t)/n)Ωn(t)n1𝐴𝑡subscriptΩ𝑛superscript𝑡topsubscript𝑀𝑛𝑡subscript𝐻𝑛𝑡𝑛subscriptΩ𝑛𝑡𝑛1A(t)=\frac{\Omega_{n}(t)^{\top}(M_{n}(t)-H_{n}(t)/n)\Omega_{n}(t)}{n-1}italic_A ( italic_t ) = divide start_ARG roman_Ω start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( italic_M start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) - italic_H start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) / italic_n ) roman_Ω start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_ARG start_ARG italic_n - 1 end_ARG, c~(s,t)=exp(stc¯(u)du)~𝑐𝑠𝑡superscriptsubscript𝑠𝑡¯𝑐𝑢𝑢\tilde{c}(s,t)=\exp\left(-\int_{s}^{t}\bar{c}(u)\differential u\right)over~ start_ARG italic_c end_ARG ( italic_s , italic_t ) = roman_exp ( - ∫ start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT over¯ start_ARG italic_c end_ARG ( italic_u ) start_DIFFOP roman_d end_DIFFOP italic_u ). Then A(t)2σmax2(t)m(t)ω(t)m(t)subscriptnorm𝐴𝑡2subscriptsuperscript𝜎2𝑡𝑚𝑡𝜔𝑡𝑚𝑡\norm{A(t)}_{2}\leq\sigma^{2}_{\max}(t)m(t)\leq\omega(t)m(t)∥ start_ARG italic_A ( italic_t ) end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≤ italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ( italic_t ) italic_m ( italic_t ) ≤ italic_ω ( italic_t ) italic_m ( italic_t ). Hence we have,

    Kn(0,t2)Kn(0,t1)2subscriptnormsubscript𝐾𝑛0subscript𝑡2subscript𝐾𝑛0subscript𝑡12\displaystyle\norm{K_{n}(0,t_{2})-K_{n}(0,t_{1})}_{2}∥ start_ARG italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 , italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) - italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 , italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT =0t1A(s)(c~(s,t2)c~(s,t1))ds+t1t2A(s)c~(s,t2)ds2absentsubscriptnormsuperscriptsubscript0subscript𝑡1𝐴𝑠~𝑐𝑠subscript𝑡2~𝑐𝑠subscript𝑡1𝑠superscriptsubscriptsubscript𝑡1subscript𝑡2𝐴𝑠~𝑐𝑠subscript𝑡2𝑠2\displaystyle=\norm{\int_{0}^{t_{1}}A(s)\left(\tilde{c}(s,t_{2})-\tilde{c}(s,t% _{1})\right)\differential s+\int_{t_{1}}^{t_{2}}A(s)\tilde{c}(s,t_{2})% \differential s}_{2}= ∥ start_ARG ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_A ( italic_s ) ( over~ start_ARG italic_c end_ARG ( italic_s , italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) - over~ start_ARG italic_c end_ARG ( italic_s , italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ) start_DIFFOP roman_d end_DIFFOP italic_s + ∫ start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_A ( italic_s ) over~ start_ARG italic_c end_ARG ( italic_s , italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) start_DIFFOP roman_d end_DIFFOP italic_s end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT
    0t1A(s)(c~(s,t2)c~(s,t1))2ds+t1t2A(s)c~(s,t2)2dsabsentsuperscriptsubscript0subscript𝑡1subscriptnorm𝐴𝑠~𝑐𝑠subscript𝑡2~𝑐𝑠subscript𝑡12𝑠superscriptsubscriptsubscript𝑡1subscript𝑡2subscriptnorm𝐴𝑠~𝑐𝑠subscript𝑡22𝑠\displaystyle\leq\int_{0}^{t_{1}}\norm{A(s)\left(\tilde{c}(s,t_{2})-\tilde{c}(% s,t_{1})\right)}_{2}\differential s+\int_{t_{1}}^{t_{2}}\norm{A(s)\tilde{c}(s,% t_{2})}_{2}\differential s≤ ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ∥ start_ARG italic_A ( italic_s ) ( over~ start_ARG italic_c end_ARG ( italic_s , italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) - over~ start_ARG italic_c end_ARG ( italic_s , italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ) end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_DIFFOP roman_d end_DIFFOP italic_s + ∫ start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ∥ start_ARG italic_A ( italic_s ) over~ start_ARG italic_c end_ARG ( italic_s , italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_DIFFOP roman_d end_DIFFOP italic_s
    =0t1ω(s)m(s)(c~(s,t1)c~(s,t2))ds+t1t2ω(s)m(s)c~(s,t2)dsabsentsuperscriptsubscript0subscript𝑡1𝜔𝑠𝑚𝑠~𝑐𝑠subscript𝑡1~𝑐𝑠subscript𝑡2𝑠superscriptsubscriptsubscript𝑡1subscript𝑡2𝜔𝑠𝑚𝑠~𝑐𝑠subscript𝑡2𝑠\displaystyle=\int_{0}^{t_{1}}\omega(s)m(s)\left(\tilde{c}(s,t_{1})-\tilde{c}(% s,t_{2})\right)\differential s+\int_{t_{1}}^{t_{2}}\omega(s)m(s)\tilde{c}(s,t_% {2})\differential s= ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_ω ( italic_s ) italic_m ( italic_s ) ( over~ start_ARG italic_c end_ARG ( italic_s , italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) - over~ start_ARG italic_c end_ARG ( italic_s , italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) ) start_DIFFOP roman_d end_DIFFOP italic_s + ∫ start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_ω ( italic_s ) italic_m ( italic_s ) over~ start_ARG italic_c end_ARG ( italic_s , italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) start_DIFFOP roman_d end_DIFFOP italic_s

    For the first term, |ω(s)m(s)(c~(s,t2)c~(s,t1))|2ω(s)m(s)𝜔𝑠𝑚𝑠~𝑐𝑠subscript𝑡2~𝑐𝑠subscript𝑡12𝜔𝑠𝑚𝑠|\omega(s)m(s)(\tilde{c}(s,t_{2})-\tilde{c}(s,t_{1}))|\leq 2\omega(s)m(s)| italic_ω ( italic_s ) italic_m ( italic_s ) ( over~ start_ARG italic_c end_ARG ( italic_s , italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) - over~ start_ARG italic_c end_ARG ( italic_s , italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ) | ≤ 2 italic_ω ( italic_s ) italic_m ( italic_s ). By the integrability of ω(s)m(s)𝜔𝑠𝑚𝑠\omega(s)m(s)italic_ω ( italic_s ) italic_m ( italic_s ), and the dominated convergence theorem (DCT),

    limt1,t20t1ω(s)m(s)(c~(s,t1)c~(s,t2))ds=0limt1,t2ω(s)m(s)(c~(s,t1)c~(s,t2))1[0,t1](s)ds=0.subscriptsubscript𝑡1subscript𝑡2superscriptsubscript0subscript𝑡1𝜔𝑠𝑚𝑠~𝑐𝑠subscript𝑡1~𝑐𝑠subscript𝑡2𝑠superscriptsubscript0subscriptsubscript𝑡1subscript𝑡2𝜔𝑠𝑚𝑠~𝑐𝑠subscript𝑡1~𝑐𝑠subscript𝑡2subscript10subscript𝑡1𝑠𝑠0\displaystyle\lim_{t_{1},t_{2}\to\infty}\int_{0}^{t_{1}}\omega(s)m(s)(\tilde{c% }(s,t_{1})-\tilde{c}(s,t_{2}))\differential s=\int_{0}^{\infty}\lim_{t_{1},t_{% 2}\to\infty}\omega(s)m(s)(\tilde{c}(s,t_{1})-\tilde{c}(s,t_{2}))1_{[0,t_{1}]}(% s)\differential s=0.roman_lim start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT → ∞ end_POSTSUBSCRIPT ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_ω ( italic_s ) italic_m ( italic_s ) ( over~ start_ARG italic_c end_ARG ( italic_s , italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) - over~ start_ARG italic_c end_ARG ( italic_s , italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) ) start_DIFFOP roman_d end_DIFFOP italic_s = ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT roman_lim start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT → ∞ end_POSTSUBSCRIPT italic_ω ( italic_s ) italic_m ( italic_s ) ( over~ start_ARG italic_c end_ARG ( italic_s , italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) - over~ start_ARG italic_c end_ARG ( italic_s , italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) ) 1 start_POSTSUBSCRIPT [ 0 , italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ] end_POSTSUBSCRIPT ( italic_s ) start_DIFFOP roman_d end_DIFFOP italic_s = 0 .

    Note that the existence of limtc~(s,t)subscript𝑡~𝑐𝑠𝑡\lim_{t\to\infty}\tilde{c}(s,t)roman_lim start_POSTSUBSCRIPT italic_t → ∞ end_POSTSUBSCRIPT over~ start_ARG italic_c end_ARG ( italic_s , italic_t ), which is guaranteed by c¯n(t)0subscript¯𝑐𝑛𝑡0\bar{c}_{n}(t)\geq 0over¯ start_ARG italic_c end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) ≥ 0, and the uniform boundedness of ω(t)m(t)𝜔𝑡𝑚𝑡\omega(t)m(t)italic_ω ( italic_t ) italic_m ( italic_t ), indicates that the limit of the product function being 0 in the second equality. For the second term,

    t1t2ω(s)m(s)c~(s,t2)dst1t2ω(s)m(s)ds0superscriptsubscriptsubscript𝑡1subscript𝑡2𝜔𝑠𝑚𝑠~𝑐𝑠subscript𝑡2𝑠superscriptsubscriptsubscript𝑡1subscript𝑡2𝜔𝑠𝑚𝑠𝑠0\displaystyle\int_{t_{1}}^{t_{2}}\omega(s)m(s)\tilde{c}(s,t_{2})\differential s% \leq\int_{t_{1}}^{t_{2}}\omega(s)m(s)\differential s\to 0∫ start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_ω ( italic_s ) italic_m ( italic_s ) over~ start_ARG italic_c end_ARG ( italic_s , italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) start_DIFFOP roman_d end_DIFFOP italic_s ≤ ∫ start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_ω ( italic_s ) italic_m ( italic_s ) start_DIFFOP roman_d end_DIFFOP italic_s → 0

    as t1,t2subscript𝑡1subscript𝑡2t_{1},t_{2}\to\inftyitalic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT → ∞ by condition (1). Hence, Kn(0,t2)Kn(0,t1)20subscriptnormsubscript𝐾𝑛0subscript𝑡2subscript𝐾𝑛0subscript𝑡120\norm{K_{n}(0,t_{2})-K_{n}(0,t_{1})}_{2}\to 0∥ start_ARG italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 , italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) - italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 , italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT → 0 as t1,t20subscript𝑡1subscript𝑡20t_{1},t_{2}\to 0italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT → 0, which shows the existence of limtKn(0,t)subscript𝑡subscript𝐾𝑛0𝑡\lim_{t\to\infty}K_{n}(0,t)roman_lim start_POSTSUBSCRIPT italic_t → ∞ end_POSTSUBSCRIPT italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 , italic_t ) in 2-norm of matrix.

Remark 24.

Sometimes the effective Gram matrix calculated from the propagator derived from Pn(t)subscript𝑃𝑛𝑡P_{n}(t)italic_P start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) is not convergent, but in this case, we can create a perturbed version of Pn(t)subscript𝑃𝑛𝑡P_{n}(t)italic_P start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) with controlled smallest eigenvalue of (Pn(t)+Pn(t))/2subscript𝑃𝑛𝑡subscript𝑃𝑛superscript𝑡top2(P_{n}(t)+P_{n}(t)^{\top})/2( italic_P start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) + italic_P start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) / 2, which guarantees the convergence of limtKn(0,t)subscript𝑡subscript𝐾𝑛0𝑡\lim_{t\to\infty}K_{n}(0,t)roman_lim start_POSTSUBSCRIPT italic_t → ∞ end_POSTSUBSCRIPT italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 , italic_t ) while preserving the trajectory of rn(t)subscript𝑟𝑛𝑡\vec{r}_{n}(t)over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) given r0(t)subscript𝑟0𝑡\vec{r}_{0}(t)over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( italic_t ). For example, in Section 3.5 we construct Pnε(t)superscriptsubscript𝑃𝑛𝜀𝑡P_{n}^{\varepsilon}(t)italic_P start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_ε end_POSTSUPERSCRIPT ( italic_t ) as a perturbed version of Pn(t)subscript𝑃𝑛𝑡P_{n}(t)italic_P start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ).

A.7 Calculations for the regression example

The gradient for the averaged loss ¯(w,Sn)¯𝑤subscript𝑆𝑛\bar{\ell}(w,S_{n})over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) and ¯(w,Sni)¯𝑤superscriptsubscript𝑆𝑛𝑖\bar{\ell}(w,S_{n}^{-i})over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ) are

¯(w,Sn)¯𝑤subscript𝑆𝑛\displaystyle\nabla\bar{\ell}(w,S_{n})∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) =12(wx1y1)x1+12(wx2y2)x2absent12superscript𝑤topsubscript𝑥1subscript𝑦1subscript𝑥112superscript𝑤topsubscript𝑥2subscript𝑦2subscript𝑥2\displaystyle=\frac{1}{2}\left(w^{\top}x_{1}-y_{1}\right)x_{1}+\frac{1}{2}% \left(w^{\top}x_{2}-y_{2}\right)x_{2}= divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( italic_w start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( italic_w start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - italic_y start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT
¯(w,Sn1)¯𝑤superscriptsubscript𝑆𝑛1\displaystyle\nabla\bar{\ell}(w,S_{n}^{-1})∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ) =n22(n1)(wx1y1)x1+n2(n1)(wx2y2)x2absent𝑛22𝑛1superscript𝑤topsubscript𝑥1subscript𝑦1subscript𝑥1𝑛2𝑛1superscript𝑤topsubscript𝑥2subscript𝑦2subscript𝑥2\displaystyle=\frac{n-2}{2(n-1)}\left(w^{\top}x_{1}-y_{1}\right)x_{1}+\frac{n}% {2(n-1)}\left(w^{\top}x_{2}-y_{2}\right)x_{2}= divide start_ARG italic_n - 2 end_ARG start_ARG 2 ( italic_n - 1 ) end_ARG ( italic_w start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + divide start_ARG italic_n end_ARG start_ARG 2 ( italic_n - 1 ) end_ARG ( italic_w start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - italic_y start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT
¯(w,Sn2)¯𝑤superscriptsubscript𝑆𝑛2\displaystyle\nabla\bar{\ell}(w,S_{n}^{-2})∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 2 end_POSTSUPERSCRIPT ) =n2(n1)(wx1y1)x1+n22(n1)(wx2y2)x2.absent𝑛2𝑛1superscript𝑤topsubscript𝑥1subscript𝑦1subscript𝑥1𝑛22𝑛1superscript𝑤topsubscript𝑥2subscript𝑦2subscript𝑥2\displaystyle=\frac{n}{2(n-1)}\left(w^{\top}x_{1}-y_{1}\right)x_{1}+\frac{n-2}% {2(n-1)}\left(w^{\top}x_{2}-y_{2}\right)x_{2}.= divide start_ARG italic_n end_ARG start_ARG 2 ( italic_n - 1 ) end_ARG ( italic_w start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + divide start_ARG italic_n - 2 end_ARG start_ARG 2 ( italic_n - 1 ) end_ARG ( italic_w start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - italic_y start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT .

The averaged contraction factor is

c¯n(t)subscript¯𝑐𝑛𝑡\displaystyle\bar{c}_{n}(t)over¯ start_ARG italic_c end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) =1ni=1n(w,zi)¯(w,Sni)|wn(t)wni(t)Δ¯n(t)absentevaluated-at1𝑛superscriptsubscript𝑖1𝑛𝑤subscript𝑧𝑖¯𝑤superscriptsubscript𝑆𝑛𝑖subscript𝑤𝑛𝑡superscriptsubscript𝑤𝑛𝑖𝑡subscript¯Δ𝑛𝑡\displaystyle=\frac{\frac{1}{n}\sum_{i=1}^{n}\nabla\ell(w,z_{i})\cdot\nabla% \bar{\ell}(w,S_{n}^{-i})\big{|}^{w_{n}^{-i}(t)}_{w_{n}(t)}}{\bar{\Delta}_{n}(t)}= divide start_ARG divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∇ roman_ℓ ( italic_w , italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ⋅ ∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ) | start_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_POSTSUBSCRIPT end_ARG start_ARG over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_ARG
=12((w,z1)¯(w,Sn1)|wn(t)wn1(t)+(w,z2)¯(w,Sn2)|wn(t)wn2(t))12(12(wx1y1)2|wn(t)wn1(t)+12(wx2y2)2|wn(t)wn2(t))absent12evaluated-at𝑤subscript𝑧1¯𝑤superscriptsubscript𝑆𝑛1subscript𝑤𝑛𝑡superscriptsubscript𝑤𝑛1𝑡evaluated-at𝑤subscript𝑧2¯𝑤superscriptsubscript𝑆𝑛2subscript𝑤𝑛𝑡superscriptsubscript𝑤𝑛2𝑡12evaluated-at12superscriptsuperscript𝑤topsubscript𝑥1subscript𝑦12subscript𝑤𝑛𝑡superscriptsubscript𝑤𝑛1𝑡evaluated-at12superscriptsuperscript𝑤topsubscript𝑥2subscript𝑦22subscript𝑤𝑛𝑡superscriptsubscript𝑤𝑛2𝑡\displaystyle=\frac{\frac{1}{2}\left(\nabla\ell(w,z_{1})\cdot\nabla\bar{\ell}(% w,S_{n}^{-1})\big{|}^{w_{n}^{-1}(t)}_{w_{n}(t)}+\nabla\ell(w,z_{2})\cdot\nabla% \bar{\ell}(w,S_{n}^{-2})\big{|}^{w_{n}^{-2}(t)}_{w_{n}(t)}\right)}{\frac{1}{2}% \left(\frac{1}{2}(w^{\top}x_{1}-y_{1})^{2}\big{|}^{w_{n}^{-1}(t)}_{w_{n}(t)}+% \frac{1}{2}(w^{\top}x_{2}-y_{2})^{2}\big{|}^{w_{n}^{-2}(t)}_{w_{n}(t)}\right)}= divide start_ARG divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( ∇ roman_ℓ ( italic_w , italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ⋅ ∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ) | start_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_POSTSUBSCRIPT + ∇ roman_ℓ ( italic_w , italic_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) ⋅ ∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 2 end_POSTSUPERSCRIPT ) | start_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 2 end_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_POSTSUBSCRIPT ) end_ARG start_ARG divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( italic_w start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT | start_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_POSTSUBSCRIPT + divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( italic_w start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - italic_y start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT | start_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 2 end_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_POSTSUBSCRIPT ) end_ARG
=12(n22(n1)(wx1y1)2|wn(t)wn1(t)+n22(n1)(wx2y2)2|wn(t)wn2(t))12(12(wx1y1)2|wn(t)wn1(t)+12(wx2y2)2|wn(t)wn2(t))=n22(n1).absent12evaluated-at𝑛22𝑛1superscriptsuperscript𝑤topsubscript𝑥1subscript𝑦12subscript𝑤𝑛𝑡superscriptsubscript𝑤𝑛1𝑡evaluated-at𝑛22𝑛1superscriptsuperscript𝑤topsubscript𝑥2subscript𝑦22subscript𝑤𝑛𝑡superscriptsubscript𝑤𝑛2𝑡12evaluated-at12superscriptsuperscript𝑤topsubscript𝑥1subscript𝑦12subscript𝑤𝑛𝑡superscriptsubscript𝑤𝑛1𝑡evaluated-at12superscriptsuperscript𝑤topsubscript𝑥2subscript𝑦22subscript𝑤𝑛𝑡superscriptsubscript𝑤𝑛2𝑡𝑛22𝑛1\displaystyle=\frac{\frac{1}{2}\left(\frac{n-2}{2(n-1)}(w^{\top}x_{1}-y_{1})^{% 2}\big{|}^{w_{n}^{-1}(t)}_{w_{n}(t)}+\frac{n-2}{2(n-1)}(w^{\top}x_{2}-y_{2})^{% 2}\big{|}^{w_{n}^{-2}(t)}_{w_{n}(t)}\right)}{\frac{1}{2}\left(\frac{1}{2}(w^{% \top}x_{1}-y_{1})^{2}\big{|}^{w_{n}^{-1}(t)}_{w_{n}(t)}+\frac{1}{2}(w^{\top}x_% {2}-y_{2})^{2}\big{|}^{w_{n}^{-2}(t)}_{w_{n}(t)}\right)}=\frac{n-2}{2(n-1)}.= divide start_ARG divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( divide start_ARG italic_n - 2 end_ARG start_ARG 2 ( italic_n - 1 ) end_ARG ( italic_w start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT | start_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_POSTSUBSCRIPT + divide start_ARG italic_n - 2 end_ARG start_ARG 2 ( italic_n - 1 ) end_ARG ( italic_w start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - italic_y start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT | start_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 2 end_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_POSTSUBSCRIPT ) end_ARG start_ARG divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( italic_w start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT | start_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_POSTSUBSCRIPT + divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( italic_w start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - italic_y start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT | start_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 2 end_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_POSTSUBSCRIPT ) end_ARG = divide start_ARG italic_n - 2 end_ARG start_ARG 2 ( italic_n - 1 ) end_ARG .

The propagator Ωnε(t)superscriptsubscriptΩ𝑛𝜀𝑡\Omega_{n}^{\varepsilon}(t)roman_Ω start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_ε end_POSTSUPERSCRIPT ( italic_t ) of the evolution drn(t)/dt=Pnε(t)rn(t)/nsubscript𝑟𝑛𝑡𝑡superscriptsubscript𝑃𝑛𝜀𝑡subscript𝑟𝑛𝑡𝑛\differential\vec{r}_{n}(t)/\differential t=-P_{n}^{\varepsilon}(t)\vec{r}_{n}% (t)/nstart_DIFFOP roman_d end_DIFFOP over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) / start_DIFFOP roman_d end_DIFFOP italic_t = - italic_P start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_ε end_POSTSUPERSCRIPT ( italic_t ) over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) / italic_n is

Ωnε(t)=exp(0tPnε(s)dsn)=Uexp(0tΛε(s)dsn)U.superscriptsubscriptΩ𝑛𝜀𝑡superscriptsubscript0𝑡superscriptsubscript𝑃𝑛𝜀𝑠𝑠𝑛𝑈superscriptsubscript0𝑡superscriptΛ𝜀𝑠𝑠𝑛superscript𝑈top\displaystyle\Omega_{n}^{\varepsilon}(t)=\exp\left(-\frac{\int_{0}^{t}P_{n}^{% \varepsilon}(s)\differential s}{n}\right)=U\exp\left(-\frac{\int_{0}^{t}% \Lambda^{\varepsilon}(s)\differential s}{n}\right)U^{\top}.roman_Ω start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_ε end_POSTSUPERSCRIPT ( italic_t ) = roman_exp ( - divide start_ARG ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_P start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_ε end_POSTSUPERSCRIPT ( italic_s ) start_DIFFOP roman_d end_DIFFOP italic_s end_ARG start_ARG italic_n end_ARG ) = italic_U roman_exp ( - divide start_ARG ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT roman_Λ start_POSTSUPERSCRIPT italic_ε end_POSTSUPERSCRIPT ( italic_s ) start_DIFFOP roman_d end_DIFFOP italic_s end_ARG start_ARG italic_n end_ARG ) italic_U start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT .

Hence, the effective metric Kn(0,t)subscript𝐾𝑛0𝑡K_{n}(0,t)italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 , italic_t ) can be calculated easily as

Kn(0,t)subscript𝐾𝑛0𝑡\displaystyle K_{n}(0,t)italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 , italic_t ) =0tUexp(20sΛε(u)dun)(IΛε(s)n)Uexp((ts)c¯)dsn1absentsuperscriptsubscript0𝑡𝑈2superscriptsubscript0𝑠superscriptΛ𝜀𝑢𝑢𝑛𝐼superscriptΛ𝜀𝑠𝑛superscript𝑈top𝑡𝑠¯𝑐𝑠𝑛1\displaystyle=\frac{\int_{0}^{t}U\exp\left(-\frac{2\int_{0}^{s}\Lambda^{% \varepsilon}(u)\differential u}{n}\right)\left(I-\frac{\Lambda^{\varepsilon}(s% )}{n}\right)U^{\top}\exp(-(t-s)\bar{c})\differential s}{n-1}= divide start_ARG ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_U roman_exp ( - divide start_ARG 2 ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT roman_Λ start_POSTSUPERSCRIPT italic_ε end_POSTSUPERSCRIPT ( italic_u ) start_DIFFOP roman_d end_DIFFOP italic_u end_ARG start_ARG italic_n end_ARG ) ( italic_I - divide start_ARG roman_Λ start_POSTSUPERSCRIPT italic_ε end_POSTSUPERSCRIPT ( italic_s ) end_ARG start_ARG italic_n end_ARG ) italic_U start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT roman_exp ( start_ARG - ( italic_t - italic_s ) over¯ start_ARG italic_c end_ARG end_ARG ) start_DIFFOP roman_d end_DIFFOP italic_s end_ARG start_ARG italic_n - 1 end_ARG
=UΛK(t)Uabsent𝑈superscriptΛ𝐾𝑡superscript𝑈top\displaystyle=U\Lambda^{K}(t)U^{\top}= italic_U roman_Λ start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT ( italic_t ) italic_U start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT

where ΛK(t)=[λ1K(t),,λnK(t)]superscriptΛ𝐾𝑡subscriptsuperscript𝜆𝐾1𝑡subscriptsuperscript𝜆𝐾𝑛𝑡\Lambda^{K}(t)=[\lambda^{K}_{1}(t),\dots,\lambda^{K}_{n}(t)]roman_Λ start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT ( italic_t ) = [ italic_λ start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_t ) , … , italic_λ start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) ], and we have

λ1K(t)=λ2K(t)=λ(t)subscriptsuperscript𝜆𝐾1𝑡subscriptsuperscript𝜆𝐾2𝑡𝜆𝑡\displaystyle\lambda^{K}_{1}(t)=\lambda^{K}_{2}(t)=\lambda(t)italic_λ start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_t ) = italic_λ start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_t ) = italic_λ ( italic_t ) :=(1c¯)1/2n1(exp(c¯t)exp(t)),assignabsentsuperscript1¯𝑐12𝑛1¯𝑐𝑡𝑡\displaystyle:=\frac{(1-\bar{c})^{-1}/2}{n-1}\left(\exp(-\bar{c}t)-\exp(t)% \right),:= divide start_ARG ( 1 - over¯ start_ARG italic_c end_ARG ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT / 2 end_ARG start_ARG italic_n - 1 end_ARG ( roman_exp ( start_ARG - over¯ start_ARG italic_c end_ARG italic_t end_ARG ) - roman_exp ( start_ARG italic_t end_ARG ) ) ,
λ3K(t)==λnK(t)=λ(t)subscriptsuperscript𝜆𝐾3𝑡subscriptsuperscript𝜆𝐾𝑛𝑡superscript𝜆𝑡\displaystyle\lambda^{K}_{3}(t)=\dots=\lambda^{K}_{n}(t)=\lambda^{\prime}(t)italic_λ start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ( italic_t ) = ⋯ = italic_λ start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) = italic_λ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_t ) :=0texp(0sε(u)du)exp((ts)c¯)(1ε(s)2)dsn1assignabsentsuperscriptsubscript0𝑡superscriptsubscript0𝑠𝜀𝑢𝑢𝑡𝑠¯𝑐1𝜀𝑠2𝑠𝑛1\displaystyle:=\frac{\int_{0}^{t}\exp\left(-\int_{0}^{s}\varepsilon(u)% \differential u\right)\cdot\exp\left(-(t-s)\bar{c}\right)\cdot\left(1-\frac{% \varepsilon(s)}{2}\right)\differential s}{n-1}:= divide start_ARG ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT roman_exp ( - ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT italic_ε ( italic_u ) start_DIFFOP roman_d end_DIFFOP italic_u ) ⋅ roman_exp ( - ( italic_t - italic_s ) over¯ start_ARG italic_c end_ARG ) ⋅ ( 1 - divide start_ARG italic_ε ( italic_s ) end_ARG start_ARG 2 end_ARG ) start_DIFFOP roman_d end_DIFFOP italic_s end_ARG start_ARG italic_n - 1 end_ARG

For ε(t)=ε¯(1[0,1](t)+1[1,](t)/t2)𝜀𝑡¯𝜀subscript101𝑡subscript11𝑡superscript𝑡2\varepsilon(t)=\bar{\varepsilon}\left(1_{[0,1]}(t)+1_{[1,\infty]}(t)/t^{2}\right)italic_ε ( italic_t ) = over¯ start_ARG italic_ε end_ARG ( 1 start_POSTSUBSCRIPT [ 0 , 1 ] end_POSTSUBSCRIPT ( italic_t ) + 1 start_POSTSUBSCRIPT [ 1 , ∞ ] end_POSTSUBSCRIPT ( italic_t ) / italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ), exp(0sε(u)du)[exp(2ε¯),1]superscriptsubscript0𝑠𝜀𝑢𝑢2¯𝜀1\exp\left(-\int_{0}^{s}\varepsilon(u)\differential u\right)\in[\exp(-2\bar{% \varepsilon}),1]roman_exp ( - ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT italic_ε ( italic_u ) start_DIFFOP roman_d end_DIFFOP italic_u ) ∈ [ roman_exp ( start_ARG - 2 over¯ start_ARG italic_ε end_ARG end_ARG ) , 1 ], 1ε(s)/2[1ε¯/2,1]1𝜀𝑠21¯𝜀211-\varepsilon(s)/2\in[1-\bar{\varepsilon}/2,1]1 - italic_ε ( italic_s ) / 2 ∈ [ 1 - over¯ start_ARG italic_ε end_ARG / 2 , 1 ], hence, for ε¯¯𝜀\bar{\varepsilon}over¯ start_ARG italic_ε end_ARG small enough, 1exp(c¯t)2c¯(n1)λ(t)1exp(c¯t)c¯(n1)1¯𝑐𝑡2¯𝑐𝑛1superscript𝜆𝑡1¯𝑐𝑡¯𝑐𝑛1\frac{1-\exp\left(-\bar{c}t\right)}{2\bar{c}(n-1)}\leq\lambda^{\prime}(t)\leq% \frac{1-\exp\left(-\bar{c}t\right)}{\bar{c}(n-1)}divide start_ARG 1 - roman_exp ( - over¯ start_ARG italic_c end_ARG italic_t ) end_ARG start_ARG 2 over¯ start_ARG italic_c end_ARG ( italic_n - 1 ) end_ARG ≤ italic_λ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_t ) ≤ divide start_ARG 1 - roman_exp ( - over¯ start_ARG italic_c end_ARG italic_t ) end_ARG start_ARG over¯ start_ARG italic_c end_ARG ( italic_n - 1 ) end_ARG. Hence, we have λ(t)=Θ(exp(c¯t))𝜆𝑡Θ¯𝑐𝑡\lambda(t)=\Theta(\exp(-\bar{c}t))italic_λ ( italic_t ) = roman_Θ ( roman_exp ( start_ARG - over¯ start_ARG italic_c end_ARG italic_t end_ARG ) ), λ(t)=Θ(1)superscript𝜆𝑡Θ1\lambda^{\prime}(t)=\Theta(1)italic_λ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_t ) = roman_Θ ( 1 ).

Here we also calculate the solution of wn(t)subscript𝑤𝑛𝑡w_{n}(t)italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) and wni(t)superscriptsubscript𝑤𝑛𝑖𝑡w_{n}^{-i}(t)italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT ( italic_t ) although not required in the derivation of the effective gram matrix.

wn(t)subscript𝑤𝑛𝑡\displaystyle w_{n}(t)italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) =0texp((ts)(12x1x1+12x2x2))(12y1x1+12y2x2)𝑑sabsentsuperscriptsubscript0𝑡𝑡𝑠12subscript𝑥1superscriptsubscript𝑥1top12subscript𝑥2superscriptsubscript𝑥2top12subscript𝑦1subscript𝑥112subscript𝑦2subscript𝑥2differential-d𝑠\displaystyle=\int_{0}^{t}\exp\left(-(t-s)\left(\frac{1}{2}x_{1}x_{1}^{\top}+% \frac{1}{2}x_{2}x_{2}^{\top}\right)\right)\left(\frac{1}{2}y_{1}x_{1}+\frac{1}% {2}y_{2}x_{2}\right)ds= ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT roman_exp ( - ( italic_t - italic_s ) ( divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT + divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) ) ( divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_y start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) italic_d italic_s
=(1+exp(t/2))(y1x1)+(1+exp(t/2))(y2x2)absent1𝑡2subscript𝑦1subscript𝑥11𝑡2subscript𝑦2subscript𝑥2\displaystyle=\left(1+\exp(-t/2)\right)\cdot(y_{1}x_{1})+\left(1+\exp(-t/2)% \right)\cdot(y_{2}x_{2})= ( 1 + roman_exp ( start_ARG - italic_t / 2 end_ARG ) ) ⋅ ( italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) + ( 1 + roman_exp ( start_ARG - italic_t / 2 end_ARG ) ) ⋅ ( italic_y start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT )

Similarly, we have

wni(t)subscriptsuperscript𝑤𝑖𝑛𝑡\displaystyle w^{-i}_{n}(t)italic_w start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) =(1exp(n22(n1)t))(ykxk)+(1exp(n2(n1)t))(ylxl)absent1𝑛22𝑛1𝑡subscript𝑦𝑘subscript𝑥𝑘1𝑛2𝑛1𝑡subscript𝑦𝑙subscript𝑥𝑙\displaystyle=\left(1-\exp(-\frac{n-2}{2(n-1)}t)\right)\cdot(y_{k}x_{k})+\left% (1-\exp(-\frac{n}{2(n-1)}t)\right)\cdot(y_{l}x_{l})= ( 1 - roman_exp ( start_ARG - divide start_ARG italic_n - 2 end_ARG start_ARG 2 ( italic_n - 1 ) end_ARG italic_t end_ARG ) ) ⋅ ( italic_y start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) + ( 1 - roman_exp ( start_ARG - divide start_ARG italic_n end_ARG start_ARG 2 ( italic_n - 1 ) end_ARG italic_t end_ARG ) ) ⋅ ( italic_y start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT italic_x start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT )

where k=2i/n𝑘2𝑖𝑛k=\lceil 2i/n\rceilitalic_k = ⌈ 2 italic_i / italic_n ⌉, l{1,2}/{k}𝑙12𝑘l\in\{1,2\}/\{k\}italic_l ∈ { 1 , 2 } / { italic_k }. We can see that when trained on the dataset Snisuperscriptsubscript𝑆𝑛𝑖S_{n}^{-i}italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - italic_i end_POSTSUPERSCRIPT, the progress on the direction x2i/nsubscript𝑥2𝑖𝑛x_{\lceil 2i/n\rceil}italic_x start_POSTSUBSCRIPT ⌈ 2 italic_i / italic_n ⌉ end_POSTSUBSCRIPT is slightly less than the other direction, which introduces the non-zero averaged loss difference Δ¯n(t)subscript¯Δ𝑛𝑡\bar{\Delta}_{n}(t)over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) during training.

A.8 Approximation of the contraction factor c¯nsubscript¯𝑐𝑛\bar{c}_{n}over¯ start_ARG italic_c end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT

In this section, we analyze the averaged version of the batch-wise contraction factor as introduced in Section B.1. By the fundamental theorem of calculus, we have the following expression of the numerator and denominator of the averaged contraction factor c¯nsubscript¯𝑐𝑛\bar{c}_{n}over¯ start_ARG italic_c end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT.

¯(w,S(m))¯(w,Sn(m))|wn(t)wn(m)(t)evaluated-at¯𝑤subscript𝑆𝑚¯𝑤superscriptsubscript𝑆𝑛𝑚subscript𝑤𝑛𝑡superscriptsubscript𝑤𝑛𝑚𝑡\displaystyle\nabla\bar{\ell}(w,S_{(m)})\cdot\nabla\bar{\ell}(w,S_{n}^{-(m)})% \big{|}^{w_{n}^{-(m)}(t)}_{w_{n}(t)}∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT ( italic_m ) end_POSTSUBSCRIPT ) ⋅ ∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - ( italic_m ) end_POSTSUPERSCRIPT ) | start_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - ( italic_m ) end_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_POSTSUBSCRIPT
=\displaystyle== 01(¯(w,Sn(m))2¯(w,S(m))+¯(w,S(m))2¯(w,Sn(m))|w=h(u))(wn(m)(t)wn(t))du,superscriptsubscript01¯𝑤superscriptsubscript𝑆𝑛𝑚superscript2¯𝑤subscript𝑆𝑚evaluated-at¯𝑤subscript𝑆𝑚superscript2¯𝑤superscriptsubscript𝑆𝑛𝑚𝑤𝑢superscriptsubscript𝑤𝑛𝑚𝑡subscript𝑤𝑛𝑡𝑢\displaystyle\int_{0}^{1}\left(\nabla\bar{\ell}(w,S_{n}^{-(m)})\nabla^{2}\bar{% \ell}(w,S_{(m)})+\nabla\bar{\ell}(w,S_{(m)})\nabla^{2}\bar{\ell}(w,S_{n}^{-(m)% })\big{|}_{w=h(u)}\right)\cdot\left(w_{n}^{-(m)}(t)-w_{n}(t)\right)% \differential u,∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ( ∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - ( italic_m ) end_POSTSUPERSCRIPT ) ∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT ( italic_m ) end_POSTSUBSCRIPT ) + ∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT ( italic_m ) end_POSTSUBSCRIPT ) ∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - ( italic_m ) end_POSTSUPERSCRIPT ) | start_POSTSUBSCRIPT italic_w = italic_h ( italic_u ) end_POSTSUBSCRIPT ) ⋅ ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - ( italic_m ) end_POSTSUPERSCRIPT ( italic_t ) - italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) ) start_DIFFOP roman_d end_DIFFOP italic_u ,
Δ¯n(m)(t)=¯(w,S(m))|wn(t)wn(m)(t)=01¯(w,S(m))|w=h(u)(wn(m)(t)wn(t))du.superscriptsubscript¯Δ𝑛𝑚𝑡evaluated-at¯𝑤subscript𝑆𝑚subscript𝑤𝑛𝑡superscriptsubscript𝑤𝑛𝑚𝑡evaluated-atsuperscriptsubscript01¯𝑤subscript𝑆𝑚𝑤𝑢superscriptsubscript𝑤𝑛𝑚𝑡subscript𝑤𝑛𝑡𝑢\displaystyle\bar{\Delta}_{n}^{-(m)}(t)=\bar{\ell}(w,S_{(m)})\big{|}^{w_{n}^{-% (m)}(t)}_{w_{n}(t)}=\int_{0}^{1}\nabla\bar{\ell}(w,S_{(m)})\big{|}_{w=h(u)}% \cdot\left(w_{n}^{-(m)}(t)-w_{n}(t)\right)\differential u.over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - ( italic_m ) end_POSTSUPERSCRIPT ( italic_t ) = over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT ( italic_m ) end_POSTSUBSCRIPT ) | start_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - ( italic_m ) end_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_POSTSUBSCRIPT = ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT ( italic_m ) end_POSTSUBSCRIPT ) | start_POSTSUBSCRIPT italic_w = italic_h ( italic_u ) end_POSTSUBSCRIPT ⋅ ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - ( italic_m ) end_POSTSUPERSCRIPT ( italic_t ) - italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) ) start_DIFFOP roman_d end_DIFFOP italic_u .

where h(u)=wn(t)+u(wn(m)(t)wn(t))𝑢subscript𝑤𝑛𝑡𝑢superscriptsubscript𝑤𝑛𝑚𝑡subscript𝑤𝑛𝑡h(u)=w_{n}(t)+u(w_{n}^{-(m)}(t)-w_{n}(t))italic_h ( italic_u ) = italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) + italic_u ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - ( italic_m ) end_POSTSUPERSCRIPT ( italic_t ) - italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) ), u[0,1]𝑢01u\in[0,1]italic_u ∈ [ 0 , 1 ] is the line segment intersecting wn(t)subscript𝑤𝑛𝑡w_{n}(t)italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) and wn(m)(t)superscriptsubscript𝑤𝑛𝑚𝑡w_{n}^{-(m)}(t)italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - ( italic_m ) end_POSTSUPERSCRIPT ( italic_t ).

We approximate ¯(w,S(m))¯𝑤subscript𝑆𝑚\nabla\bar{\ell}(w,S_{(m)})∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT ( italic_m ) end_POSTSUBSCRIPT ), ¯(w,S(m))¯𝑤subscript𝑆𝑚\nabla\bar{\ell}(w,S_{(m)})∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT ( italic_m ) end_POSTSUBSCRIPT ) by ¯(w,Sn)¯𝑤subscript𝑆𝑛\nabla\bar{\ell}(w,S_{n})∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ), approximate 2¯(w,S(m))superscript2¯𝑤subscript𝑆𝑚\nabla^{2}\bar{\ell}(w,S_{(m)})∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT ( italic_m ) end_POSTSUBSCRIPT ), 2¯(w,S(m))superscript2¯𝑤subscript𝑆𝑚\nabla^{2}\bar{\ell}(w,S_{(m)})∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT ( italic_m ) end_POSTSUBSCRIPT ) by 2¯(w,Sn)superscript2¯𝑤subscript𝑆𝑛\nabla^{2}\bar{\ell}(w,S_{n})∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ). We approximate the integral by the value at the point u=0𝑢0u=0italic_u = 0, where h(u)=wn(t)𝑢subscript𝑤𝑛𝑡h(u)=w_{n}(t)italic_h ( italic_u ) = italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ), then we have the following approximation of c¯nsubscript¯𝑐𝑛\bar{c}_{n}over¯ start_ARG italic_c end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT.

c¯n(t)¯(wn(t),Sn)2¯(wn(t),Sn)𝔼(m)[wn(m)(t)wn(t)]¯(wn(t),Sn)𝔼(m)[wn(m)(t)wn(t)].subscript¯𝑐𝑛𝑡¯superscriptsubscript𝑤𝑛𝑡subscript𝑆𝑛topsuperscript2¯subscript𝑤𝑛𝑡subscript𝑆𝑛subscript𝔼𝑚delimited-[]superscriptsubscript𝑤𝑛𝑚𝑡subscript𝑤𝑛𝑡¯superscriptsubscript𝑤𝑛𝑡subscript𝑆𝑛topsubscript𝔼𝑚delimited-[]superscriptsubscript𝑤𝑛𝑚𝑡subscript𝑤𝑛𝑡\displaystyle\bar{c}_{n}(t)\approx\frac{\nabla\bar{\ell}(w_{n}(t),S_{n})^{\top% }\nabla^{2}\bar{\ell}(w_{n}(t),S_{n})\mathbb{E}_{(m)}\left[w_{n}^{-(m)}(t)-w_{% n}(t)\right]}{\nabla\bar{\ell}(w_{n}(t),S_{n})^{\top}\mathbb{E}_{(m)}\left[w_{% n}^{-(m)}(t)-w_{n}(t)\right]}.over¯ start_ARG italic_c end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) ≈ divide start_ARG ∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT over¯ start_ARG roman_ℓ end_ARG ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) blackboard_E start_POSTSUBSCRIPT ( italic_m ) end_POSTSUBSCRIPT [ italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - ( italic_m ) end_POSTSUPERSCRIPT ( italic_t ) - italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) ] end_ARG start_ARG ∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT blackboard_E start_POSTSUBSCRIPT ( italic_m ) end_POSTSUBSCRIPT [ italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - ( italic_m ) end_POSTSUPERSCRIPT ( italic_t ) - italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) ] end_ARG .

Appendix B Experimental Details

Dataset

We use the MNIST and CIFAR10 datasets for experiments in Section 4. We do experiments on 10-classes and 2-classes (we select classes 0,3) problems on both MNIST and CIFAR10 (denoted as MNIST-10, MNIST-2, CIFAR-10, CIFAR-2 respectively), and 5-classes (we select classes 0,1,2,3,4) problem on MNIST (denoted as MNIST0-5). For all experiments, we choose n/m=10𝑛𝑚10n/m=10italic_n / italic_m = 10.

Architectures

We use LeNet-5 (a network with two convolutional layers of 20 and 50 channels respectively, both of 5×5 kernel size, and a fully-connected layer with 500 hidden neurons), LeNeT-5-GS (the original LeNeT-5 with an additional gray-scale layer), WRN-4-4 (wide residual network with 4 layers and a widening factor of 4, the batch normalization layers are all replaced with layer normalization layers Ba et al. (2016)) and FC (two layer fully-connected net) for training, We use two layer fully-connected net for synthetic data generation.

Synthetic data generation

We created two types of synthetic datasets: 1) Datasets Syn-(a,b)𝑎𝑏(a,b)( italic_a , italic_b ) is created by modifying the labeling regime of MNIST dataset.

  • Approximate the second moment matrix of input 𝔼[xx]𝔼delimited-[]𝑥superscript𝑥top\mathbb{E}[xx^{\top}]blackboard_E [ italic_x italic_x start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] by its empirical version XX/nsuperscript𝑋top𝑋𝑛X^{\top}X/nitalic_X start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X / italic_n calculated by 10000 samples from the original MNIST training set.

  • Eigenvalue decomposition of the empirical second moment matrix XX/n=Qdiag(L)Qsuperscript𝑋top𝑋𝑛𝑄diag𝐿superscript𝑄topX^{\top}X/n=Q\operatorname{\text{diag}}(L)Q^{\top}italic_X start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_X / italic_n = italic_Q diag ( italic_L ) italic_Q start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT, where L𝐿Litalic_L denotes the eigen spectrum sorted from the largest to the smallest.

  • Project the input of training set (except for the samples used for calculating empirical second moment matrix) and validation set of MNIST onto Qa:bsubscript𝑄:𝑎𝑏Q_{a:b}italic_Q start_POSTSUBSCRIPT italic_a : italic_b end_POSTSUBSCRIPT. Whiten each pixel of the projection.

  • Relabel the original input by a teacher network with random weights applied to the projected input.

2) Datasets Gaussian-α𝛼\alphaitalic_α is created with Gaussian data with different covariance matrices, labeled by a teacher network with random weights.

  • Create covariance matrix A𝐴Aitalic_A with i𝑖iitalic_i-th eigenvalue being exp(αi)𝛼𝑖\exp(-\alpha i)roman_exp ( start_ARG - italic_α italic_i end_ARG ). The eigenvalue decomposition of A𝐴Aitalic_A is A=Qdiag(L)Q𝐴𝑄diag𝐿superscript𝑄topA=Q\operatorname{\text{diag}}(L)Q^{\top}italic_A = italic_Q diag ( italic_L ) italic_Q start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT.

  • Sample the input from the multivariate Gaussian distribution N(0,A)𝑁0𝐴N(0,A)italic_N ( 0 , italic_A ).

  • Project the input onto Q1:10subscript𝑄:110Q_{1:10}italic_Q start_POSTSUBSCRIPT 1 : 10 end_POSTSUBSCRIPT. Whiten each elements of the projection.

  • Label the original input by a teacher network with random weights applied to the projected input.

3) Dataset MNIST(random label) is created by randomly assigning labels to the original MNIST inputs, according to a uniform distribution on the ten classes {0,1,2,3,4,5,6,7,8,9}.

B.1 Contraction and perturbation factors for the omitting-m𝑚mitalic_m-samples setting in Section 4

In the omitting m𝑚mitalic_m-samples setting, by similar calculations as in Appendix A, the batch-wise contraction and perturbation factors are

cn(m)(t)=¯(w,S(m))¯(w,Sn(m))|wn(t)wn(m)(t)Δn(m)(t),superscriptsubscript𝑐𝑛𝑚𝑡evaluated-at¯𝑤subscript𝑆𝑚¯𝑤superscriptsubscript𝑆𝑛𝑚subscript𝑤𝑛𝑡superscriptsubscript𝑤𝑛𝑚𝑡superscriptsubscriptΔ𝑛𝑚𝑡\displaystyle c_{n}^{-(m)}(t)=\frac{\nabla\bar{\ell}(w,S_{(m)})\cdot\nabla\bar% {\ell}(w,S_{n}^{-(m)})\big{|}^{w_{n}^{-(m)}(t)}_{w_{n}(t)}}{\Delta_{n}^{-(m)}(% t)},italic_c start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - ( italic_m ) end_POSTSUPERSCRIPT ( italic_t ) = divide start_ARG ∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT ( italic_m ) end_POSTSUBSCRIPT ) ⋅ ∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - ( italic_m ) end_POSTSUPERSCRIPT ) | start_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - ( italic_m ) end_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_POSTSUBSCRIPT end_ARG start_ARG roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - ( italic_m ) end_POSTSUPERSCRIPT ( italic_t ) end_ARG ,
ϵn(m)(t)=¯(w,S(m))(¯(w,Sn)¯(w,Sn(m)))|wn(t).superscriptsubscriptitalic-ϵ𝑛𝑚𝑡evaluated-at¯𝑤subscript𝑆𝑚¯𝑤subscript𝑆𝑛¯𝑤superscriptsubscript𝑆𝑛𝑚subscript𝑤𝑛𝑡\displaystyle\epsilon_{n}^{-(m)}(t)=\nabla\bar{\ell}(w,S_{(m)})\cdot\left(% \nabla\bar{\ell}(w,S_{n})-\nabla\bar{\ell}(w,S_{n}^{-(m)})\right)\bigg{|}_{w_{% n}(t)}.italic_ϵ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - ( italic_m ) end_POSTSUPERSCRIPT ( italic_t ) = ∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT ( italic_m ) end_POSTSUBSCRIPT ) ⋅ ( ∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) - ∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - ( italic_m ) end_POSTSUPERSCRIPT ) ) | start_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_POSTSUBSCRIPT .

The averaged contraction and perturbation factors are

c¯n(t)=𝔼(m)[¯(w,S(m))¯(w,Sn(m))|wn(t)wn(m)(t)]Δ¯n(t),subscript¯𝑐𝑛𝑡subscript𝔼𝑚evaluated-at¯𝑤subscript𝑆𝑚¯𝑤superscriptsubscript𝑆𝑛𝑚subscript𝑤𝑛𝑡superscriptsubscript𝑤𝑛𝑚𝑡subscript¯Δ𝑛𝑡\displaystyle\bar{c}_{n}(t)=\frac{\operatorname*{\mathbb{E}}_{(m)}\left[\nabla% \bar{\ell}(w,S_{(m)})\cdot\nabla\bar{\ell}(w,S_{n}^{-(m)})\big{|}^{w_{n}^{-(m)% }(t)}_{w_{n}(t)}\right]}{\bar{\Delta}_{n}(t)},over¯ start_ARG italic_c end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) = divide start_ARG blackboard_E start_POSTSUBSCRIPT ( italic_m ) end_POSTSUBSCRIPT [ ∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT ( italic_m ) end_POSTSUBSCRIPT ) ⋅ ∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - ( italic_m ) end_POSTSUPERSCRIPT ) | start_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - ( italic_m ) end_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_POSTSUBSCRIPT ] end_ARG start_ARG over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_ARG ,
ϵ¯n(t)=trΣ^(t)n1,Σ^n(t)=CovzUnif(Sn)(wn(t),z),formulae-sequencesubscript¯italic-ϵ𝑛𝑡trace^Σ𝑡𝑛1subscript^Σ𝑛𝑡subscriptCovsimilar-to𝑧Unifsubscript𝑆𝑛subscript𝑤𝑛𝑡𝑧\displaystyle\bar{\epsilon}_{n}(t)=\frac{\tr\hat{\Sigma}(t)}{n-1},\quad\hat{% \Sigma}_{n}(t)=\operatorname*{\text{Cov}}_{z\sim\text{Unif}(S_{n})}\nabla\ell(% w_{n}(t),z),over¯ start_ARG italic_ϵ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) = divide start_ARG roman_tr over^ start_ARG roman_Σ end_ARG ( italic_t ) end_ARG start_ARG italic_n - 1 end_ARG , over^ start_ARG roman_Σ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) = cov start_POSTSUBSCRIPT italic_z ∼ Unif ( italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT ∇ roman_ℓ ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_z ) ,

where Σ^n(t)subscript^Σ𝑛𝑡\hat{\Sigma}_{n}(t)over^ start_ARG roman_Σ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) represents the covariance matrix of (wn(t),z)subscript𝑤𝑛𝑡𝑧\nabla\ell(w_{n}(t),z)∇ roman_ℓ ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_z ) for z𝑧zitalic_z sampled uniformly from the dataset Snsubscript𝑆𝑛S_{n}italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT. Note that the averaged contraction factor ϵ¯n(t)subscript¯italic-ϵ𝑛𝑡\bar{\epsilon}_{n}(t)over¯ start_ARG italic_ϵ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) for removed-m𝑚mitalic_m-samples settings are the same for different m𝑚mitalic_m’s.

B.2 The analysis of the increment of averaged loss difference Δ¯n(t)Δ¯n(t0)subscript¯Δ𝑛𝑡subscript¯Δ𝑛subscript𝑡0\bar{\Delta}_{n}(t)-\bar{\Delta}_{n}(t_{0})over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) - over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ).

In this section, we consider the training process starting from time t0subscript𝑡0t_{0}italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. Different trajectories wn(m)()superscriptsubscript𝑤𝑛𝑚w_{n}^{-(m)}(\cdot)italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - ( italic_m ) end_POSTSUPERSCRIPT ( ⋅ ) and wn()subscript𝑤𝑛w_{n}(\cdot)italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( ⋅ ) are different at time t0subscript𝑡0t_{0}italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, so the batchwise loss difference Δn(m)(t0)subscriptsuperscriptΔ𝑚𝑛subscript𝑡0\Delta^{-(m)}_{n}(t_{0})roman_Δ start_POSTSUPERSCRIPT - ( italic_m ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) and averaged loss difference Δ¯n(t0)subscript¯Δ𝑛subscript𝑡0\bar{\Delta}_{n}(t_{0})over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) are nonzero in general. We now consider the increment Δ¯n(t)Δ¯n(t0)subscript¯Δ𝑛𝑡subscript¯Δ𝑛subscript𝑡0\bar{\Delta}_{n}(t)-\bar{\Delta}_{n}(t_{0})over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) - over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) for t>t0𝑡subscript𝑡0t>t_{0}italic_t > italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. The evolution of Δ¯n(t)Δ¯n(t0)subscript¯Δ𝑛𝑡subscript¯Δ𝑛subscript𝑡0\bar{\Delta}_{n}(t)-\bar{\Delta}_{n}(t_{0})over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) - over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) is,

d(Δ¯n(t)Δ¯n(t0))dt=c¯n(t)(Δ¯n(t)Δ¯n(t0))+ϵ¯n(t),derivative𝑡subscript¯Δ𝑛𝑡subscript¯Δ𝑛subscript𝑡0subscript¯𝑐𝑛𝑡subscript¯Δ𝑛𝑡subscript¯Δ𝑛subscript𝑡0subscript¯italic-ϵ𝑛𝑡\displaystyle\derivative{\left(\bar{\Delta}_{n}(t)-\bar{\Delta}_{n}(t_{0})% \right)}{t}=-\bar{c}_{n}(t)\left(\bar{\Delta}_{n}(t)-\bar{\Delta}_{n}(t_{0})% \right)+\bar{\epsilon}_{n}(t),divide start_ARG roman_d start_ARG ( over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) - over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ) end_ARG end_ARG start_ARG roman_d start_ARG italic_t end_ARG end_ARG = - over¯ start_ARG italic_c end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) ( over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) - over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ) + over¯ start_ARG italic_ϵ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) ,

where by revising the denominator in Eq. 8 and Eq. 9, we have

c¯n(t)=𝔼(m)[¯(w,S(m))¯(w,Sn(m))|wn(t)wn(m)(t)]Δ¯n(t)Δ¯n(t0),subscript¯𝑐𝑛𝑡subscript𝔼𝑚evaluated-at¯𝑤subscript𝑆𝑚¯𝑤superscriptsubscript𝑆𝑛𝑚subscript𝑤𝑛𝑡superscriptsubscript𝑤𝑛𝑚𝑡subscript¯Δ𝑛𝑡subscript¯Δ𝑛subscript𝑡0\displaystyle\bar{c}_{n}(t)=\frac{\operatorname*{\mathbb{E}}_{(m)}\left[\nabla% \bar{\ell}(w,S_{(m)})\cdot\nabla\bar{\ell}(w,S_{n}^{-(m)})\big{|}^{w_{n}^{-(m)% }(t)}_{w_{n}(t)}\right]}{\bar{\Delta}_{n}(t)-\bar{\Delta}_{n}(t_{0})},over¯ start_ARG italic_c end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) = divide start_ARG blackboard_E start_POSTSUBSCRIPT ( italic_m ) end_POSTSUBSCRIPT [ ∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT ( italic_m ) end_POSTSUBSCRIPT ) ⋅ ∇ over¯ start_ARG roman_ℓ end_ARG ( italic_w , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - ( italic_m ) end_POSTSUPERSCRIPT ) | start_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - ( italic_m ) end_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_POSTSUBSCRIPT ] end_ARG start_ARG over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) - over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_ARG ,
ϵ¯n(t)=trΣ^(t)n1,Σ^n(t)=CovzUnif(Sn)(wn(t),z).formulae-sequencesubscript¯italic-ϵ𝑛𝑡trace^Σ𝑡𝑛1subscript^Σ𝑛𝑡subscriptCovsimilar-to𝑧Unifsubscript𝑆𝑛subscript𝑤𝑛𝑡𝑧\displaystyle\bar{\epsilon}_{n}(t)=\frac{\tr\hat{\Sigma}(t)}{n-1},\quad\hat{% \Sigma}_{n}(t)=\operatorname*{\text{Cov}}_{z\sim\text{Unif}(S_{n})}\nabla\ell(% w_{n}(t),z).over¯ start_ARG italic_ϵ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) = divide start_ARG roman_tr over^ start_ARG roman_Σ end_ARG ( italic_t ) end_ARG start_ARG italic_n - 1 end_ARG , over^ start_ARG roman_Σ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) = cov start_POSTSUBSCRIPT italic_z ∼ Unif ( italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT ∇ roman_ℓ ( italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) , italic_z ) .

The evolution of the residual starting from t0subscript𝑡0t_{0}italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT is

rn(t)=Ωn(t0,t)rn(t0).subscript𝑟𝑛𝑡subscriptΩ𝑛subscript𝑡0𝑡subscript𝑟𝑛subscript𝑡0\displaystyle\vec{r}_{n}(t)=\Omega_{n}(t_{0},t)\vec{r}_{n}(t_{0}).over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) = roman_Ω start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_t ) over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) .

Combining with Lemma 10, we have the following decomposition for the covariance trace,

trΣ^n(t)=1nrn(t0)Ωn(t0,t)(Mn(t)Hn(t)n)Ωn(t0,t)rn(t0).tracesubscript^Σ𝑛𝑡1𝑛subscript𝑟𝑛superscriptsubscript𝑡0topsubscriptΩ𝑛superscriptsubscript𝑡0𝑡topsubscript𝑀𝑛𝑡subscript𝐻𝑛𝑡𝑛subscriptΩ𝑛subscript𝑡0𝑡subscript𝑟𝑛subscript𝑡0\displaystyle\tr\hat{\Sigma}_{n}(t)=\frac{1}{n}\vec{r}_{n}(t_{0})^{\top}\Omega% _{n}(t_{0},t)^{\top}\left(M_{n}(t)-\frac{H_{n}(t)}{n}\right)\Omega_{n}(t_{0},t% )\vec{r}_{n}(t_{0}).roman_tr over^ start_ARG roman_Σ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) = divide start_ARG 1 end_ARG start_ARG italic_n end_ARG over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT roman_Ω start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_t ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( italic_M start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) - divide start_ARG italic_H start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_ARG start_ARG italic_n end_ARG ) roman_Ω start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_t ) over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) .

By similar arguments as in Theorem 13, we have the quadratic form expression for the increment of averaged loss difference Δ¯n(t)Δ¯n(t0)subscript¯Δ𝑛𝑡subscript¯Δ𝑛subscript𝑡0\bar{\Delta}_{n}(t)-\bar{\Delta}_{n}(t_{0})over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) - over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ),

Δ¯n(t)Δ¯n(t0)=rn(t0)Kn(t0,t)rn(t0).subscript¯Δ𝑛𝑡subscript¯Δ𝑛subscript𝑡0subscript𝑟𝑛superscriptsubscript𝑡0topsubscript𝐾𝑛subscript𝑡0𝑡subscript𝑟𝑛subscript𝑡0\displaystyle\bar{\Delta}_{n}(t)-\bar{\Delta}_{n}(t_{0})=\vec{r}_{n}(t_{0})^{% \top}K_{n}(t_{0},t)\vec{r}_{n}(t_{0}).over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) - over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_t ) over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) .

where

Kn(t0,t)=t0tΩn(t0,s)(Mn(s)Hn(s)n)Ωn(t0,s)exp(stc¯n(u)du)dsn1.subscript𝐾𝑛subscript𝑡0𝑡superscriptsubscriptsubscript𝑡0𝑡subscriptΩ𝑛superscriptsubscript𝑡0𝑠topsubscript𝑀𝑛𝑠subscript𝐻𝑛𝑠𝑛subscriptΩ𝑛subscript𝑡0𝑠superscriptsubscript𝑠𝑡subscript¯𝑐𝑛𝑢𝑢𝑠𝑛1\displaystyle K_{n}(t_{0},t)=\frac{\int_{t_{0}}^{t}\Omega_{n}(t_{0},s)^{\top}% \left(M_{n}(s)-\frac{H_{n}(s)}{n}\right)\Omega_{n}(t_{0},s)\exp\left(-\int_{s}% ^{t}\bar{c}_{n}(u)\differential u\right)\differential s}{n-1}.italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_t ) = divide start_ARG ∫ start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT roman_Ω start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_s ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( italic_M start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_s ) - divide start_ARG italic_H start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_s ) end_ARG start_ARG italic_n end_ARG ) roman_Ω start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_s ) roman_exp ( - ∫ start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT over¯ start_ARG italic_c end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_u ) start_DIFFOP roman_d end_DIFFOP italic_u ) start_DIFFOP roman_d end_DIFFOP italic_s end_ARG start_ARG italic_n - 1 end_ARG .

Let

Kn(t0)limtKn(t0,t)subscript𝐾𝑛subscript𝑡0subscript𝑡subscript𝐾𝑛subscript𝑡0𝑡K_{n}(t_{0})\triangleq\lim_{t\to\infty}K_{n}(t_{0},t)italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ≜ roman_lim start_POSTSUBSCRIPT italic_t → ∞ end_POSTSUBSCRIPT italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_t )

when the limit exists, then we have

Δ¯n()Δ¯n(t0)=rn(t0)Kn(t0)rn(t0),subscript¯Δ𝑛subscript¯Δ𝑛subscript𝑡0subscript𝑟𝑛superscriptsubscript𝑡0topsubscript𝐾𝑛subscript𝑡0subscript𝑟𝑛subscript𝑡0\displaystyle\bar{\Delta}_{n}(\infty)-\bar{\Delta}_{n}(t_{0})=\vec{r}_{n}(t_{0% })^{\top}K_{n}(t_{0})\vec{r}_{n}(t_{0}),over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( ∞ ) - over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ,

where Δ¯n():=limtΔ¯n(t)assignsubscript¯Δ𝑛subscript𝑡subscript¯Δ𝑛𝑡\bar{\Delta}_{n}(\infty):=\lim_{t\to\infty}\bar{\Delta}_{n}(t)over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( ∞ ) := roman_lim start_POSTSUBSCRIPT italic_t → ∞ end_POSTSUBSCRIPT over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ), and the limit exists. We call Kn(t0)subscript𝐾𝑛subscript𝑡0K_{n}(t_{0})italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) the effective Gram matrix of a neural network starting from t0subscript𝑡0t_{0}italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT.

B.3 The approximation of generalization gap

Table S.1 compares the generalization gaps, averaged loss difference and its approximations for a variety of different architectures and datasets. We can see that in almost all cases, these quantities are very close, indicating that the approximation of averaged loss difference represents well of the generalization gap. The small generalization errors provide guarantees for the quality of the used models. The small training loss shows that the models are trained till near interpolation. The second last column σ¯(Kn)¯𝜎subscript𝐾𝑛\bar{\sigma}(K_{n})over¯ start_ARG italic_σ end_ARG ( italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) shows the estimates of the kernel magnitude. We can see from the last column of the table that when the same datasets are used, even the number of samples are different, the norm of the initial residual are almost the same (eg. last column of row 2-6 in the table for the results of MNIST-5 with different number of samples), which justifies the idea of normalizing the initial residual by 1/n1𝑛1/\sqrt{n}1 / square-root start_ARG italic_n end_ARG, and shows that the normalization makes the effective Gram matrix decomposition of datasets with different samples comparable.

Architecture Dataset # samples Δ¯n(c,ϵ^,t)subscript¯Δ𝑛𝑐^italic-ϵ𝑡\bar{\Delta}_{n}(c,\hat{\epsilon},t)over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_c , over^ start_ARG italic_ϵ end_ARG , italic_t ) Δ¯n(c,ϵ,t)subscript¯Δ𝑛𝑐italic-ϵ𝑡\bar{\Delta}_{n}(c,\epsilon,t)over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_c , italic_ϵ , italic_t ) Δ¯n(t)subscript¯Δ𝑛𝑡\bar{\Delta}_{n}(t)over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) δR(Sn,t)𝛿𝑅subscript𝑆𝑛𝑡\delta R(S_{n},t)italic_δ italic_R ( italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_t ) δR¯(Sn(m),t)𝛿¯𝑅subscriptsuperscript𝑆𝑚𝑛𝑡\delta\bar{R}(S^{-(m)}_{n},t)italic_δ over¯ start_ARG italic_R end_ARG ( italic_S start_POSTSUPERSCRIPT - ( italic_m ) end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_t ) Generalization error Rtrain(Sn,t)subscript𝑅trainsubscript𝑆𝑛𝑡R_{\text{train}}(S_{n},t)italic_R start_POSTSUBSCRIPT train end_POSTSUBSCRIPT ( italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_t ) σ¯(Kn)¯𝜎subscript𝐾𝑛\bar{\sigma}(K_{n})over¯ start_ARG italic_σ end_ARG ( italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) rn(0)22subscriptsuperscriptnormsubscript𝑟𝑛022||\vec{r}_{n}(0)||^{2}_{2}| | over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) | | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT
FC MNIST-2 1100 0.018 0.018 0.021 0.051 0.053 0.028 0.034 0.595644 0.513084
FC MNIST-5 55 0.255 0.266 0.271 0.336 0.360 0.134 0.036 7.374783 0.808355
FC MNIST-5 110 0.221 0.224 0.232 0.247 0.258 0.092 0.038 6.687616 0.820471
FC MNIST-5 550 0.127 0.128 0.142 0.095 0.101 0.040 0.038 4.664132 0.821945
FC MNIST-5 1100 0.111 0.112 0.128 0.084 0.090 0.033 0.038 4.567158 0.823081
FC MNIST-5 2200 0.093 0.089 0.092 0.070 0.076 0.028 0.036 3.780354 0.824114
FC MNIST-10 1100 0.469 0.472 0.476 0.448 0.462 0.134 0.036 22.119434 0.911234
LENET-5 MNIST-10 1100 0.241 0.243 0.228 0.203 0.221 0.075 0.048 6.582657 0.899971
FC CIFAR-2 1100 0.346 0.446 0.342 0.420 0.433 0.154 0.037 10.059950 0.513471
LENET-5-GS CIFAR-2 2200 0.461 0.438 0.461 0.495 0.553 0.140 0.042 5.696342 0.509838
LENET-5 CIFAR-2 1100 0.627 0.642 0.375 0.450 0.504 0.137 0.043 8.446142 0.496051
WRN-4-4 CIFAR-2 1100 0.111 0.147 0.110 0.187 0.204 0.107 0.090 0.535692 0.498821
FC syn-(1,10) 1100 0.221 0.237 0.212 0.160 0.180 0.084 0.038 4.419841 0.571913
FC syn-(11,20) 1100 0.234 0.293 0.371 0.368 0.399 0.137 0.039 5.008938 0.512861
FC syn-(21,30) 1100 0.235 0.391 0.484 0.421 0.440 0.159 0.040 6.834943 0.501674
FC syn-(31,40) 1100 0.347 0.583 0.546 0.529 0.554 0.192 0.041 8.572351 0.483528
FC syn-(41,50) 1100 0.373 0.682 0.599 0.549 0.584 0.192 0.037 10.590858 0.520236
FC MNIST(random label) 55 3.265 3.789 3.276 4.831 4.770 0.902 0.041 55.534283 0.906837
FC Gaussian-1 1100 0.062 0.057 0.053 0.063 0.068 0.039 0.038 2.711061 0.601203
FC Gaussian-0.5 1100 0.087 0.115 0.126 0.122 0.135 0.064 0.038 4.457899 0.569315
FC Gaussian-0.1 1100 0.188 0.198 0.213 0.227 0.239 0.110 0.038 3.605311 0.517485
FC Gaussian-0.05 1100 0.251 0.259 0.257 0.280 0.295 0.124 0.035 3.497570 0.513735
FC Gaussian-0.01 1100 0.488 0.502 0.488 0.518 0.533 0.220 0.039 3.283361 0.518636
Table S.1: Statistics of effective Gram matrix approximation for a variety of different architectures and datasets. See Appendix B for the details of the datasets and architectures. See Section 4.1 for the definitions of generalization gaps δR(Sn,t)𝛿𝑅subscript𝑆𝑛𝑡\delta R(S_{n},t)italic_δ italic_R ( italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_t ), δR¯(Sn(m),t)𝛿¯𝑅superscriptsubscript𝑆𝑛𝑚𝑡\delta\bar{R}(S_{n}^{-(m)},t)italic_δ over¯ start_ARG italic_R end_ARG ( italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - ( italic_m ) end_POSTSUPERSCRIPT , italic_t ), averaged loss difference Δ¯n(t)subscript¯Δ𝑛𝑡\bar{\Delta}_{n}(t)over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) and its approximations Δ¯n(c,ϵ^,t)subscript¯Δ𝑛𝑐^italic-ϵ𝑡\bar{\Delta}_{n}(c,\hat{\epsilon},t)over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_c , over^ start_ARG italic_ϵ end_ARG , italic_t ), Δ¯n(c,ϵ,t)subscript¯Δ𝑛𝑐italic-ϵ𝑡\bar{\Delta}_{n}(c,\epsilon,t)over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_c , italic_ϵ , italic_t ). “Generalization error” in this table refers to the averaged zero-one loss on test dataset. Rtrain(Sn,t)subscript𝑅trainsubscript𝑆𝑛𝑡R_{\text{train}}(S_{n},t)italic_R start_POSTSUBSCRIPT train end_POSTSUBSCRIPT ( italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT , italic_t ) refers to the training loss on dataset Snsubscript𝑆𝑛S_{n}italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT. For the last two columns, σ¯(Kn)¯𝜎subscript𝐾𝑛\bar{\sigma}(K_{n})over¯ start_ARG italic_σ end_ARG ( italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) refers to the mean of the eigenvalues of effective kernel, rn(0)22subscriptsuperscriptnormsubscript𝑟𝑛022\norm{\vec{r}_{n}(0)}^{2}_{2}∥ start_ARG over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) end_ARG ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT refers to the squared norm of initial residual. In this table, we evaluate all the quantities (except for rn(0)22subscriptsuperscriptnormsubscript𝑟𝑛022\norm{\vec{r}_{n}(0)}^{2}_{2}∥ start_ARG over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( 0 ) end_ARG ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT) at the end of training.

B.4 Estimation of the propagator Ωn(t0,t)subscriptΩ𝑛subscript𝑡0𝑡\Omega_{n}(t_{0},t)roman_Ω start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_t )

The propagator Ωn(t0,t)subscriptΩ𝑛subscript𝑡0𝑡\Omega_{n}(t_{0},t)roman_Ω start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_t ) plays a big role in evolution of the residual rn(t)subscript𝑟𝑛𝑡\vec{r}_{n}(t)over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) in Eq. 13 and the effective kernel in Knsubscript𝐾𝑛K_{n}italic_K start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT Eq. 18. We next introduce two different ways of approximating Ωn(t0,t)subscriptΩ𝑛subscript𝑡0𝑡\Omega_{n}(t_{0},t)roman_Ω start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_t ).

Product approximation

By a discrete approximation of the evolution of rn(t)subscript𝑟𝑛𝑡\vec{r}_{n}(t)over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) for a small η=o(1)𝜂𝑜1\eta=o(1)italic_η = italic_o ( 1 ), we have rn(t+η)rn(t)η=1nPn(t)rn(t).subscript𝑟𝑛𝑡𝜂subscript𝑟𝑛𝑡𝜂1𝑛subscript𝑃𝑛𝑡subscript𝑟𝑛𝑡\frac{\vec{r}_{n}(t+\eta)-\vec{r}_{n}(t)}{\eta}=-\frac{1}{n}P_{n}(t)\vec{r}_{n% }(t).divide start_ARG over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t + italic_η ) - over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) end_ARG start_ARG italic_η end_ARG = - divide start_ARG 1 end_ARG start_ARG italic_n end_ARG italic_P start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) over→ start_ARG italic_r end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) . We can derive a discrete approximation of Ωn(t0,t)subscriptΩ𝑛subscript𝑡0𝑡\Omega_{n}(t_{0},t)roman_Ω start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_t ) for t=t0+Tη𝑡subscript𝑡0𝑇𝜂t=t_{0}+T\etaitalic_t = italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_T italic_η to be

Ωn(t)k=0T1(IηnPn(t0+kη)),subscriptΩ𝑛𝑡superscriptsubscriptproduct𝑘0𝑇1𝐼𝜂𝑛subscript𝑃𝑛subscript𝑡0𝑘𝜂\Omega_{n}(t)\approx\prod_{k=0}^{T-1}\left(I-\frac{\eta}{n}P_{n}(t_{0}+k\eta)% \right),roman_Ω start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) ≈ ∏ start_POSTSUBSCRIPT italic_k = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T - 1 end_POSTSUPERSCRIPT ( italic_I - divide start_ARG italic_η end_ARG start_ARG italic_n end_ARG italic_P start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_k italic_η ) ) , (21)

with the products taken from the right.

Refer to caption
Refer to caption
Figure S.1: Approximation results of averaged perturbation and averaged loss difference. This plot shows the statistics of FC trained on MNIST with all 10 classes, with n=100𝑛100n=100italic_n = 100 and m=10𝑚10m=10italic_m = 10. Left: Approximations of ϵ¯(t)¯italic-ϵ𝑡\bar{\epsilon}(t)over¯ start_ARG italic_ϵ end_ARG ( italic_t ), where ϵ¯(t)¯italic-ϵ𝑡\bar{\epsilon}(t)over¯ start_ARG italic_ϵ end_ARG ( italic_t ) is the actual averaged perturbation defined by Eq. 9. Right: Approximations of Δ¯(t)¯Δ𝑡\bar{\Delta}(t)over¯ start_ARG roman_Δ end_ARG ( italic_t ), where Δ¯(c,ϵ,t)¯Δ𝑐italic-ϵ𝑡\bar{\Delta}(c,\epsilon,t)over¯ start_ARG roman_Δ end_ARG ( italic_c , italic_ϵ , italic_t ) is evaluated using the actual expression of contraction Eq. 8 and perturbation Eq. 9, and Δ¯(t)¯Δ𝑡\bar{\Delta}(t)over¯ start_ARG roman_Δ end_ARG ( italic_t ) is evaluated by the actual expression of the averaged loss difference Eq. 5.

Magnus expansion (Magnus, 1954)

We may write the propagator Ωn(t0,t)subscriptΩ𝑛subscript𝑡0𝑡\Omega_{n}(t_{0},t)roman_Ω start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_t ) using its Lie algebra as Ωn(t0,t)=exp(ωn(t0,t))subscriptΩ𝑛subscript𝑡0𝑡subscript𝜔𝑛subscript𝑡0𝑡\Omega_{n}(t_{0},t)=\exp\left(\omega_{n}(t_{0},t)\right)roman_Ω start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_t ) = roman_exp ( italic_ω start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_t ) ). When Pn(t)subscript𝑃𝑛𝑡P_{n}(t)italic_P start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t ) does not commute with itself at different times, the Magnus expansion provides a way to write this time-ordered exponential in terms of an infinite series ωn(t0,t)=k=1ωnk(t0,t)subscript𝜔𝑛subscript𝑡0𝑡superscriptsubscript𝑘1superscriptsubscript𝜔𝑛𝑘subscript𝑡0𝑡\omega_{n}(t_{0},t)=\sum_{k=1}^{\infty}\omega_{n}^{k}(t_{0},t)italic_ω start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_t ) = ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT italic_ω start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_t ) where the first two terms are

ωn1(t0,t)superscriptsubscript𝜔𝑛1subscript𝑡0𝑡\displaystyle\omega_{n}^{1}(t_{0},t)italic_ω start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_t ) =1nt0tPn(t1)dt1,absent1𝑛superscriptsubscriptsubscript𝑡0𝑡subscript𝑃𝑛subscript𝑡1subscript𝑡1\displaystyle=-\frac{1}{n}\int_{t_{0}}^{t}P_{n}(t_{1})\differential t_{1},= - divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∫ start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT italic_P start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) start_DIFFOP roman_d end_DIFFOP italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , (22)
ωn2(t0,t)superscriptsubscript𝜔𝑛2subscript𝑡0𝑡\displaystyle\omega_{n}^{2}(t_{0},t)italic_ω start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_t ) =12n2t0tt0t1[Pn(t1),Pn(t2)]dt2dt1,absent12superscript𝑛2superscriptsubscriptsubscript𝑡0𝑡superscriptsubscriptsubscript𝑡0subscript𝑡1subscript𝑃𝑛subscript𝑡1subscript𝑃𝑛subscript𝑡2subscript𝑡2subscript𝑡1\displaystyle=\frac{1}{2n^{2}}\int_{t_{0}}^{t}\int_{t_{0}}^{t_{1}}[P_{n}(t_{1}% ),P_{n}(t_{2})]\differential t_{2}\differential t_{1},= divide start_ARG 1 end_ARG start_ARG 2 italic_n start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ∫ start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ∫ start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT [ italic_P start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , italic_P start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) ] start_DIFFOP roman_d end_DIFFOP italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_DIFFOP roman_d end_DIFFOP italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ,

where [A,B]ABBA𝐴𝐵𝐴𝐵𝐵𝐴[A,B]\equiv AB-BA[ italic_A , italic_B ] ≡ italic_A italic_B - italic_B italic_A is the commutator of matrices A𝐴Aitalic_A and B𝐵Bitalic_B. Note that ωn2(t0,t)=0superscriptsubscript𝜔𝑛2subscript𝑡0𝑡0\omega_{n}^{2}(t_{0},t)=0italic_ω start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_t ) = 0 if Pn(t1)subscript𝑃𝑛subscript𝑡1P_{n}(t_{1})italic_P start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) and Pn(t2)subscript𝑃𝑛subscript𝑡2P_{n}(t_{2})italic_P start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) commute t1,t2>t0for-allsubscript𝑡1subscript𝑡2subscript𝑡0\forall t_{1},t_{2}>t_{0}∀ italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT > italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. Magnus expansion can be approximated by numerical integration.

Fig. S.1 shows the approximation results for ϵ¯nsubscript¯italic-ϵ𝑛\bar{\epsilon}_{n}over¯ start_ARG italic_ϵ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT and Δ¯nsubscript¯Δ𝑛\bar{\Delta}_{n}over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT when different approximations of Ωn(t0,t)subscriptΩ𝑛subscript𝑡0𝑡\Omega_{n}(t_{0},t)roman_Ω start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_t ) are used. We can see from the plot that the Magnus expansion gives good approximation when t𝑡titalic_t is small, but the approximations diverge from the true values of ϵ¯nsubscript¯italic-ϵ𝑛\bar{\epsilon}_{n}over¯ start_ARG italic_ϵ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT and Δ¯nsubscript¯Δ𝑛\bar{\Delta}_{n}over¯ start_ARG roman_Δ end_ARG start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT for large t𝑡titalic_t. The second order Magnus expansion is even worse than the first order one, this could be result form the overshooting of wn2(t0,t)superscriptsubscript𝑤𝑛2subscript𝑡0𝑡w_{n}^{2}(t_{0},t)italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_t ). The term [Pn(t1),Pn(t2)]subscript𝑃𝑛subscript𝑡1subscript𝑃𝑛subscript𝑡2[P_{n}(t_{1}),P_{n}(t_{2})][ italic_P start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , italic_P start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) ] being highly oscillatory, and the step size being too large can be possible reasons. In comparison, the product approximation performs well till the end of training. Hence, for all experiments in Section 4, we calculate the effective Gram matrix through product approximation.