## Neural nets

The following was not written by me, but by a friend now retired from Bell labs. It is so good that it’s worth sharing.

I asked him to explain the following paper to me which I found incomprehensible despite reading about neural nets for years. The paper tries to figure out why neural nets work so well. The authors note that we lack a theoretical foundation for how neural nets work (or why they should !).

https://www.pnas.org/content/pnas/117/44/27070.full.pdf

Here’s what I got back

Interesting paper. Thanks.

I’ve had some exposure to these ideas and this particular issue, but I’m hardly an expert.

I’m not sure what aspect of the paper you find puzzling. I’ll just say a few things about what I gleaned out of the paper, which may overlap with what you’ve already figured out.

The paper, which is really a commentary on someone else’s work, focuses on the classification problem. Basically, classification is just curve fitting. The curve you want defines a function f that takes a random example x from some specified domain D and gives you the classification c of x, that is, c = f(x).

Neural networks (NNs) provide a technique for realizing this function f by way of a complex network with many parameters that can be freely adjusted. You take a (“small”) subset T of examples from D where you know the classification and you use those to “train” the NN, which means you adjust the parameters to minimize the errors that the NN makes when classifying the elements of T. You then cross your fingers and hope that the NN will show useful accuracy when classifying examples from D that it has not seen before (i.e., examples that were not in the training set T). There is lots of empirical hokus pokus and rules-of-thumb concerning what techniques work better than others in designing and training neural networks. Research to place these issues on a firmer theoretical basis continues.

You might think that the best way to train a NN doing the classification task is simply to monitor the classifications it makes on the training set vectors and adjust the NN parameters (weights) to minimize those errors. The problem here is that classification output is very granular (discontinuous): cat/dog, good/bad, etc. You need to have a more nuanced (“gray”) view of things to get the hints you need to gradually adjust the NN weights and home in on their “best” setting. The solution is a so-called “loss” function, a continuous function that operates on the output data before it’s classified (while it is still very analog, as opposed to the digital-like classification output). The loss function should be chosen so that lower loss will generally correspond to lower classification error. Choosing it, of course, is not a trivial thing. I’ll have more to say about that later.

One of the supposed truisms of NNs in the “old days” was that you shouldn’t overtrain the network. Overtraining means beating the parameters to death until you get 100% perfect classification on the training set T. Empirically, it was found that overtraining degrades performance: Your goal should be to get “good” performance on T, but not “too good.” Ex post facto, this finding was rationalized as follows: When you overtrain, you are teaching the NN to do an exact thing for an exact set T, so the moment it sees something that differs even a little from the examples in set T, the NN is confused about what to do. That explanation never made much sense to me, but a lot of workers in the field seemed to find it persuasive.

Perhaps a better analogy is the non-attentive college student who skipped lectures all semester and has gained no understanding of the course material. Facing a failing grade, he manages by chicanery to steal a copy of the final exam a week before it’s given. He cracks open the textbook (for the first time!) and, by sheer willpower, manages to ferret out of that wretched tome what he guesses are the correct, exact answers to all the questions in the exam. He doesn’t really understand any of the answers, but he commits them to memory and is now glowing with confidence that he will ace the test and get a good grade in the course.

But a few days before the final exam date the professor decides to completely rewrite the exam, throwing out all the old questions and replacing them with new ones. The non-attentive student, faced with exam questions he’s never seen before, has no clue how to answer these unfamiliar questions because he has no understanding of the underlying principles. He fails the exam badly and gets an F in the course.
Relating the analogy of the previous two paragraphs to the concept of overtraining NNs, the belief was that if you train a NN to do a “good” job on the test set T but not “too good” a job, it will incorporate (in its parameter settings) some of the background knowledge of “why” examples are classified the way they are, which will help it do a better job when it encounters “unfamiliar” examples (i.e., examples not in the test set). However, if you push the training beyond that point, the NN starts to enter the regime where its learning (embodied in its parameter settings) becomes more like the rote memorization of the non-attentive student, devoid of understanding of the underlying principles and ill prepared to answer questions it has not seen before. Like I said, I was never sure this explanation made a lot of sense, but workers in the field seemed to like it.

That brings us to “deep learning” NNs, which are really just old-fashioned NNs but with lots more layers and, therefore, lots more complexity. So instead of having just “many” parameters, you have millions. For brevity in what follows, I’ll often refer to a “deep learning NN” as simply a “NN.”
Now let’s refer to Figure 1 in the paper. It illustrates some of the things I said above. The vertical axis measures error, while the horizontal axis measures training iterations. Training involves processing a training vector from T, looking at the resulting value of the loss function, and adjusting the NN’s weights (from how you set them in the previous iteration) in a manner that’s designed to reduce the loss. You do this with each training vector in sequence, which causes the NN’s weights to gradually change to values that (you hope) will result in better overall performance. After a certain predetermined number of training iterations, you stop and measure the overall performance of the NN: the overall error on the training vectors, the overall loss, and the overall error on the test vectors. The last are vectors from D that were not part of the training set.

Figure 1 illustrates the overtraining phenomenon. Initially, more training gives lower error on the test vectors. But then you hit a minimum, with more training after that resulting in higher error on the test set. In old-style NNs, that was the end of the story. With deep-learning NNs, it was discovered that continuing the training well beyond what was previously thought wise, even into the regime where the training error is at or near zero (the so-called Terminal Phase of Training—TFT), can produce a dramatic reduction in test error. This is the great mystery that researchers are trying to understand.

You can read the four points in the paper on page 27071, which are posited as “explanations” of—or at least observations of interesting phenomena that accompany—this unexpected lowering of test error. I read points 1 and 2 as simply saying that the pre-classification portion of the NN [which executes z = h(x, theta), in their terminology] gets so fine-tuned by the training that it is basically doing the classification all by itself, with the classifier per se being left to do only a fairly trivial job (points 3 and 4).
To me, I feel like this “explanation” misses the point. Here is my two-cents worth: I think the whole success of this method is critically dependent on the loss function. The latter has to embody, with good fidelity, the “wisdom” of what constitutes a good answer. If it does, then overtraining the deep-learning NN like crazy on that loss function will cause its millions of weights to “learn” that wisdom. That is, the NN is not just learning what the right answer is on a limited set of training vectors, but it is learning the “wisdom” of what constitutes a right answer from the loss function itself. Because of the subtlety and complexity of that latent loss function wisdom, this kind of learning became possible only with the availability of modern deep-learning NNs with their great complexity and huge number of parameters.