Saturday, December 3, 2011

Basics of regression and model fitting

In regression, we build a model that uses one or more variables to predict some other variable. To understand regression, it is useful to play with simple two-dimensional data (where one variable is used to predict a second variable). An important aspect of regression is the use of cross-validation to evaluate the quality of different models.

CODE

% Let's generate some data in two dimensions.
x = randn(1,200);
y = x.^2 + 3*x + 2 + 3*randn(1,200);
figure(999); clf; hold on;
h1 = scatter(x,y,'k.');
xlabel('x'); ylabel('y');



% Through inspection of the scatterplot, we see that there
% appears to be a nonlinear relationship between x and y.
% We would like to build a model that quantitatively characterizes
% this relationship.

% Let's consider two different models.  One model is a purely linear
% model, y = a*x + b, where a and b are free parameters. The
% second model is a quadratic model, y = a*x^2 + b*x + c where
% a, b, and c are free parameters.  We will assume that we want
% to minimize the squared error between the model and the data.
model1 = polyfit(x,y,1);
model2 = polyfit(x,y,2);
ax = axis;
xx = linspace(ax(1),ax(2),100);
h2 = plot(xx,polyval(model1,xx),'r-','LineWidth',2);
h3 = plot(xx,polyval(model2,xx),'g-','LineWidth',2);
axis(ax);
legend([h2 h3],{'Linear model' 'Quadratic model'});
title('Direct fit (no cross-validation)');



% Although the linear model captures the basic trend in the data, the
% quadratic model seems to characterize the data better. In particular, the
% linear model seems to overestimate the data at middle values of x and
% underestimate the data at low and high values of x.

% How can we formally establish which model is best? A simple approach
% is to quantify the fit quality of each model using a metric like R^2
% (see the blog post on R^2) and then determine which model has
% the higher fit quality.  The problem with this approach is that the
% quadratic model will always outperform the linear model in
% terms of fit quality, even when the true underlying relationship
% is purely linear. The reason is that the quadratic model subsumes
% the linear model and includes one additional parameter (the
% weight on the x^2 term).  Thus, the quadratic model will always do at
% least as well as the linear model and will do even better given the
% extra parameter (unless the weight on the x^2 term is estimated to be
% exactly zero).

% A better approach is to quantify the prediction quality of each model
% using cross-validation.  This approach is exactly the same as the
% first approach, except that the quality of fit is evaluated on new
% data points that are not used to fit the parameters of the model. The
% intuition is that the fit quality on the data points used for training
% will, on average, be an overestimate of the true fit quality (i.e. the
% expected fit quality of the estimated model for data points drawn from
% the underlying data distribution).

% With cross-validation, there is no guarantee that the more complex
% quadratic model will outperform the linear model.  The best performing
% model will be the one that, after parameter estimation, most closely
% characterizes the underlying relationship between x and y.

% Let's use cross-validation in fitting the linear and
% quadratic models to our example dataset. Specifically,
% let's use leave-one-out cross-validation, a method in which we
% leave a single data point out, fit the model on the remaining data
% points, predict the left-out data point, and then repeat this whole
% process for every single data point. We will use the metric of R^2
% to quantify how closely the model predictions match the data.
model1 = fitprfstatic([x; ones(1,length(x))]',y',0,0,[],-1,[],[],[],@calccod);
model2 = fitprfstatic([x.^2; x; ones(1,length(x))]',y',0,0,[],-1,[],[],[],@calccod);
figure(998); setfigurepos([100 100 600 300]); clf;
subplot(1,2,1); hold on;
h1 = scatter(x,y,'k.');
ax = axis;
[d,ii] = sort(x);
h2 = plot(x(ii),model1.modelfit(ii),'r-','LineWidth',2);
axis(ax);
xlabel('x'); ylabel('y');
title(sprintf('Linear model; cross-validated R^2 = %.1f',model1.r));
subplot(1,2,2); hold on;
h3 = scatter(x,y,'k.');
ax = axis;
[d,ii] = sort(x);
h4 = plot(x(ii),model2.modelfit(ii),'g-','LineWidth',2);
axis(ax);
xlabel('x'); ylabel('y');
title(sprintf('Quadratic model; cross-validated R^2 = %.1f',model2.r));



% In these plots, the red and green lines indicate the predictions
% of the linear and quadratic models, respectively.  Notice that
% the lines are slightly wiggly; this reflects the fact that each
% predicted data point comes from a different model estimate.
% We find that the quadratic model achieves a higher
% cross-validated R^2 value than the linear model.

% Let's repeat these simulations for another dataset.
x = randn(1,50);
y = .1*x.^2 + x + 2 + randn(1,50);





% For this dataset, even though the underlying relationship between
% x and y is quadratic, we find that the quadratic model produces
% a lower cross-validated R^2 than the linear model.  This indicates
% that we do not have enough data to reliably estimate the parameters
% of the quadratic model.  Thus, we are better off estimating the linear
% model and using the linear model estimate as a description of the data.
% (For comparison purposes, the R^2 between the true model and the
% observed y-values, calculated as calccod(.1*x.^2 + x + 2,y), is 47.4.
% The linear model's R^2 is not as good as this, but is substantially
% better than the quadratic's model R^2.)

No comments:

Post a Comment