Cross validation is a widely-used model validation technique to estimate how accurately a predictive model will generalize to an independent data set. There are two main uses of cross validation: hyperparameter tuning and model assessment. This post will briefly discuss the use of cross validation in hyperparameter tuning before focusing on using cross validation for model assessment and showing how to compare models based on their cross validation errors using the Start Groups and End Groups nodes in SAS Enterprise Miner.
In hyperparameter tuning, cross validation is used to select the suitable flexibility of a modelduring model building. For example, when building a neural network model cross validation can be used to find optimal hyperparameter values (e.g., number of hidden layers, number of neurons, learning rate, momentum of stochastic gradient decent algorithm, etc.) Hyperparameter tuning based on cross validation can be done automatically using the new Autotune statement available in a number of SAS® Visual Data Mining and Machine Learning procedures (PROCs FOREST, GRADBOOST, NNET, and TREESPLIT).
In model assessment, cross validation is used to compare different models that have already been builtusingthe full training data. Suppose you built several models using various algorithms and hyperparameter settings and now you want to compare these models by estimating their prediction power. The basic idea in calculating cross validation error is to divide up training data into k-folds (e.g. k=5 or k=10). Each fold will then be held out one at a time, the model will be trained on the remaining data, and that model will then be used to predict the target for theholdout observations. When you finish fitting and scoring for all k versions of the training and validation data sets, you will obtainholdoutpredictions for all of the observations inyour originaltraining data. The average squared error between these predictions and the true observed response is the cross validation error.
In SAS Enterprise Miner, Start/End Groups nodes were originally implemented to stratify an analysis based on a stratification variable. However, with a couple of simple tricks these nodes can be used along with the Model Import node to obtain cross validation error of a model. You can even compare several models based on their cross validation errors using the Model Comparison node.
Suppose you fita model on your full training data using the Gradient Boosting node in SAS Enterprise Minerfor the followingset of the hyperparameters (Niterations=50, Shrinkage=0.2, Train proportion=60, etc):
Now you can calculate cross validation error of thismodel by running the following flow:
1. Use Transform Variables node to create a k-fold cross validation indicator as a new input variable (_fold_) that randomly divides your data set into k-folds. Make sure to save this new variable as a segment variable. For example, for 5-fold cross validation, Formulas of the Transform Variables node should look like this:
2. In the Start Groups node, specify the “Mode” as “Cross-validation” andin the Gradient Boosting nodemake sure to use the same parameter settings that you used in your original boosted trees model. Run until the End Groups node.
While the Start/End Group nodes manage to create k versions of training data and calculate fit statistics of the training data, they do not actually calculate the cross validation error from scoring the holdout observations using these fitted models.However, if you check the score code generated by the End Groups node, you cansee that it generates the correct score code to calculate the cross validation error. You can view this score code by first clicking theResults of the End Groups node, then on the top menu click View>>SAS Results>>Flow Code.However this readily available score code can be used by another node (such as Model Import node or SAS Code node) to obtain the cross validation error.
3. Attach the Model Import node and run the whole path. The Train: Average Squared Error column in the Results of the Model Import node is the k-fold cross validation error of youroriginalboosted treesmodel that you trained by using the full training data.
If you are comparing multiple models based on their cross validation errors, your flow (attached as a zip file)should look like this:
Following table shows part of the output table that is produced by the Model Comparison:
Note that because the Model Import node is used, cross validation error is listed as Train: Average Squared Error. But do not let the ‘Train:’ part confuse you --the Model Import uses the score code generated by the Start/End Groups node in the way we specified in (2), so it is actually the cross validation error.The output table above shows thatcross validation error of the gradient boostingmodel is the smallest. If you choose thismodelto make prediction for a new data set,make suretouse the score code generated byyour initial modeling nodewhich buildsthe modelon thefull training set, instead of the kmodels thatarebuiltby theStart and End Groups nodes for the purpose of calculating the cross validation error.
I build on this diagram in another tip, Assessing Models by using k-fold Cross Validation in SAS Enterprise Miner, which shows how to obtain a 5-fold cross validation testing error, providing a more complete SAS Enterprise Miner flow.