The goal of this notebook is to train a classifier that predicts the “target” variable.
For this, the three essential steps are:
A small outlook of the structure:
The goal of the exploratory data analysis is to pre-identify key metrics that help us predict the target. We will start with some basics on data structure and column values, then move on to looking at the distrubtion of the numeric columns. We do this first on a higher level, and then on a more granular level. Afterwards we will investigate the time component of the data. These steps are necessary to decide, which model approach to use. The appendix contains an additional exploration of the time series conponent.
Before we dive into the analysis, let’s load necessary packages, fix some settings, add functions we’ll use and load the data.
#### install and load packages ####
libraries = c("knitr","lubridate", "rms", "markdown","mime","rmarkdown","tinytex","data.table","lattice","latticeExtra","Hmisc","DT","scales","ggplot2","forecast","rpart","rpart.plot","randomForest", "recipes", "caret", "mlbench", "themis", "ROSE", "mltools", "MLeval")
lapply(libraries, function(x) if (!(x %in% installed.packages())) {install.packages(x, dependencies = TRUE, repos = "http://cran.us.r-project.org")} )
invisible(lapply(libraries, library, quietly = TRUE, character.only = TRUE))
## ##
#### settings ####
Sys.setenv(LANG = "en") # set environment language to English
Sys.setlocale("LC_TIME", "en_US.UTF-8") # set timestamp language to English
Sys.setlocale("LC_TIME", "English") # set timestamp language to English
## ##
#### load data ####
data <- fread("data.csv")
## ##
#### functions ####
plot_theme <- theme(panel.border = element_blank(),
axis.text = element_text(size = 20, face = "bold"),
axis.title = element_text(size = 24, face = "bold"),
strip.text = element_text(size = 20, face = "bold"),
plot.title = element_text(size = 24, face = "bold", hjust = .5),
panel.background = element_blank(), # bg of the panel
plot.background = element_blank(), # bg of the plot
panel.grid.major = element_blank(), # get rid of major grid
panel.grid.minor = element_blank(), # get rid of minor grid
legend.position = "none")
plot_theme_legend <- theme(panel.border = element_blank(),
axis.text = element_text(size = 20, face = "bold"),
axis.title = element_text(size = 24, face = "bold"),
strip.text = element_text(size = 20, face = "bold"),
plot.title = element_text(size = 24, face = "bold", hjust = .5),
panel.background = element_blank(), # bg of the panel
plot.background = element_blank(), # bg of the plot
panel.grid.major = element_blank(), # get rid of major grid
panel.grid.minor = element_blank(), # get rid of minor grid
legend.background = element_rect(fill = "transparent"), # get rid of legend bg
legend.box.background = element_rect(fill = "transparent")) # get rid of legend panel bg
## to get histograms for different variables and columns let's wrap it up in a small function for the sake of readability ## used in {r histogram plots} ##
get_hist <- function(DATA, CATEGORY = "all", HIST_COL = NA){
## separate if all categories are chosen or only a specific one (this is mainly for aesthetic reasons)
if (CATEGORY == "all"){
# subset the data to only include values with the specified TARGET and get the column of which we want to plot the histogram
dist_data <- DATA[, c(..HIST_COL, "target")]
dist_data[, target := as.factor(target)] # do this to make sure the target is a factor
names(dist_data) <- c("hist_col", "Target") # to pass it to ggplot2 (avoids writing functions for passing the externally provided variable name)
# plot the histogram
plot_histo <- ggplot(dist_data, aes(x=hist_col, group = Target, fill = Target)) +
geom_histogram(alpha=0.7, position="identity", color = "black") + # if yend = Poisson
plot_theme_legend +
labs(colour="Legend", x = "Value", y = "Count", title = paste("Category '", CATEGORY, "' | column '", HIST_COL, "'", sep = ""))
} else if (CATEGORY %in% c("a","b","c")) {
# subset the data to only include values with the specified TARGET and get the column of which we want to plot the histogram
dist_data <- DATA[categorical0 %in% CATEGORY, c(..HIST_COL, "target")]
dist_data[, target := as.factor(target)] # do this to make sure the target is a factor
names(dist_data) <- c("hist_col", "Target") # to pass it to ggplot2 (avoids writing functions for passing the externally provided variable name)
# plot the histogram
plot_histo <- ggplot(dist_data, aes(x=hist_col, group = Target, color = Target, fill = Target)) +
geom_histogram(alpha=0.7, position="identity", color = "black") + # if yend = Poisson
plot_theme_legend +
labs(colour="Legend", x = "Value", y = "Count", title = paste("Category '", CATEGORY, "' | column '", HIST_COL, "'", sep = ""))
# just an error handler
} else {
return("Error: invalid category, maybe a typo?")
}
return(plot_histo)
}
# For the sake of readability, let's write a short function to count the number of observations with target = 1 per time variable and print the output # used in {r create features} #
get_N_per_VAR <- function(VAR){
N_per_var <- data[,.N,by=get(VAR)] # count number of values for the passed variable VAR
N_pos_per_var <- data[target==1,.N,by=get(VAR)];setkey(N_pos_per_var,get)
N_per_var[, pct_target := round((N_pos_per_var[,N] / N ), digits = 2)]
names(N_per_var) <- c(VAR,"N", "pct_target") # make sure names are correct
setkeyv(N_per_var,VAR) # sort by the passed variable VAR
return(N_per_var)
}
create_train_test <- function(data, size = 0.8, train = TRUE) {
n_row = nrow(data)
total_row = size * n_row
train_sample <- 1:total_row
if (train == TRUE) {
return (data[train_sample, ])
} else {
return (data[-train_sample, ])
}
}
As a first step, let’s just print out the data table to understand what we’re dealing with.
nrow(data) # show number of rows
## [1] 10000
datatable(head(data,(nrow(data)*0.01))) # print data
Remember that we want to predict the variable “target”. This seems to be a time series problem, but the observations are unordered. In total we have 10,000 observations. Good to know, but it probably makes sense to order the data by date and time.
Also, we would probably like to plot the data, but as of now, we can’t make any judgement of what to plot ideally. Therefore, let’s have a look at the structure of it. This might look a bit overwhelming at first, but we want to understand what columns the data has, of which type they are, and what kind of information they contain.
setkey(data, date, time) # order by date & time
str(data) # print out structure of the data
## Classes 'data.table' and 'data.frame': 10000 obs. of 6 variables:
## $ date : IDate, format: "1970-01-01" "1970-01-02" ...
## $ numeric0 : int 7382 192 4291 4226 9735 5560 4821 762 3942 8473 ...
## $ numeric1 : int NA 6 NA 5 NA NA 1 7 6 7 ...
## $ categorical0: chr "c" "b" "b" "a" ...
## $ time : chr "13:13:57" "23:37:45" "22:39:45" "02:13:50" ...
## $ target : int 0 0 0 0 0 0 1 1 0 1 ...
## - attr(*, ".internal.selfref")=<externalptr>
## - attr(*, "sorted")= chr [1:2] "date" "time"
This just outputs what type each column has and prints the head of it. Also, we can see that the order by date & date is correct, from earliest to latest. Moreover, the time variable should be formatted to a timestamp format if we want to plot the data.
html(describe(data)) # print out a description of the columns of the data
n missing distinct Info Mean Gmd .05 10000 0 7664 1 1994-09-30 5994 1972-06-13 .10 .25 .50 .75 .90 .95 1974-12-22 1982-09-16 1994-10-04 2006-12-01 2014-06-14 2017-05-02
lowest : | 1970-01-01 | 1970-01-02 | 1970-01-05 | 1970-01-06 | 1970-01-08 |
highest: | 2019-08-13 | 2019-08-17 | 2019-08-18 | 2019-08-22 | 2019-08-24 |
n | missing | distinct | Info | Mean | Gmd | .05 | .10 | .25 | .50 | .75 | .90 | .95 |
---|---|---|---|---|---|---|---|---|---|---|---|---|
10000 | 0 | 6287 | 1 | 4982 | 3333 | 478 | 1012 | 2500 | 4940 | 7487 | 8996 | 9539 |
n | missing | distinct | Info | Mean | Gmd | .05 | .10 | .25 | .50 | .75 | .90 | .95 |
---|---|---|---|---|---|---|---|---|---|---|---|---|
5083 | 4917 | 10 | 0.99 | 4.54 | 3.29 | 0 | 1 | 2 | 5 | 7 | 9 | 9 |
Value 0 1 2 3 4 5 6 7 8 9 Frequency 484 484 541 490 522 516 526 468 527 525 Proportion 0.095 0.095 0.106 0.096 0.103 0.102 0.103 0.092 0.104 0.103
n | missing | distinct |
---|---|---|
10000 | 0 | 3 |
Value a b c Frequency 3324 3341 3335 Proportion 0.332 0.334 0.334
n | missing | distinct |
---|---|---|
10000 | 0 | 9455 |
n | missing | distinct | Info | Sum | Mean | Gmd |
---|---|---|---|---|---|---|
10000 | 0 | 2 | 0.346 | 1330 | 0.133 | 0.2306 |
The above outputs a description of each column along with some summary statistics.
Combined with the structure of the data there are a few things we can already observe:
To get closer to identifying key features for predicting the target variable, we could plot the distribution of numerical0 and numerical1, separated by target values. It probably makes sense to also plot the data separated by categories a,b,c.
The goal of plotting now is to understand the hierarchy of the data, because that’s gonna help us come up with a good, but also explainable approach. At this point we could probably already set up a blackbox and make predictions, but that won’t help our business stakeholders understand what’s going on.
Now that we’ve set up everything to loop over the different combinations, we can print out the histograms. At this point, we already built some kind of decision tree for plotting histograms. In that sense, it seems as if a tree based approach could do the job as we’re separating the data by categories and numeric columns.
The reason why we’re doing this separation as a first step is that there are some dangers in the data:
To have more control over what we’re going to model later, we will thus get a quick overview and see if we can already find some patterns.
## specify values of interest so that we can loop over them ##
unique_category <- c(data[,sort(unique(categorical0))]) # unique values of "categorical0" column
hist_cols <- c("numeric0", "numeric1") # the numeric columns
plots <- lapply(hist_cols, function(i) {
get_hist(data, CATEGORY = "all", HIST_COL = i)
})
for (i in 1:length(plots)) {
cat("#### Histogram",i,"\n")
print(plots[[i]])
cat('\n\n')
}
The plots show the distribution of values by ‘target’ state and per numerical column.
To illustrate the difference between categories, we will look at the number of values with state = 1 per category. Recall that every category has ~3300 observations.
datatable(data[target==1,.N, by = "categorical0"])
Clearly, if an observation is of category a, it is much more like to be in state 1 than b or c.
But are their characteristics significantly different?
hists <- lapply(hist_cols, function(i) {
lapply(unique_category, function(cat) {
get_hist(data, CATEGORY = cat, HIST_COL = i)
})
})
for (i in 1:length(hists)) {
for (j in 1:length(hists[[i]])) {
cat("### Histogram of column",hist_cols[i], "|", unique_category[j], "\n")
print(hists[[i]][[j]])
cat('\n\n')
}
}
The plots show the distribution of values by ‘target’ state and per numerical column, separated by categories.
The large class imbalance is a problem we will have to address. We could e.g. address it by:
We will now turn to investigating the time component to see how relevant it is. For this, we will have to create time series objects. First, let’s check how many gaps the data has.
# create a series from start date to end date to capture all possible dates #
start <- data[,min(date)]
end <- data[,max(date)]
start_to_end <- data.table("date" = seq(start,end,by = "1 day"))
# match it with our data and preallocate a column for missing dates #
start_to_end[data, target := i.target, on = "date"][, date_missing := 0]
start_to_end[is.na(target), date_missing := 1] # set the missing date column to '1' where the target value is NA
# plot all values where target = NA (we have a missing observation) #
datatable(start_to_end[,.N,by = "date_missing"])
The above table shows how often the date is missing. A value of 1 indicates here that the date is missing.
More often than not, we have gaps in the observations. There are two ways to approach this problem:
In our case, it probably makes more sense to go with the second approach, as our goal is to predict the state of any row given the information we have about different features. We do this because we’re hoping to find some time related patterns.
We want to predict the target state of each observation based on date, time, category and two numerical columns.
A simple and often useful approach is to classify observations based on a decision tree. It is a great starting point, because it’s a simple, intuitive and transparent method.
With the insights gained, fitting a Decision Tree is now straight forward. We need to perform some data wrangling, account for the large difference in target vs non-target values, and split the data into a train and test set before we can fit a tree. We impute missing values by the mean of the data. When imputing, one should generally be careful, as this technique is only consistent when missing values are not informative (Josse et al. 2020).
## For fitting the tree remove date / time columns as they yield no information ##
tree_data <- data[,!c("date", "time", "date_time")]
# Also, we impute missing values by the mean of the data #
mean_numeric_1 <- tree_data[,mean(numeric1, na.rm = TRUE)]
tree_data[is.na(numeric1), numeric1 := mean_numeric_1]
## we have already seen that we have way more 0 than 1 target values ##
# Therefore, we count the number of values by state #
DT_balance <- tree_data[,.N,by = "target"]
# Next, we will calculate the ratio of target vs non-target values, which can then be passed as priors to the tree #
N_target_0 <- DT_balance[target == 0,N] / nrow(tree_data)
N_target_1 <- 1 - N_target_0
weights <- c(N_target_1, N_target_0)
## Before we fit the model, we need to do a split between train and test data ##
## When there is no time component involved, it makes sense to take random samples ##
## However, in our case, we will split at 80% of the data and predict the last 20% ##
tree_data_train <- create_train_test(tree_data, 0.8, train = TRUE)
tree_data_test <- create_train_test(tree_data, 0.8, train = FALSE)
## Finally, we can fit the tree ##
tree_fit <- rpart(target~.,
data = tree_data_train,
parms = list(prior = weights), # priors due to class imbalance
method = "class") # method = "class" because target is either 0 or 1
tree_plot <- rpart.plot(tree_fit,
type = 4,
extra = 106)
The above plot shows the fitted Decision Tree from our training data. Given our exploratory analysis, this bears little surprises. The tree shows that the separation is first done by category, and then by ‘numeric0’ value. The tree did not consider other variables such as ‘numeric1’ or the time variables important enough.
Intuitively, the tree can be interpreted as follows:
Based on these rules we will evaluate the predictive power of this model. Since our focus is on a balance between predictive accuracy and simplicity, the few rules should not be an issue for us.
To measure the accuracy, we will evaluate how many
we have. This can be depicted in a so called confusion matrix.
Based on the ratio of \(\frac{TP + TN}{TP + TN + FP + FN}\) we can then calculate the accuracy.
predict_unseen <- predict(tree_fit, tree_data_test, type = 'class')
confusionMatrix(tree_data_test[,as.factor(target)], predict_unseen)
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 1243 505
## 1 6 246
##
## Accuracy : 0.7445
## 95% CI : (0.7248, 0.7635)
## No Information Rate : 0.6245
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.372
##
## Mcnemar's Test P-Value : < 2.2e-16
##
## Sensitivity : 0.9952
## Specificity : 0.3276
## Pos Pred Value : 0.7111
## Neg Pred Value : 0.9762
## Prevalence : 0.6245
## Detection Rate : 0.6215
## Detection Prevalence : 0.8740
## Balanced Accuracy : 0.6614
##
## 'Positive' Class : 0
##
The output shows the confusion matrix and the accuracy calculation. The table reads as follows:
This again bears little surprises:
We learned from the distribution of values, that observations in state 0 are mostly large and mostly in categories b & c. Therefore, it is easy to be accurate: less than 1% of values are misclassified.
We also saw that it’s much harder to identify state 1 observations, since they’re mainly small, but not always. Equally, state 0 observations are mainly large, but not always. The tree decided that if an observation is in category a, its correct label is most likely state = 1. However, only 1/3 of observations in category a are actually in state 1. Consequently, only roughly 1/3 of values are correctly classified.
While we have an overall satisfactory accuracy due to the large number of correct classifications in the majority group state 0, we can try to improve accuracy and reduce the number of False Negatives by increasing model complexity. A typical extension of Decision Tree models is the Random Forest algorithm. While it is more of a blackbox approach than a Decision Tree, we hope that it offers some additional insights through the feature importance and that the additional complexity increases our model accuracy.
The basic idea of a Random Forest is to generate many Decision Trees through randomly drawing from the original data (Bootstrapping), and using only a subset of the total number of features available at each step. A committee of many classifiers then makes a majority vote on the class of each observation. This has the advantage that it reduces the variance of predictions through combining many trees and decorrelating it through subsetting the features.
# set random seed for reproducing results #
set.seed(1234)
rf_data <- copy(tree_data)
# refactor the target column as 0 / 1 values do not work with the algorithm in R
rf_data[,target := as.character(target)]
rf_data[target %in% "0",target := "A"]
rf_data[target %in% "1",target := "B"]
rf_data[,target := as.factor((target))]
# create folds #
cv_folds <- createFolds(rf_data[,target], k = 5, returnTrain = TRUE)
# create tune control with upsampling for handling class imbalance and 5-fold cross validation during training #
tuneGrid <- expand.grid(.mtry = c(1 : 10))
ctrl <- trainControl(method = "cv",
number = 5,
search = 'grid',
classProbs = TRUE,
savePredictions = "final",
index = cv_folds,
summaryFunction = twoClassSummary,
sampling = "up")
# specify tuning parameters #
ntrees <- c(100,500,1000)
nodesize <- c(1,5,10)
params <- expand.grid(ntrees = ntrees,
nodesize = nodesize)
# train the model in a grid search #
# this may take a while, in a real world scenario #
# we would ideally want to move this into a cloud environment #
# addtionally, this could be sped up using parallelization #
store_maxnode <- vector("list", nrow(params))
for(i in 1:nrow(params)){
nodesize <- params[i,2]
ntree <- params[i,1]
set.seed(123)
rf_model <- train(target~.,
data = rf_data,
method = "rf",
importance=TRUE,
metric = "ROC",
tuneGrid = tuneGrid,
trControl = ctrl,
ntree = ntree,
nodesize = nodesize)
store_maxnode[[i]] <- rf_model
}
# get unique names for experiments #
names(store_maxnode) <- paste("ntrees:", params$ntrees,
"nodesize:", params$nodesize)
# combine results and print output #
results_mtry <- resamples(store_maxnode)
summary(results_mtry)
##
## Call:
## summary.resamples(object = results_mtry)
##
## Models: ntrees: 100 nodesize: 1, ntrees: 500 nodesize: 1, ntrees: 1000 nodesize: 1, ntrees: 100 nodesize: 5, ntrees: 500 nodesize: 5, ntrees: 1000 nodesize: 5, ntrees: 100 nodesize: 10, ntrees: 500 nodesize: 10, ntrees: 1000 nodesize: 10
## Number of resamples: 5
##
## ROC
## Min. 1st Qu. Median Mean 3rd Qu.
## ntrees: 100 nodesize: 1 0.8632676 0.8678769 0.8726693 0.8703615 0.8731864
## ntrees: 500 nodesize: 1 0.8612893 0.8704373 0.8730542 0.8708334 0.8734314
## ntrees: 1000 nodesize: 1 0.8614714 0.8708807 0.8734184 0.8710650 0.8734975
## ntrees: 100 nodesize: 5 0.8643159 0.8670346 0.8707896 0.8697045 0.8715962
## ntrees: 500 nodesize: 5 0.8638855 0.8702552 0.8703734 0.8705386 0.8732829
## ntrees: 1000 nodesize: 5 0.8632351 0.8698617 0.8719149 0.8705557 0.8719355
## ntrees: 100 nodesize: 10 0.8593944 0.8692341 0.8716103 0.8703070 0.8752016
## ntrees: 500 nodesize: 10 0.8637229 0.8678173 0.8719788 0.8704668 0.8720991
## ntrees: 1000 nodesize: 10 0.8614594 0.8677273 0.8730921 0.8701071 0.8733165
## Max. NA's
## ntrees: 100 nodesize: 1 0.8748070 0
## ntrees: 500 nodesize: 1 0.8759550 0
## ntrees: 1000 nodesize: 1 0.8760569 0
## ntrees: 100 nodesize: 5 0.8747864 0
## ntrees: 500 nodesize: 5 0.8748959 0
## ntrees: 1000 nodesize: 5 0.8758314 0
## ntrees: 100 nodesize: 10 0.8760949 0
## ntrees: 500 nodesize: 10 0.8767160 0
## ntrees: 1000 nodesize: 10 0.8749404 0
##
## Sens
## Min. 1st Qu. Median Mean 3rd Qu.
## ntrees: 100 nodesize: 1 0.9377163 0.9411765 0.9417532 0.9429066 0.9434833
## ntrees: 500 nodesize: 1 0.9106113 0.9215686 0.9227220 0.9235294 0.9250288
## ntrees: 1000 nodesize: 1 0.9106113 0.9163783 0.9215686 0.9228374 0.9250288
## ntrees: 100 nodesize: 5 0.8477509 0.8587082 0.8615917 0.8638985 0.8731257
## ntrees: 500 nodesize: 5 0.8615917 0.8633218 0.8644752 0.8658593 0.8696655
## ntrees: 1000 nodesize: 5 0.9354095 0.9365629 0.9371396 0.9395617 0.9423299
## ntrees: 100 nodesize: 10 0.8748558 0.8794694 0.8823529 0.8844291 0.8875433
## ntrees: 500 nodesize: 10 0.9129181 0.9140715 0.9158016 0.9166090 0.9181084
## ntrees: 1000 nodesize: 10 0.9025375 0.9059977 0.9083045 0.9077278 0.9088812
## Max. NA's
## ntrees: 100 nodesize: 1 0.9504037 0
## ntrees: 500 nodesize: 1 0.9377163 0
## ntrees: 1000 nodesize: 1 0.9405998 0
## ntrees: 100 nodesize: 5 0.8783160 0
## ntrees: 500 nodesize: 5 0.8702422 0
## ntrees: 1000 nodesize: 5 0.9463668 0
## ntrees: 100 nodesize: 10 0.8979239 0
## ntrees: 500 nodesize: 10 0.9221453 0
## ntrees: 1000 nodesize: 10 0.9129181 0
##
## Spec
## Min. 1st Qu. Median Mean 3rd Qu.
## ntrees: 100 nodesize: 1 0.3308271 0.3421053 0.3571429 0.3631579 0.3834586
## ntrees: 500 nodesize: 1 0.3721805 0.3759398 0.4097744 0.4105263 0.4323308
## ntrees: 1000 nodesize: 1 0.3721805 0.3834586 0.4210526 0.4165414 0.4360902
## ntrees: 100 nodesize: 5 0.5225564 0.5413534 0.5714286 0.5616541 0.5789474
## ntrees: 500 nodesize: 5 0.5413534 0.5451128 0.5526316 0.5556391 0.5639098
## ntrees: 1000 nodesize: 5 0.3458647 0.3533835 0.3721805 0.3744361 0.3947368
## ntrees: 100 nodesize: 10 0.4699248 0.4924812 0.4924812 0.5105263 0.5112782
## ntrees: 500 nodesize: 10 0.3947368 0.4097744 0.4285714 0.4270677 0.4323308
## ntrees: 1000 nodesize: 10 0.4248120 0.4398496 0.4398496 0.4571429 0.4661654
## Max. NA's
## ntrees: 100 nodesize: 1 0.4022556 0
## ntrees: 500 nodesize: 1 0.4624060 0
## ntrees: 1000 nodesize: 1 0.4699248 0
## ntrees: 100 nodesize: 5 0.5939850 0
## ntrees: 500 nodesize: 5 0.5751880 0
## ntrees: 1000 nodesize: 5 0.4060150 0
## ntrees: 100 nodesize: 10 0.5864662 0
## ntrees: 500 nodesize: 10 0.4699248 0
## ntrees: 1000 nodesize: 10 0.5150376 0
This code chunk shows the tuning of the RF model and some performance metrics. We addressed class imbalance by upsampling the minority class. Other approaches would be e.g. downsampling the minority class or using algorithms such as “SMOTE” or “ROSE”. In general, there is no consensus which method is the best practice, and it varies from use case to use case which approach should be preferred. We use 5-fold cross validation to evaluate the models. When using cross validation, we repeatedly divide the data into folds where part of the observations are being withheld during training and then used as test data. Note that in this case, we do not additionally make a train-test split. Diebold (2015) argues that (pseudo-out-of-sample) approaches are consistent only if the withheld data is asymptotically irrelevant. I.e. in small data cases, full-sample fitting is preferrable. Instead of using accuracy as a performance metrics we turned to investigating the area under the Receiver Operator Curve, which is often a better way to evaluate the predictive power of a model. It gives low scores both to random and to one class only classifiers. Additionally, we measure Sensitivity (True Positive Rate) and Specificity (True Negative Rate) that indicate the ability of a classifier to detect positive (negative) examples:
Where AP = All Positives and AN = All Negatives.
The results show that the model with ntrees=500
and nodesize=5
maintains a balance between sensitivity and specificity while having a good ROC value. In general, the ROC value of models doesn’t seem to differ a lot, but sensitivity and specificity do vary a lot. Expectedly, these values get better with a higher number of trees. The biggest model with ntrees=1000
however does not seem to perform significantly better than the one with ntrees=500
Note that Random Forests aren’t prone to overfitting, but the accuracy converges after a certain amount of trees and additional trees will provide little to none additional predictive power.
# get variable importance #
plot(varImp(store_maxnode$`ntrees: 500 nodesize: 5`))
The feature importance plot confirms that the categories and the ‘numeric0’ are our main features of interest. Other variables only play a minor role. However, the meaning of this plot should not be overestimated as little is known about their theoretical properties (Scornet, Biau, and Vert 2015).
# get performance metrics #
fit_eval <- evalm(store_maxnode$`ntrees: 500 nodesize: 5`, silent = TRUE, showplots = FALSE)
confusionMatrix(store_maxnode$`ntrees: 500 nodesize: 5`$pred$pred,store_maxnode$`ntrees: 500 nodesize: 5`$pred$obs)
## Confusion Matrix and Statistics
##
## Reference
## Prediction A B
## A 7507 591
## B 1163 739
##
## Accuracy : 0.8246
## 95% CI : (0.817, 0.832)
## No Information Rate : 0.867
## P-Value [Acc > NIR] : 1
##
## Kappa : 0.3566
##
## Mcnemar's Test P-Value : <2e-16
##
## Sensitivity : 0.8659
## Specificity : 0.5556
## Pos Pred Value : 0.9270
## Neg Pred Value : 0.3885
## Prevalence : 0.8670
## Detection Rate : 0.7507
## Detection Prevalence : 0.8098
## Balanced Accuracy : 0.7107
##
## 'Positive' Class : A
##
## get roc curve plotted in ggplot2
fit_eval$roc
The above plot shows the ROC curve, as well as the area under the ROC curve (AUC). It visualizes the tradeoff in any classifier between True Positives and False Positives. A perfect classifier would be in the upper left corner. The \(AUC-ROC\) score ranges from 0 to 1, where 1 is again a perfect classifier. The obtained score of 0.87 indicates that the classifier’s predictive power is satisfactory. Finally, as work on the consistency and theoeretical properties in general of the RF estimator is still in an early stage (Scornet, Biau, and Vert 2015), any performance metrics can only be an indicator of a good model fit and the goal of using the ROC curve is not to obtain a perfect model, but one that serves our purpose well enough. We left the probability threshold at the default of 0.5. This value can be varied to either tune the algorithm towards sensitivity or specificity and helps us tackle class imbalance.
While we do not see a great improvement in terms of accuracy compared with the baseline Decision Tree, we have become better in predicting values in the minority class. The Random Forest model clearly performs better, but does not give us the desired outcome just yet.
Finally, we will use XGBoost for making predictions. The concept of boosting is at a first glance very similar to that of Random Forests (or Bagging in general). A committee of “weak” classifiers (barely better than chance) is combined to make a majority vote on the predicted class of each observation. In contrast to Random Forests, this is however an iterative procedure. While iterating, the algorithm emphasizes missclassified observations in order to learn difficult patterns. Boosting has proven to be one of the most powerful classifiers of the last decade (Hastie, Tibshirani, and Friedman 2009). XGBoost is one of the most recent implementations, and we will use it in this example.
# make train data set with one hot encoding (xboost only accepts numerical values) #
train <- tree_data[, !"target"]
cols <- c("categorical0", "year", "quarter", "month", "week", "weekday", "hour")
train[, (cols) := lapply(.SD, factor), .SDcols = cols]
train <- one_hot(train)
colnames_train <- names(train)
train <- matrix(as.numeric(unlist(train)), nrow = nrow(train))
colnames(train) <- colnames_train
# control parameters #
ctrl_xgb <- trainControl(method = "cv",
number = 5,
search = 'grid',
classProbs = TRUE,
savePredictions = "final",
index = cv_folds,
summaryFunction = twoClassSummary)
# calculate weights for observations #
xgb_weights <- ifelse(rf_data[,target] == "B",
table(rf_data[,target])[1]/nrow(rf_data),
table(rf_data[,target])[2]/nrow(rf_data))
# get tuning grid #
tuneGrid_xgb <- expand.grid(.nrounds = c(100, 250, 500),
.max_depth = c(1,3,6),
.eta = c(0.01,0.025,0.1,0.3),
.gamma = c(3),
.colsample_bytree = c(0.6,0.8,1),
.subsample = c(0.75),
.min_child_weight = c(1))
# fit the model #
# this again may take a while #
xgb_model <- train(x = train,
y = rf_data[,target],
method = "xgbTree",
trControl = ctrl_xgb,
tuneGrid = tuneGrid_xgb,
weights = xgb_weights,
verbose = TRUE,
metric = "ROC",
verbosity = 0,
allowParallel = TRUE)
xgb_model
## eXtreme Gradient Boosting
##
## 10000 samples
## 150 predictor
## 2 classes: 'A', 'B'
##
## No pre-processing
## Resampling: Cross-Validated (5 fold)
## Summary of sample sizes: 8000, 8000, 8000, 8000, 8000
## Resampling results across tuning parameters:
##
## eta max_depth colsample_bytree nrounds ROC Sens Spec
## 0.010 1 0.6 100 0.8704152 0.7462514 0.8451128
## 0.010 1 0.6 250 0.8709397 0.7305652 0.8804511
## 0.010 1 0.6 500 0.8715398 0.6936563 0.9789474
## 0.010 1 0.8 100 0.8695285 0.7462514 0.8451128
## 0.010 1 0.8 250 0.8693574 0.7162630 0.9240602
## 0.010 1 0.8 500 0.8694910 0.6936563 0.9789474
## 0.010 1 1.0 100 0.8696196 0.7462514 0.8451128
## 0.010 1 1.0 250 0.8698491 0.7056517 0.9458647
## 0.010 1 1.0 500 0.8695222 0.6936563 0.9789474
## 0.010 3 0.6 100 0.8695970 0.6931949 0.9812030
## 0.010 3 0.6 250 0.8692070 0.6931949 0.9812030
## 0.010 3 0.6 500 0.8660559 0.6933103 0.9812030
## 0.010 3 0.8 100 0.8688544 0.6930796 0.9819549
## 0.010 3 0.8 250 0.8679259 0.6933103 0.9819549
## 0.010 3 0.8 500 0.8700150 0.6935409 0.9812030
## 0.010 3 1.0 100 0.8690654 0.6930796 0.9812030
## 0.010 3 1.0 250 0.8694591 0.6933103 0.9812030
## 0.010 3 1.0 500 0.8686901 0.6933103 0.9812030
## 0.010 6 0.6 100 0.8703654 0.6940023 0.9789474
## 0.010 6 0.6 250 0.8684646 0.6933103 0.9812030
## 0.010 6 0.6 500 0.8657613 0.6936563 0.9812030
## 0.010 6 0.8 100 0.8694049 0.6931949 0.9812030
## 0.010 6 0.8 250 0.8674959 0.6931949 0.9812030
## 0.010 6 0.8 500 0.8688137 0.6936563 0.9804511
## 0.010 6 1.0 100 0.8685780 0.6930796 0.9812030
## 0.010 6 1.0 250 0.8687718 0.6933103 0.9812030
## 0.010 6 1.0 500 0.8687350 0.6933103 0.9812030
## 0.025 1 0.6 100 0.8690888 0.7200692 0.9127820
## 0.025 1 0.6 250 0.8697169 0.6936563 0.9789474
## 0.025 1 0.6 500 0.8692694 0.6936563 0.9781955
## 0.025 1 0.8 100 0.8701934 0.7095732 0.9300752
## 0.025 1 0.8 250 0.8687575 0.6936563 0.9789474
## 0.025 1 0.8 500 0.8686849 0.6937716 0.9781955
## 0.025 1 1.0 100 0.8698509 0.7159170 0.9157895
## 0.025 1 1.0 250 0.8691610 0.6936563 0.9781955
## 0.025 1 1.0 500 0.8690472 0.6936563 0.9781955
## 0.025 3 0.6 100 0.8703328 0.6933103 0.9812030
## 0.025 3 0.6 250 0.8695667 0.6936563 0.9804511
## 0.025 3 0.6 500 0.8702314 0.6936563 0.9804511
## 0.025 3 0.8 100 0.8664986 0.6933103 0.9812030
## 0.025 3 0.8 250 0.8678270 0.6933103 0.9812030
## 0.025 3 0.8 500 0.8690804 0.6936563 0.9812030
## 0.025 3 1.0 100 0.8722219 0.6933103 0.9812030
## 0.025 3 1.0 250 0.8694999 0.6933103 0.9812030
## 0.025 3 1.0 500 0.8693635 0.6933103 0.9812030
## 0.025 6 0.6 100 0.8694747 0.6933103 0.9827068
## 0.025 6 0.6 250 0.8672995 0.6936563 0.9796992
## 0.025 6 0.6 500 0.8688499 0.6936563 0.9789474
## 0.025 6 0.8 100 0.8677550 0.6933103 0.9812030
## 0.025 6 0.8 250 0.8686014 0.6936563 0.9812030
## 0.025 6 0.8 500 0.8693858 0.6936563 0.9796992
## 0.025 6 1.0 100 0.8710372 0.6931949 0.9812030
## 0.025 6 1.0 250 0.8689791 0.6933103 0.9812030
## 0.025 6 1.0 500 0.8682220 0.6933103 0.9812030
## 0.100 1 0.6 100 0.8689509 0.6941176 0.9789474
## 0.100 1 0.6 250 0.8701243 0.6937716 0.9789474
## 0.100 1 0.6 500 0.8709928 0.6937716 0.9789474
## 0.100 1 0.8 100 0.8695465 0.6936563 0.9789474
## 0.100 1 0.8 250 0.8701065 0.6936563 0.9789474
## 0.100 1 0.8 500 0.8701876 0.6937716 0.9781955
## 0.100 1 1.0 100 0.8689936 0.6936563 0.9781955
## 0.100 1 1.0 250 0.8699547 0.6937716 0.9781955
## 0.100 1 1.0 500 0.8701167 0.6937716 0.9781955
## 0.100 3 0.6 100 0.8693401 0.6936563 0.9804511
## 0.100 3 0.6 250 0.8696302 0.6936563 0.9804511
## 0.100 3 0.6 500 0.8695521 0.6937716 0.9804511
## 0.100 3 0.8 100 0.8686034 0.6936563 0.9804511
## 0.100 3 0.8 250 0.8712959 0.6936563 0.9789474
## 0.100 3 0.8 500 0.8714173 0.6936563 0.9812030
## 0.100 3 1.0 100 0.8697301 0.6936563 0.9804511
## 0.100 3 1.0 250 0.8699417 0.6936563 0.9804511
## 0.100 3 1.0 500 0.8705210 0.6936563 0.9804511
## 0.100 6 0.6 100 0.8702051 0.6940023 0.9804511
## 0.100 6 0.6 250 0.8708211 0.6940023 0.9789474
## 0.100 6 0.6 500 0.8707918 0.6942330 0.9789474
## 0.100 6 0.8 100 0.8687333 0.6935409 0.9804511
## 0.100 6 0.8 250 0.8705113 0.6936563 0.9804511
## 0.100 6 0.8 500 0.8716896 0.6938870 0.9796992
## 0.100 6 1.0 100 0.8687950 0.6935409 0.9796992
## 0.100 6 1.0 250 0.8686741 0.6936563 0.9796992
## 0.100 6 1.0 500 0.8688312 0.6941176 0.9774436
## 0.300 1 0.6 100 0.8693366 0.6938870 0.9781955
## 0.300 1 0.6 250 0.8679428 0.6940023 0.9766917
## 0.300 1 0.6 500 0.8680460 0.6940023 0.9781955
## 0.300 1 0.8 100 0.8710008 0.6941176 0.9774436
## 0.300 1 0.8 250 0.8702179 0.6937716 0.9781955
## 0.300 1 0.8 500 0.8705726 0.6943483 0.9751880
## 0.300 1 1.0 100 0.8706130 0.6935409 0.9789474
## 0.300 1 1.0 250 0.8717074 0.6935409 0.9789474
## 0.300 1 1.0 500 0.8690201 0.6943483 0.9759398
## 0.300 3 0.6 100 0.8714459 0.6936563 0.9789474
## 0.300 3 0.6 250 0.8722904 0.6942330 0.9789474
## 0.300 3 0.6 500 0.8720484 0.6945790 0.9774436
## 0.300 3 0.8 100 0.8698567 0.6943483 0.9774436
## 0.300 3 0.8 250 0.8707717 0.6943483 0.9789474
## 0.300 3 0.8 500 0.8718388 0.6948097 0.9759398
## 0.300 3 1.0 100 0.8682116 0.6933103 0.9819549
## 0.300 3 1.0 250 0.8696488 0.6940023 0.9796992
## 0.300 3 1.0 500 0.8711456 0.6948097 0.9796992
## 0.300 6 0.6 100 0.8709646 0.6942330 0.9804511
## 0.300 6 0.6 250 0.8709930 0.6948097 0.9751880
## 0.300 6 0.6 500 0.8694216 0.6968858 0.9669173
## 0.300 6 0.8 100 0.8714205 0.6946943 0.9759398
## 0.300 6 0.8 250 0.8700263 0.6963091 0.9691729
## 0.300 6 0.8 500 0.8687794 0.7002307 0.9548872
## 0.300 6 1.0 100 0.8710843 0.6953864 0.9766917
## 0.300 6 1.0 250 0.8711346 0.6967705 0.9744361
## 0.300 6 1.0 500 0.8689589 0.7000000 0.9646617
##
## Tuning parameter 'gamma' was held constant at a value of 3
## Tuning
## parameter 'min_child_weight' was held constant at a value of 1
##
## Tuning parameter 'subsample' was held constant at a value of 0.75
## ROC was used to select the optimal model using the largest value.
## The final values used for the model were nrounds = 250, max_depth = 3, eta
## = 0.3, gamma = 3, colsample_bytree = 0.6, min_child_weight = 1 and subsample
## = 0.75.
Looking at the results we see that the AUC value is approximately similar to that of a Random Forest model. We picked some hyperparameters to adjust that work well in practice, and help us avoid overfitting:
We identify the optimal parameter combination through performing a grid search as we did for the Random Forest algorithm. The most influential parameter in XGBoost is the learning rate. The learning rate is usually chosen in the region of 0.1-0.3, but smaller or larger values can be chosen depending on the use case.
# get performance metrics #
xgb_eval <- evalm(xgb_model, silent = TRUE, showplots = FALSE)
confusionMatrix(xgb_model$pred$pred,xgb_model$pred$obs)
## Confusion Matrix and Statistics
##
## Reference
## Prediction A B
## A 6019 28
## B 2651 1302
##
## Accuracy : 0.7321
## 95% CI : (0.7233, 0.7408)
## No Information Rate : 0.867
## P-Value [Acc > NIR] : 1
##
## Kappa : 0.3669
##
## Mcnemar's Test P-Value : <2e-16
##
## Sensitivity : 0.6942
## Specificity : 0.9789
## Pos Pred Value : 0.9954
## Neg Pred Value : 0.3294
## Prevalence : 0.8670
## Detection Rate : 0.6019
## Detection Prevalence : 0.6047
## Balanced Accuracy : 0.8366
##
## 'Positive' Class : A
##
## get roc curve plotted in ggplot2
xgb_eval$roc
The confusion matrix shows that the Boosting algorithm is capable of making powerful predictions. We managed to strongly increase Specifity while losing some Sensitivity This means that the classifier has become much better at predicting values from the minority class at the cost of losing predictive power in the majority class. Looking at the ROC curve we find that nothing has changed. The results indicate that XGBoost performs best among the tested algorithms in predicting minority class values, but isn’t necessarily the best classifier in case we’re interested in predicting the majority class.
The goal of this analysis was to predict whether a value is target / non-target based on numeric values, category, date & time. While this seemed like a time series problem at first, any time series method would be overtly complicated in comparison to extracting time based features.
The exploratory analysis revealed that the category and the column ‘numeric0’ were expected to be our main predictors. We expected limited information from the time based variables and column ‘numeric1’ due to limited variation between categories (time based variables & ‘numeric1’), and large amount of missing values (‘numeric1’). Additionally, the analysis revealed a major class imbalance. Only ca. 13% of observations were of target state = 1. This is a problem for classification, as we would ideally want classes to be balanced.
Tree based models are an intuitive choice in many business problems, as they are explainable and simple. Therefore, we fitted a Decision Tree, which classified observations based on category and ‘numeric0’ value. While it was moderately accurate, it was very biased towards the majority state. To address this issue, we increased model complexity, and tested whether the Random Forest algorithm could fix the issue. The feature importance plot confirmed that mainly the categories and ‘numeric0’ are relevant for classifying the target value. The area under the ROC curve suggests that we have a satisfactory model fit, however we had to find a balance between sensitivity and specificity. While better than a Decision Tree, the predictive power for the minority class was still barely higher than a random guess. XGBoost yielded the best performance in predicting the minority class, but not necessarily when predicting the majority class. While it did not increase the Accuracy or the AUC largely, it had a significantly higher Specificity while also maintaining a satisfactory Sensitivity value.
The tested models can only serve as a baseline for more sophisticated approaches and their predictive power can still be improved through further hyperparameter tuning, adjusting the probability threshold for classification, comparing methods for missing data imputation, and comparing different sampling approaches to address the class imbalance. Furthermore, parameter / model changes usually represent a trade off between false positives and false negatives. In any real world use case it should be decided which of these values is less costly. Accordingly, a suitable performance metric can then be chosen to compare models.
Time series experts might ask why we are not focussing on the time component.
## Plot column numeric0 when in state 0
ggplot(data[target == 0], aes(x = date_time, y = numeric0)) +
geom_point(size = 0.5, color = "darkblue") +
geom_line(size = 0.1, color = "darkblue") +
labs(x = "Date", y = "Value", title = "Evolution 'numeric0' | state = 0") +
plot_theme
## Plot column numeric0 when in state 1
ggplot(data[target == 1], aes(x = date_time, y = numeric0)) +
geom_point(size = 0.5, color = "darkblue") +
geom_line(size = 0.1, color = "darkblue") +
labs(x = "Date", y = "Value", title = "Evolution 'numeric0' | state = 1") +
plot_theme
acf(data[target == 1, numeric0], main = "Autcorrelation function | 'numeric0' | state = 1")
The above plots show the evolution of the column ‘numeric0’ when in state 0 and 1, plus the autocorrelation function of values in state 1.
These are the major observations:
As interesting as these points are, they do not explain why the target switches from 0 to 1. ***