Use the ValidationSet option to NetTrain to ensure that the trained net does not overfit the input data. This is commonly referred to as a test or hold-out dataset.

Create synthetic training data based on a Gaussian curve.

In[1]:=
Click for copyable input

data = Table[ x -> Exp[-x^2] + RandomVariate[NormalDistribution[0, .15]], {x, -3, 3, .2}];

In[2]:=
Click for copyable input

plot = ListPlot[List @@@ data, PlotStyle -> Red]

Out[2]=

Train a net with a large number of parameters relative to the amount of training data.

In[3]:=
Click for copyable input

net = NetChain[{150, Tanh, 150, Tanh, 1}, "Input" -> "Scalar", "Output" -> "Scalar"]; net1 = NetTrain[net, data, Method -> "ADAM"]

In[4]:=
Click for copyable input

The resulting net overfits the data, learning the noise in addition to the underlying function.

In[5]:=
Click for copyable input

Show[Plot[net1[x], {x, -3, 3}], plot]

Out[5]=

Subdivide the data into a training set and a hold-out validation set.

In[6]:=
Click for copyable input

data = RandomSample[data]; {train, test} = TakeDrop[data, 24];

Use the ValidationSet option to have NetTrain select the net that achieved the lowest validation loss during training.

In[7]:=
Click for copyable input

net2 = NetTrain[net, train, ValidationSet -> test]

In[8]:=
Click for copyable input

The result returned by NetTrain was the net that generalized best to points in the validation set, as measured by validation loss. This penalizes overfitting, as the noise present in the training data is uncorrelated with the noise present in the validation set.

In[9]:=
Click for copyable input

Show[Plot[net2[x], {x, -3, 3}], plot]

Out[9]=

Another way to tackle overfitting is to use L2 regularization, which implicitly associates a loss with nonzero parameters in the net during training. This can be specified with a Method option to NetTrain.

In[10]:=
Click for copyable input

net3 = NetTrain[net, data, Method -> {"ADAM", "L2Regularization" -> 5}]

Out[10]=

L2 regularization penalizes “complex” nets, as measured by the magnitudes of their parameters, which tends to reduce overfitting.

In[11]:=
Click for copyable input

Show[Plot[net3[x], {x, -3, 3}], plot]

Out[11]=

 

Content retrieved from: https://www.wolfram.com/language/11/neural-networks/avoid-overfitting-using-a-hold-out-set.html?product=mathematica.