**CODE**

% Let's generate some data in two dimensions.

x = randn(1,200);

y = x.^2 + 3*x + 2 + 3*randn(1,200);

figure(999); clf; hold on;

h1 = scatter(x,y,'k.');

xlabel('x'); ylabel('y');

% Through inspection of the scatterplot, we see that there

% appears to be a nonlinear relationship between x and y.

% We would like to build a model that quantitatively characterizes

% this relationship.

% Let's consider two different models. One model is a purely linear

% model, y = a*x + b, where a and b are free parameters. The

% second model is a quadratic model, y = a*x^2 + b*x + c where

% a, b, and c are free parameters. We will assume that we want

% to minimize the squared error between the model and the data.

model1 = polyfit(x,y,1);

model2 = polyfit(x,y,2);

ax = axis;

xx = linspace(ax(1),ax(2),100);

h2 = plot(xx,polyval(model1,xx),'r-','LineWidth',2);

h3 = plot(xx,polyval(model2,xx),'g-','LineWidth',2);

axis(ax);

legend([h2 h3],{'Linear model' 'Quadratic model'});

title('Direct fit (no cross-validation)');

% Although the linear model captures the basic trend in the data, the

% quadratic model seems to characterize the data better. In particular, the

% linear model seems to overestimate the data at middle values of x and

% underestimate the data at low and high values of x.

% How can we formally establish which model is best? A simple approach

% is to quantify the fit quality of each model using a metric like R^2

% (see the blog post on R^2) and then determine which model has

% the higher fit quality. The problem with this approach is that the

% quadratic model will always outperform the linear model in

% terms of fit quality, even when the true underlying relationship

% is purely linear. The reason is that the quadratic model subsumes

% the linear model and includes one additional parameter (the

% weight on the x^2 term). Thus, the quadratic model will always do at

% least as well as the linear model and will do even better given the

% extra parameter (unless the weight on the x^2 term is estimated to be

% exactly zero).

% A better approach is to quantify the prediction quality of each model

% using cross-validation. This approach is exactly the same as the

% first approach, except that the quality of fit is evaluated on new

% data points that are not used to fit the parameters of the model. The

% intuition is that the fit quality on the data points used for training

% will, on average, be an overestimate of the true fit quality (i.e. the

% expected fit quality of the estimated model for data points drawn from

% the underlying data distribution).

% With cross-validation, there is no guarantee that the more complex

% quadratic model will outperform the linear model. The best performing

% model will be the one that, after parameter estimation, most closely

% characterizes the underlying relationship between x and y.

% Let's use cross-validation in fitting the linear and

% quadratic models to our example dataset. Specifically,

% let's use leave-one-out cross-validation, a method in which we

% leave a single data point out, fit the model on the remaining data

% points, predict the left-out data point, and then repeat this whole

% process for every single data point. We will use the metric of R^2

% to quantify how closely the model predictions match the data.

model1 = fitprfstatic([x; ones(1,length(x))]',y',0,0,[],-1,[],[],[],@calccod);

model2 = fitprfstatic([x.^2; x; ones(1,length(x))]',y',0,0,[],-1,[],[],[],@calccod);

figure(998); setfigurepos([100 100 600 300]); clf;

subplot(1,2,1); hold on;

h1 = scatter(x,y,'k.');

ax = axis;

[d,ii] = sort(x);

h2 = plot(x(ii),model1.modelfit(ii),'r-','LineWidth',2);

axis(ax);

xlabel('x'); ylabel('y');

title(sprintf('Linear model; cross-validated R^2 = %.1f',model1.r));

subplot(1,2,2); hold on;

h3 = scatter(x,y,'k.');

ax = axis;

[d,ii] = sort(x);

h4 = plot(x(ii),model2.modelfit(ii),'g-','LineWidth',2);

axis(ax);

xlabel('x'); ylabel('y');

title(sprintf('Quadratic model; cross-validated R^2 = %.1f',model2.r));

% In these plots, the red and green lines indicate the predictions

% of the linear and quadratic models, respectively. Notice that

% the lines are slightly wiggly; this reflects the fact that each

% predicted data point comes from a different model estimate.

% We find that the quadratic model achieves a higher

% cross-validated R^2 value than the linear model.

% Let's repeat these simulations for another dataset.

x = randn(1,50);

y = .1*x.^2 + x + 2 + randn(1,50);

% For this dataset, even though the underlying relationship between

% x and y is quadratic, we find that the quadratic model produces

% a lower cross-validated R^2 than the linear model. This indicates

% that we do not have enough data to reliably estimate the parameters

% of the quadratic model. Thus, we are better off estimating the linear

% model and using the linear model estimate as a description of the data.

% (For comparison purposes, the R^2 between the true model and the

% observed y-values, calculated as calccod(.1*x.^2 + x + 2,y), is 47.4.

% The linear model's R^2 is not as good as this, but is substantially

% better than the quadratic's model R^2.)

## No comments:

## Post a Comment