Understanding Overfitting and Regularization in Machine Learning
In machine learning, overfitting is a common problem. Today we try to understand it with an example from a book titled pattern recognition by Bishop. Let us take an example of recognizing handwritten digits from zero to nine. A naive way of recognizing these digits can be writing rules. Let’s say these images are three by three pixels, we can straighten these pixels and make a nine-dimensional vector, and then we can write rules based on the amount of grey in a pixel, zero for white and one for full black. This way, we can classify these 10 digits. But you can imagine the number of rules will quickly become huge if the size of the number of digits grows, or we have to write rules for text.
Let us try to solve the same problem with linear regression curve fitting, we can represent a digit with nine points, the blue dots, and if we were to make rules to guess this digit, we can approximate a polynomial for it, the equation of the polynomial is shown below and is shown as a green line. If we can estimate the parameters of this polynomial, then we can say this polynomial represents the model for this digit and any unforeseen digit can be classified using this model. But the challenge is finding the parameters of this polynomial.
Let’s generalize this by approximating the red curve and summing the error on both sides of the curve from the actual data points’ blue dots, the total error will be the summation of the difference between observed data and data given by a polynomial at some weights. If we assume the polynomial of order zero, a straight line red line on the right, we can see there will be a huge error and this is not a good fit for the data.
Note that the error is actually the function of the weights of polynomial changing weights can reduce or increase the error, first order polynomial will still be a straight line, but with slope and we can see there will be lots of errors. Third-order polynomial on the other hand, seems a very good fit, but if we keep on increasing the order of the polynomials, the errors keep decreasing. In the ninth order, we have zero error as the polynomial passes through all the blue data points. But is this a good generalized problem? Definitely not. This curve with zero error represents the current digit, but cannot classify any new digit, which might have a slight change in the pixel link. So this cannot be generalized Also, you can see in the ninth order the polynomial weights become huge, the training error becomes zero, but the testing error goes to 100% Bottom left diagram, this is called overfitting.
The model has learned the current dataset, but cannot be generalized. one potential way to gegeneralizes to increase the observed data points, the bottom no curve with a polynomial of order nine and 15 data points, it seems a very good fit, and may generalize better, but the question is how to find this number. One way to have many data points, as in the bottom right diagram. But if we do not have the luxury of unlimited data points, then what to do?
One way to generalize to penalize weights of the polynomials, as given in the equation in the top left. This is called the regularization of the model. Finding such a lamda is that it has a smaller variation in the weights of polynomials is one way to generalize model with a limited number of data points.
You can see by regularization, we can generate a good fit with smaller order polynomials, and with a smaller number of data points. So regularization helps avoid overfitting and improves the generalization capabilities of the models. There are other ways to generalize models. Stay tuned for these in the next videos. As a side note, why do we call it linear regression? Even though the polynomials are of higher order? The answer is because the weights are linear. Thank you and see you next time