Some statistical theory for deep neural networks...

20
Some statistical theory for deep neural networks Johannes Schmidt-Hieber 1 / 20

Transcript of Some statistical theory for deep neural networks...

Page 1: Some statistical theory for deep neural networks [1cm]pwp.gatech.edu/.../presentation-schmidt-hieber-a-j.pdfIRate for DNNs .n =(2 +1) (up to logarithmic factors) IRate for best wavelet

Some statistical theory for deep neural networks

Johannes Schmidt-Hieber

1 / 20

Page 2: Some statistical theory for deep neural networks [1cm]pwp.gatech.edu/.../presentation-schmidt-hieber-a-j.pdfIRate for DNNs .n =(2 +1) (up to logarithmic factors) IRate for best wavelet

Deep neural networks

Network architecture (L,p) consists of

I a positive integer L called the number of hidden layers/depth

I width vector p = (p0, . . . , pL+1) ∈ NL+2.

Neural network with network architecture (L,p)

f : Rp0 → RpL+1 , x 7→ f (x) = WL+1σvLWLσvL−1· · ·W2σv1W1x,

Network parameters:

I Wi is a pi × pi−1 matrix

I vi ∈ Rpi

Activation function:

I We study the ReLU activation function σ(x) = max(x , 0).

2 / 20

Page 3: Some statistical theory for deep neural networks [1cm]pwp.gatech.edu/.../presentation-schmidt-hieber-a-j.pdfIRate for DNNs .n =(2 +1) (up to logarithmic factors) IRate for best wavelet

Depth

Source: Kaiming He, Deep Residual Networks

I Networks are deepI version of ResNet with 152 hidden layersI networks become deeper

3 / 20

Page 4: Some statistical theory for deep neural networks [1cm]pwp.gatech.edu/.../presentation-schmidt-hieber-a-j.pdfIRate for DNNs .n =(2 +1) (up to logarithmic factors) IRate for best wavelet

High-dimensionality

Source: arxiv.org/pdf/1605.07678.pdf

I Number of network parameters is larger than sample sizeI AlexNet uses 60 million parameters for 1.2 million training

samples

4 / 20

Page 5: Some statistical theory for deep neural networks [1cm]pwp.gatech.edu/.../presentation-schmidt-hieber-a-j.pdfIRate for DNNs .n =(2 +1) (up to logarithmic factors) IRate for best wavelet

Network sparsity

(ax + b)+

−b/a

I There is some sort of sparsity on the parameters

I units get deactivated by initialization or by learning

I we assume instead sparsely connected networks

5 / 20

Page 6: Some statistical theory for deep neural networks [1cm]pwp.gatech.edu/.../presentation-schmidt-hieber-a-j.pdfIRate for DNNs .n =(2 +1) (up to logarithmic factors) IRate for best wavelet

The large parameter trick

I with arbitrarily large networkparameters we can approximatethe indicator function via

x 7→ σ(ax)− σ(ax − 1)

I it is common in approximation theory to use networks withnetwork parameters tending to infinity

I in practice, network parameters aretypically small

we restrict all network parameters in absolute value by one

6 / 20

Page 7: Some statistical theory for deep neural networks [1cm]pwp.gatech.edu/.../presentation-schmidt-hieber-a-j.pdfIRate for DNNs .n =(2 +1) (up to logarithmic factors) IRate for best wavelet

Statistical analysis

I we observe n i.i.d. copies (X1,Y1), . . . , (Xn,Yn),

Yi = f (Xi ) + εi , εi ∼ N (0, 1)

I Xi ∈ Rd , Yi ∈ R,I goal is to reconstruct the function f : Rd → R

I has been studied extensively (kernel smoothing, wavelets,splines, . . . )

7 / 20

Page 8: Some statistical theory for deep neural networks [1cm]pwp.gatech.edu/.../presentation-schmidt-hieber-a-j.pdfIRate for DNNs .n =(2 +1) (up to logarithmic factors) IRate for best wavelet

The estimator

I choose network architecture (L,p) and sparsity sI denote by F(L,p, s) the class of all networks with

I architecture (L,p)I number of active (e.g. non-zero) parameters is s

I our theory applies to any estimator f̂n taking values inF(L,p, s)

I prediction error

R(f̂n, f ) := Ef

[(f̂n(X)− f (X)

)2],

with XD= X1 being independent of the sample

I study the dependence of n on R(f̂n, f )

8 / 20

Page 9: Some statistical theory for deep neural networks [1cm]pwp.gatech.edu/.../presentation-schmidt-hieber-a-j.pdfIRate for DNNs .n =(2 +1) (up to logarithmic factors) IRate for best wavelet

Function class

I classical idea: assume that regression function is β-smooth

I optimal nonparametric estimation rate is n−2β/(2β+d)

I suffers from curse of dimensionality

I to understand deep learning this setting is therefore useless

I make a good structural assumption on f

9 / 20

Page 10: Some statistical theory for deep neural networks [1cm]pwp.gatech.edu/.../presentation-schmidt-hieber-a-j.pdfIRate for DNNs .n =(2 +1) (up to logarithmic factors) IRate for best wavelet

Hierarchical structure

I Important: Only few objects are combined on deeperabstraction levelI few letters in one wordI few words in one sentence

10 / 20

Page 11: Some statistical theory for deep neural networks [1cm]pwp.gatech.edu/.../presentation-schmidt-hieber-a-j.pdfIRate for DNNs .n =(2 +1) (up to logarithmic factors) IRate for best wavelet

Function class

I We assume that

f = gq ◦ . . . ◦ g0

withI gi : Rdi → Rdi+1 .I each of the di+1 components of gi is βi -smooth and depends

only on ti variablesI ti can be much smaller than diI effective smoothness

β∗i := βi

q∏`=i+1

(β` ∧ 1).

I we show that the rate depends on the pairs

(ti , β∗i ), i = 0, . . . , q.

11 / 20

Page 12: Some statistical theory for deep neural networks [1cm]pwp.gatech.edu/.../presentation-schmidt-hieber-a-j.pdfIRate for DNNs .n =(2 +1) (up to logarithmic factors) IRate for best wavelet

Example: Additive models

I In an additive model

f (x) =d∑

i=1

fi (xi )

I This can be written as f = g1 ◦ g0 with

g0(x) = (fi (xi ))i=1,...,d , g1(y) =d∑

i=1

yi .

Hence, t0 = 1, d1 = t1 = d .I Decomposes additive functions in

I one function that can be non-smooth but every component isone-dimensional

I one function that has high-dimensional input but the functionis smooth

12 / 20

Page 13: Some statistical theory for deep neural networks [1cm]pwp.gatech.edu/.../presentation-schmidt-hieber-a-j.pdfIRate for DNNs .n =(2 +1) (up to logarithmic factors) IRate for best wavelet

Main resultTheorem: If

(i) depth � log n

(ii) width ≥ network sparsity � maxi=0,...,q nti

2β∗i+ti log n

Then, for any network reconstruction method f̂n,

predicition error � φn + ∆n

(up to log n-factors) with

∆n := E[1

n

n∑i=1

(Yi − f̂n(Xi ))2 − inff ∈F(L,p,s)

1

n

n∑i=1

(Yi − f (Xi ))2]

and

φn := maxi=0,...,q

n− 2β∗i

2β∗i+ti .

13 / 20

Page 14: Some statistical theory for deep neural networks [1cm]pwp.gatech.edu/.../presentation-schmidt-hieber-a-j.pdfIRate for DNNs .n =(2 +1) (up to logarithmic factors) IRate for best wavelet

Consequences

I the assumption that depth � log n appears naturally

I in particular the depth scales with the sample size

I the networks can have much more parameters than thesample size

important for statistical performance is not the size ofthe network but the amount of regularization

14 / 20

Page 15: Some statistical theory for deep neural networks [1cm]pwp.gatech.edu/.../presentation-schmidt-hieber-a-j.pdfIRate for DNNs .n =(2 +1) (up to logarithmic factors) IRate for best wavelet

Consequences (ctd.)

paradox:

I good rate for all smoothness indices

I existing piecewise linear methods only give good rates up tosmoothness two

I Here the non-linearity of the function class helps

non-linearity is essential!!!

15 / 20

Page 16: Some statistical theory for deep neural networks [1cm]pwp.gatech.edu/.../presentation-schmidt-hieber-a-j.pdfIRate for DNNs .n =(2 +1) (up to logarithmic factors) IRate for best wavelet

Suboptimality of wavelet estimators

I f (x) = h(x1 + . . .+ xd)

I for some α-smooth function h

I Rate for DNNs . n−α/(2α+1) (up to logarithmic factors)

I Rate for best wavelet thresholding estimator & n−α/(2α+d)

I Reason: Low-dimensional structure does not affect the decayof the wavelet coefficients

16 / 20

Page 17: Some statistical theory for deep neural networks [1cm]pwp.gatech.edu/.../presentation-schmidt-hieber-a-j.pdfIRate for DNNs .n =(2 +1) (up to logarithmic factors) IRate for best wavelet

MARS

I consider products of ramp functions

hI ,t(x1, . . . , xd) =∏j∈I

(± (xj − tj)

)+

I piecewise constant in each component

I MARS (multivariate adaptive regression splines) fits linearcombinations of such functions to data

I greedy algorithm

I has depth and width type parameters

17 / 20

Page 18: Some statistical theory for deep neural networks [1cm]pwp.gatech.edu/.../presentation-schmidt-hieber-a-j.pdfIRate for DNNs .n =(2 +1) (up to logarithmic factors) IRate for best wavelet

Comparison with MARS

I how does MARS compare to ReLU networks?

I functions that can be represented by s parameters withrespect to the MARS function system can be represented bys log(1/ε)-sparse DNNs up to sup-norm error ε

18 / 20

Page 19: Some statistical theory for deep neural networks [1cm]pwp.gatech.edu/.../presentation-schmidt-hieber-a-j.pdfIRate for DNNs .n =(2 +1) (up to logarithmic factors) IRate for best wavelet

Comparison with MARS (ctd.)

Figure: Reconstruction using MARS (left) and networks (right)

I the opposite is not true, one counterexample is

f (x1, x2) = (x1 + x2 − 1)+

I we need & ε−1/2 many parameters to get ε-close with MARSfunctions

I conclusion: DNNs work better for correlated design

19 / 20

Page 20: Some statistical theory for deep neural networks [1cm]pwp.gatech.edu/.../presentation-schmidt-hieber-a-j.pdfIRate for DNNs .n =(2 +1) (up to logarithmic factors) IRate for best wavelet

References:

I SH (2017). Nonparametric regression using deep neuralnetworks with ReLU activation function. arxiv:1708.06633

I Eckle, SH (2018). A comparison of deep networks with ReLUactivation function and linear spline-type methods.arXiv:1804.02253

20 / 20