14 Árboles de decisión

Los árboles de decisión, también conocidos como modelos de árbol de clasificación y regresión (CART), son métodos basados en árboles para el aprendizaje automático supervisado. Los árboles de clasificación y de regresión simples son fáciles de usar e interpretar, pero no son competitivos con los mejores métodos de aprendizaje automático. Sin embargo, forman la base para el conjunto de modelos de ensamblaje como “bagged trees”, “random forest” y “boosted trees”, que aunque son menos interpretables, son muy precisos.

Los modelos CART se puede definir en dos tipos de problemas

  • Árboles de clasificación: la variable resultado es categórica y el métodos se utiliza para identificar la “clase” dentro de la cual es más probable que caiga nuestra variable resultado. Un ejemplo de un problema de tipo clasificación sería determinar quién se suscribirá o no a una plataforma digital; o quién se graduará o no de la escuela secundaria; o si una persona tiene cáncer o no.

  • Árboles de regressión: la variable resultado es continua y el métodos se utiliza para predecir su valor. Un ejemplo de un problema de tipo regresión sería predecir los precios de venta de una casa residencial o el nivel de colesterol de una persona.

Los modelos CART segmentan el espacio predictor en \(K\) nodos terminales no superpuestos (hojas). Cada nodo se describe mediante un conjunto de reglas que se pueden utilizar para predecir nuevas respuestas. El valor predicho \(\hat{y}\) para cada nodo es la moda (clasificación) o la media (regresión).

Los modelos CART definen los nodos a través de un proceso top-down greedy llamado división binaria recursiva (recursive binary splitting). El proceso es de arriba hacia abajo porque comienza en la parte superior del árbol con todas las observaciones en una sola región y divide sucesivamente el espacio de predicción. Es greedy porque en cada paso de división, la mejor división se realiza en ese paso en particular sin tener en cuenta las divisiones posteriores. La siguiente figura muestra la idea general de esta metodología:

Diagrama árboles de decisión

Como vemos en el ejemplo una de las ventajas de los modelos CART es que consideran interacciones. En este curso no vamos a ver la regresión lógica pero es una metodología muy interesante que extiende CART cuando las variables predictoras son binarias y las interacciones que buscamos son del tipo AND y OR. Esta metodología se ha empleado con éxito para analizar datos genéticos donde el interés radica en saber cuál es el riesgo de desarrolar una enfermeda si te tiene por ejemplo: “una mutación en un punto A del genoma (SNP) y otra mutación en el punto B ó si se tiene una mutación en el punto C pero no se tiene en el punto D.

También son interesantes porque permiten valores faltantes sin la necesidad de hacer imputaciones previas.

La mejor división es la variable predictora y el punto de corte que minimiza una función de costo. La función de costo más común para los árboles de regresión es la suma de los residuos al cuadrado,

\[RSS = \sum_{k=1}^K\sum_{i \in A_k}{\left(y_i - \hat{y}_{A_k} \right)^2}.\] Para árboles de clasificación, es el índice de Gini,

\[G = \sum_{c=1}^C{\hat{p}_{kc}(1 - \hat{p}_{kc})},\]

y la entropía (aka información estadística)

\[D = - \sum_{c=1}^C{\hat{p}_{kc} \log \hat{p}_{kc}}\]

dónde \(\hat{p}_{kc}\) es la proporción de observaciones de entrenamiento en el nodo \(k\) que son de clase \(c\). Un nodo completamente puro en un árbol binario tendría \(\hat{p} \in \{ 0, 1 \}\) y \(G=D=0\). Un nodo completamente impuro en un árbol binario tendría \(\hat{p}=0.5\) y \(G=0.5^2 \cdot 2 = 0.25\) y \(D = -(0.5 \log(0.5)) \cdot 2 = 0.69\).

CART repite el proceso de división para cada nodo hijo hasta que se satisface un criterio de detención, generalmente cuando ningún tamaño de nodo supera un máximo predefinido o la división no mejora el modelo de manera significativa. CART también puede imponer un número mínimo de observaciones en cada nodo.

Es probable que el árbol resultante esté sobre-entrenado (over-fitting) y, por lo tanto, no se generalice bien para los datos de prueba. Para evitar este problema CART poda el árbol, minimizando el error de predicción de validación cruzada. En este caso, el hiperparámetro que debermos seleccionar en este modelo es la profundidad del arbol (e.g. número de nodos).

En lugar de realizar una validación cruzada de todos los subárboles posibles para encontrar el que tenga el mínimo de error, CART utiliza la poda de complejidad de costos (cost-complexity pruning). Costo-complejidad es la compensación entre error (costo) y tamaño del árbol (complejidad) donde la compensación se cuantifica con el parámetro costo-complejidad \(c_p\). El costo-complejidad del árbol, \(R_{c_p}(T)\), es la suma de su riesgo (error) más un factor de “complejidad de costos” \(c_p\) multiplicado pro el tamaño del arbol \(|T|\).

\[R_{c_p}(T) = R(T) + c_p|T|\]

\(c_p\) puede tomar cualquier valor de \([0..\infty]\), pero resulta que hay un árbol óptimo para rangos de \(c_p\), por lo que solo hay un conjunto finito de valores interesantes para \(c_p\) (ver Therneau y Atkinson 2019. CART utiliza validación cruzada para determinar qué \(c_p\) es óptimo.

14.1 Árboles de clasificación

Veamos cómo crear árboles de clasificación usando el conjunto de datos ISLR::OJ que se usaron para predecir qué marca de zumo de naranja, Citrus Hill (CH) o Minute Maid = (MM) toman los clientes (variable `Purchase) a partir de 17 variables predictoras.

Vamos a introducir la librería skimr que es interesante para hacer descriptivas. Con ella podremos saber, por ejemplo, cuántos tipos de variables tenemos o ver qué distribuciones tienen las variables continuas

library(tidyverse)
library(caret)
library(rpart)  # classification and regression trees 
library(rpart.plot)  # better formatted plots than the ones in rpart

oj_dat <- ISLR::OJ
skimr::skim(oj_dat)
Table 14.1: Data summary
Name oj_dat
Number of rows 1070
Number of columns 18
_______________________
Column type frequency:
factor 2
numeric 16
________________________
Group variables None

Variable type: factor

skim_variable n_missing complete_rate ordered n_unique top_counts
Purchase 0 1 FALSE 2 CH: 653, MM: 417
Store7 0 1 FALSE 2 No: 714, Yes: 356

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
WeekofPurchase 0 1 254.38 15.56 227.00 240.00 257.00 268.00 278.00 ▆▅▅▇▇
StoreID 0 1 3.96 2.31 1.00 2.00 3.00 7.00 7.00 ▇▅▃▁▇
PriceCH 0 1 1.87 0.10 1.69 1.79 1.86 1.99 2.09 ▅▂▇▆▁
PriceMM 0 1 2.09 0.13 1.69 1.99 2.09 2.18 2.29 ▂▁▃▇▆
DiscCH 0 1 0.05 0.12 0.00 0.00 0.00 0.00 0.50 ▇▁▁▁▁
DiscMM 0 1 0.12 0.21 0.00 0.00 0.00 0.23 0.80 ▇▁▂▁▁
SpecialCH 0 1 0.15 0.35 0.00 0.00 0.00 0.00 1.00 ▇▁▁▁▂
SpecialMM 0 1 0.16 0.37 0.00 0.00 0.00 0.00 1.00 ▇▁▁▁▂
LoyalCH 0 1 0.57 0.31 0.00 0.33 0.60 0.85 1.00 ▅▃▆▆▇
SalePriceMM 0 1 1.96 0.25 1.19 1.69 2.09 2.13 2.29 ▁▂▂▂▇
SalePriceCH 0 1 1.82 0.14 1.39 1.75 1.86 1.89 2.09 ▂▁▇▇▅
PriceDiff 0 1 0.15 0.27 -0.67 0.00 0.23 0.32 0.64 ▁▂▃▇▂
PctDiscMM 0 1 0.06 0.10 0.00 0.00 0.00 0.11 0.40 ▇▁▂▁▁
PctDiscCH 0 1 0.03 0.06 0.00 0.00 0.00 0.00 0.25 ▇▁▁▁▁
ListPriceDiff 0 1 0.22 0.11 0.00 0.14 0.24 0.30 0.44 ▂▃▆▇▁
STORE 0 1 1.63 1.43 0.00 0.00 2.00 3.00 4.00 ▇▃▅▅▃

Dividiremos nuestra base de datos oj_dat (n = 1070) en oj_train (80%, n = 857) para estimar varios modelos, y oj_test (20%, n = 213) para comparar su rendimiento con datos nuevos.

set.seed(12345)
partition <- createDataPartition(y = oj_dat$Purchase, p = 0.8, list = FALSE)
oj_train <- oj_dat[partition, ]
oj_test <- oj_dat[-partition, ]

La función rpart::rpart() construye un árbol completo, minimizando el índice de Gini \(G\) por defecto (parms = list (split = "gini")), hasta que se cumpla el criterio de parada. El criterio de parada predeterminado es:

  • solo intenta una división si el nodo actual tiene al menos minsplit = 20 observaciones, y
  • solo acepta una división si
    • los nodos resultantes tienen al menos minbucket = round (minsplit / 3) observaciones, y
    • el ajuste general resultante mejora en cp = 0.01 (es decir, \(\Delta G <= 0.01\)).
# Usar method = "class" para clasificación y method = "anova" para regresión
set.seed(123)
oj_mdl_cart_full <- rpart(formula = Purchase ~ ., data = oj_train, 
                          method = "class")
oj_mdl_cart_full
n= 857 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

 1) root 857 334 CH (0.61026838 0.38973162)  
   2) LoyalCH>=0.48285 537  94 CH (0.82495345 0.17504655)  
     4) LoyalCH>=0.7648795 271  13 CH (0.95202952 0.04797048) *
     5) LoyalCH< 0.7648795 266  81 CH (0.69548872 0.30451128)  
      10) PriceDiff>=-0.165 226  50 CH (0.77876106 0.22123894) *
      11) PriceDiff< -0.165 40   9 MM (0.22500000 0.77500000) *
   3) LoyalCH< 0.48285 320  80 MM (0.25000000 0.75000000)  
     6) LoyalCH>=0.2761415 146  58 MM (0.39726027 0.60273973)  
      12) SalePriceMM>=2.04 71  31 CH (0.56338028 0.43661972) *
      13) SalePriceMM< 2.04 75  18 MM (0.24000000 0.76000000) *
     7) LoyalCH< 0.2761415 174  22 MM (0.12643678 0.87356322) *

La salida comienza con el nodo raíz. La clase predicha en la raíz es CH y esta predicción produce 334 errores en las 857 observaciones para una tasa de éxito (precisión) del 61% y una tasa de error del 39%. Los nodos secundarios del nodo “x” están etiquetados como 2x) y 2x + 1), por lo que los nodos secundarios de 1) son 2) y 3), y los nodos secundarios de 2) son 4) y 5). Los nodos terminales están etiquetados con un asterisco (*).

Sorprendentemente, solo 3 de las 17 variables se utilizaron en el árbol completo: LoyalCH (lealtad de marca del cliente para CH), PriceDiff (precio relativo de MM sobre CH) y SalePriceMM (precio absoluto de MM). La primera división está en LoyalCH = 0.48285. Aquí hay un diagrama del árbol completo (sin podar).

rpart.plot(oj_mdl_cart_full, yesno = TRUE)

Las cajas muestran la clasificación del nodo (según la moda), la proporción de observaciones que no son CH y la proporción de observaciones incluidas en el nodo.

rpart () no solo hizo crecer el árbol completo, sino que identificó el conjunto de parámetros de complejidad de costos y midió el rendimiento del modelo de cada árbol correspondiente mediante validación cruzada. printcp () muestra los posibles valores de \(c_p\). La siguiente tabla se puede utilizar esta tabla para decidir cómo podar el árbol.

printcp(oj_mdl_cart_full)

Classification tree:
rpart(formula = Purchase ~ ., data = oj_train, method = "class")

Variables actually used in tree construction:
[1] LoyalCH     PriceDiff   SalePriceMM

Root node error: 334/857 = 0.38973

n= 857 

        CP nsplit rel error  xerror     xstd
1 0.479042      0   1.00000 1.00000 0.042745
2 0.032934      1   0.52096 0.54192 0.035775
3 0.013473      3   0.45509 0.47006 0.033905
4 0.010000      5   0.42814 0.46407 0.033736

Hay 4 valores de \(c_p\) en este modelo. El modelo con el parámetro de complejidad más pequeño permite la mayoría de las divisiones (nsplit). El parámetro de mayor complejidad corresponde a un árbol con solo un nodo raíz. rel error es la tasa de error relativa al nodo raíz. El error absoluto del nodo raíz es 0.38973162 (la proporción de MM), por lo que su rel error es 0.38973162 / 0.38973162 = 1.0. Eso significa que el error absoluto del árbol completo (en CP = 0.01) es 0.42814 * 0.38973162 = 0.1669. Podemos verificarlo calculando la tasa de error de los valores predichos:

pred <- predict(oj_mdl_cart_full, newdata = oj_train, type = "class")
mean(oj_train$Purchase != pred)
[1] 0.1668611

Para acaber de explicar toda la salida de la table CP, xerror es la tasa de error relativa con validación cruzada y xstd es su error estándar. Si se desea el error más bajo posible, deberíamos podar el árbol con el error de CV relativo más pequeño, \(c_p=0.01\). Si deseamos equilibrar el poder predictivo con la simplicidad, podaremos al árbol más pequeño dentro de 1 SE del que tiene el error relativo más pequeño. La tabla CP no es muy útil para encontrar ese árbol, así que añadiremos una columna para encontrarlo.

oj_mdl_cart_full$cptable %>%
   data.frame() %>%
   mutate(
      min_idx = which.min(oj_mdl_cart_full$cptable[, "xerror"]),
      rownum = row_number(),
      xerror_cap = oj_mdl_cart_full$cptable[min_idx, "xerror"] + 
                   oj_mdl_cart_full$cptable[min_idx, "xstd"],
      eval = case_when(rownum == min_idx ~ "min xerror",
                       xerror < xerror_cap ~ "under cap",
                       TRUE ~ "")
   ) %>%
   dplyr::select(-rownum, -min_idx) 
          CP nsplit rel.error    xerror       xstd xerror_cap       eval
1 0.47904192      0 1.0000000 1.0000000 0.04274518  0.4978082           
2 0.03293413      1 0.5209581 0.5419162 0.03577468  0.4978082           
3 0.01347305      3 0.4550898 0.4700599 0.03390486  0.4978082  under cap
4 0.01000000      5 0.4281437 0.4640719 0.03373631  0.4978082 min xerror

El árbol más simple que usa la regla 1-SE es \(c_p = 0.01347305\) (error CV = 0.18). Afortunadamente, plotcp () nos da una representación gráfica de la relación entre xerror y cp.

plotcp(oj_mdl_cart_full, upper = "splits")

La línea discontinua se establece en el mínimo xerror + xstd. El eje superior muestra el número de divisiones en el árbol. NOTA: No estoy seguro de por qué los valores de CP no son los mismos que en la tabla (están cerca, pero no son los mismos). La figura sugiere que debería podar a 5 o 3 divisiones. Vemos que esta curva nunca llega al mínimo, sigue disminuyendo en 5 divisiones. El valor del parámetro de ajuste predeterminado cp = 0.01 puede ser demasiado grande, así que lo cambiaremos a cp = 0.001 y empezaremos de nuevo.

set.seed(123)
oj_mdl_cart_full <- rpart(
   formula = Purchase ~ .,
   data = oj_train,
   method = "class",
   cp = 0.001
   )
print(oj_mdl_cart_full)
n= 857 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

  1) root 857 334 CH (0.61026838 0.38973162)  
    2) LoyalCH>=0.48285 537  94 CH (0.82495345 0.17504655)  
      4) LoyalCH>=0.7648795 271  13 CH (0.95202952 0.04797048) *
      5) LoyalCH< 0.7648795 266  81 CH (0.69548872 0.30451128)  
       10) PriceDiff>=-0.165 226  50 CH (0.77876106 0.22123894)  
         20) ListPriceDiff>=0.255 115  11 CH (0.90434783 0.09565217) *
         21) ListPriceDiff< 0.255 111  39 CH (0.64864865 0.35135135)  
           42) PriceMM>=2.155 19   2 CH (0.89473684 0.10526316) *
           43) PriceMM< 2.155 92  37 CH (0.59782609 0.40217391)  
             86) DiscCH>=0.115 7   0 CH (1.00000000 0.00000000) *
             87) DiscCH< 0.115 85  37 CH (0.56470588 0.43529412)  
              174) ListPriceDiff>=0.215 45  15 CH (0.66666667 0.33333333) *
              175) ListPriceDiff< 0.215 40  18 MM (0.45000000 0.55000000)  
                350) LoyalCH>=0.527571 28  13 CH (0.53571429 0.46428571)  
                  700) WeekofPurchase< 266.5 21   8 CH (0.61904762 0.38095238) *
                  701) WeekofPurchase>=266.5 7   2 MM (0.28571429 0.71428571) *
                351) LoyalCH< 0.527571 12   3 MM (0.25000000 0.75000000) *
       11) PriceDiff< -0.165 40   9 MM (0.22500000 0.77500000) *
    3) LoyalCH< 0.48285 320  80 MM (0.25000000 0.75000000)  
      6) LoyalCH>=0.2761415 146  58 MM (0.39726027 0.60273973)  
       12) SalePriceMM>=2.04 71  31 CH (0.56338028 0.43661972)  
         24) LoyalCH< 0.303104 7   0 CH (1.00000000 0.00000000) *
         25) LoyalCH>=0.303104 64  31 CH (0.51562500 0.48437500)  
           50) WeekofPurchase>=246.5 52  22 CH (0.57692308 0.42307692)  
            100) PriceCH< 1.94 35  11 CH (0.68571429 0.31428571)  
              200) StoreID< 1.5 9   1 CH (0.88888889 0.11111111) *
              201) StoreID>=1.5 26  10 CH (0.61538462 0.38461538)  
                402) LoyalCH< 0.410969 17   4 CH (0.76470588 0.23529412) *
                403) LoyalCH>=0.410969 9   3 MM (0.33333333 0.66666667) *
            101) PriceCH>=1.94 17   6 MM (0.35294118 0.64705882) *
           51) WeekofPurchase< 246.5 12   3 MM (0.25000000 0.75000000) *
       13) SalePriceMM< 2.04 75  18 MM (0.24000000 0.76000000)  
         26) SpecialCH>=0.5 14   6 CH (0.57142857 0.42857143) *
         27) SpecialCH< 0.5 61  10 MM (0.16393443 0.83606557) *
      7) LoyalCH< 0.2761415 174  22 MM (0.12643678 0.87356322)  
       14) LoyalCH>=0.035047 117  21 MM (0.17948718 0.82051282)  
         28) WeekofPurchase< 273.5 104  21 MM (0.20192308 0.79807692)  
           56) PriceCH>=1.875 20   9 MM (0.45000000 0.55000000)  
            112) WeekofPurchase>=252.5 12   5 CH (0.58333333 0.41666667) *
            113) WeekofPurchase< 252.5 8   2 MM (0.25000000 0.75000000) *
           57) PriceCH< 1.875 84  12 MM (0.14285714 0.85714286) *
         29) WeekofPurchase>=273.5 13   0 MM (0.00000000 1.00000000) *
       15) LoyalCH< 0.035047 57   1 MM (0.01754386 0.98245614) *

Este es un árbol mucho más grande. ¿Encontramo un valor cp que produce un mínimo?

plotcp(oj_mdl_cart_full, upper = "splits")

Sí, el mínimo está en CP = 0.011 con 5 divisiones. El mínimo + 1 SE está en CP = 0.021 con 3 divisiones. Podaremos entonces el árbol en 3.

oj_mdl_cart <- prune(
   oj_mdl_cart_full,
   cp = oj_mdl_cart_full$cptable[oj_mdl_cart_full$cptable[, 2] == 3, "CP"]
)
rpart.plot(oj_mdl_cart, yesno = TRUE)

El indicador de compra más “importante” parece ser LoyalCH. De la vignette de rpart (página 12) tenemos que:

An overall measure of variable importance is the sum of the goodness of split measures for each split for which it was the primary variable, plus goodness (adjusted agreement) for all splits in which it was a surrogate.

“Surrogate” (variable subrogada[^La FDA define una variable subrogada como “una medida de laboratorio o signo físico que se usa en ensayos terapéuticos como sustituto de una variable clínicamente significativa que es una medida directa sobre lo que siente un paciente, sus funciones o su supervivencia y que se espera que prediga el efecto de la terapia]) se refieren a características alternativas para que un nodo maneje los datos faltantes. Para cada división, CART evalúa una variedad de divisiones alternativas”sustitutas” para usar cuando el valor de la característica para la división principal es NA. Las divisiones sustitutas son divisiones que producen resultados similares a la división original.

La importancia de una variable es la suma de la mejora en la medida general de Gini (o RMSE) producida por los nodos en los que aparece. En el siguiente gráfico podemos ver la importancia de cada variable para este modelo.

oj_mdl_cart$variable.importance %>% 
   data.frame() %>%
   rownames_to_column(var = "Feature") %>%
   rename(Overall = '.') %>%
   ggplot(aes(x = fct_reorder(Feature, Overall), y = Overall)) +
   geom_pointrange(aes(ymin = 0, ymax = Overall), color = "cadetblue", size = .3) +
   theme_minimal() +
   coord_flip() +
   labs(x = "", y = "", title = "Importancia mediante clasificación simple")

LoyalCH es, con mucho, la variable más importante, como se esperaba de su posición en la parte superior del árbol así como en el siguiente nivel abajo.

Podemos ver cómo aparecen los variables subrogadas en el modelo con la función summary().

summary(oj_mdl_cart)
Call:
rpart(formula = Purchase ~ ., data = oj_train, method = "class", 
    cp = 0.001)
  n= 857 

          CP nsplit rel error    xerror       xstd
1 0.47904192      0 1.0000000 1.0000000 0.04274518
2 0.03293413      1 0.5209581 0.5419162 0.03577468
3 0.01347305      3 0.4550898 0.4700599 0.03390486

Variable importance
       LoyalCH      PriceDiff    SalePriceMM        StoreID WeekofPurchase         DiscMM 
            67              9              5              4              4              3 
       PriceMM      PctDiscMM        PriceCH 
             3              3              1 

Node number 1: 857 observations,    complexity param=0.4790419
  predicted class=CH  expected loss=0.3897316  P(node) =1
    class counts:   523   334
   probabilities: 0.610 0.390 
  left son=2 (537 obs) right son=3 (320 obs)
  Primary splits:
      LoyalCH       < 0.48285   to the right, improve=132.56800, (0 missing)
      StoreID       < 3.5       to the right, improve= 40.12097, (0 missing)
      PriceDiff     < 0.015     to the right, improve= 24.26552, (0 missing)
      ListPriceDiff < 0.255     to the right, improve= 22.79117, (0 missing)
      SalePriceMM   < 1.84      to the right, improve= 20.16447, (0 missing)
  Surrogate splits:
      StoreID        < 3.5       to the right, agree=0.646, adj=0.053, (0 split)
      PriceMM        < 1.89      to the right, agree=0.638, adj=0.031, (0 split)
      WeekofPurchase < 229.5     to the right, agree=0.632, adj=0.016, (0 split)
      DiscMM         < 0.77      to the left,  agree=0.629, adj=0.006, (0 split)
      SalePriceMM    < 1.385     to the right, agree=0.629, adj=0.006, (0 split)

Node number 2: 537 observations,    complexity param=0.03293413
  predicted class=CH  expected loss=0.1750466  P(node) =0.6266044
    class counts:   443    94
   probabilities: 0.825 0.175 
  left son=4 (271 obs) right son=5 (266 obs)
  Primary splits:
      LoyalCH       < 0.7648795 to the right, improve=17.669310, (0 missing)
      PriceDiff     < 0.015     to the right, improve=15.475200, (0 missing)
      SalePriceMM   < 1.84      to the right, improve=13.951730, (0 missing)
      ListPriceDiff < 0.255     to the right, improve=11.407560, (0 missing)
      DiscMM        < 0.15      to the left,  improve= 7.795122, (0 missing)
  Surrogate splits:
      WeekofPurchase < 257.5     to the right, agree=0.594, adj=0.180, (0 split)
      PriceCH        < 1.775     to the right, agree=0.590, adj=0.173, (0 split)
      StoreID        < 3.5       to the right, agree=0.587, adj=0.165, (0 split)
      PriceMM        < 2.04      to the right, agree=0.587, adj=0.165, (0 split)
      SalePriceMM    < 2.04      to the right, agree=0.587, adj=0.165, (0 split)

Node number 3: 320 observations
  predicted class=MM  expected loss=0.25  P(node) =0.3733956
    class counts:    80   240
   probabilities: 0.250 0.750 

Node number 4: 271 observations
  predicted class=CH  expected loss=0.04797048  P(node) =0.3162194
    class counts:   258    13
   probabilities: 0.952 0.048 

Node number 5: 266 observations,    complexity param=0.03293413
  predicted class=CH  expected loss=0.3045113  P(node) =0.3103851
    class counts:   185    81
   probabilities: 0.695 0.305 
  left son=10 (226 obs) right son=11 (40 obs)
  Primary splits:
      PriceDiff     < -0.165    to the right, improve=20.84307, (0 missing)
      ListPriceDiff < 0.235     to the right, improve=20.82404, (0 missing)
      SalePriceMM   < 1.84      to the right, improve=16.80587, (0 missing)
      DiscMM        < 0.15      to the left,  improve=10.05120, (0 missing)
      PctDiscMM     < 0.0729725 to the left,  improve=10.05120, (0 missing)
  Surrogate splits:
      SalePriceMM    < 1.585     to the right, agree=0.906, adj=0.375, (0 split)
      DiscMM         < 0.57      to the left,  agree=0.895, adj=0.300, (0 split)
      PctDiscMM      < 0.264375  to the left,  agree=0.895, adj=0.300, (0 split)
      WeekofPurchase < 274.5     to the left,  agree=0.872, adj=0.150, (0 split)
      SalePriceCH    < 2.075     to the left,  agree=0.857, adj=0.050, (0 split)

Node number 10: 226 observations
  predicted class=CH  expected loss=0.2212389  P(node) =0.2637106
    class counts:   176    50
   probabilities: 0.779 0.221 

Node number 11: 40 observations
  predicted class=MM  expected loss=0.225  P(node) =0.04667445
    class counts:     9    31
   probabilities: 0.225 0.775 

Una vez tenemos un modelo (o varios) los podemos evaluar en la muestra test con las medidas estándard

pred <- predict(oj_mdl_cart, newdata = oj_test, type = "class") 
oj_cm_cart <- confusionMatrix(pred,  oj_test$Purchase)
oj_cm_cart
Confusion Matrix and Statistics

          Reference
Prediction  CH  MM
        CH 113  13
        MM  17  70
                                          
               Accuracy : 0.8592          
                 95% CI : (0.8051, 0.9029)
    No Information Rate : 0.6103          
    P-Value [Acc > NIR] : 1.265e-15       
                                          
                  Kappa : 0.7064          
                                          
 Mcnemar's Test P-Value : 0.5839          
                                          
            Sensitivity : 0.8692          
            Specificity : 0.8434          
         Pos Pred Value : 0.8968          
         Neg Pred Value : 0.8046          
             Prevalence : 0.6103          
         Detection Rate : 0.5305          
   Detection Prevalence : 0.5915          
      Balanced Accuracy : 0.8563          
                                          
       'Positive' Class : CH              
                                          

También podemos representar gráficamente la tabla de confusión

plot(oj_test$Purchase, pred, 
     main = "Clasificación: Predicho vs. Observado",
     xlab = "Observado",
     ylab = "Predicho")

14.2 Área bajo la curva ROC

También podemos calcular el área bajo la curva ROC. La curva ROC (características operativas del receptor) es otra medida de precisión. Corresponde a un gráfico de la tasa de verdaderos positivos (TPR, sensibilidad) versus la tasa de falsos positivos (FPR, 1 - especificidad) para un conjunto de umbrales. De forma predeterminada, el umbral para predecir la clasificación predeterminada es 0.50, pero podría ser cualquier umbral. La función precrec::evalmod () calcula los valores de la matriz de confusión del modelo usando el conjunto de datos test. El AUC en el conjunto de datos test es 0.8848 y podemos calcularlo con varias funciones: pROC::plot.roc (), plotROC::geom_roc (), yardstick::roc_curve () y plotROC para usar ggplot() [geometría geom_roc ()].

Nosotros usaremos pROC. Para ello necesitamos tener las predicciones como probabilidades para la categoría de referencia. NOTA: El AUC es, pues, una medida útil para casos donde el predictor es binario.

library(pROC)
pred2 <-  predict(oj_mdl_cart, newdata = oj_test, type = "prob")[,"CH"]
roc.car <- roc(oj_test$Purchase, pred2, print.auc=TRUE, 
               ci=TRUE,
               plot=TRUE)

14.2.1 Entrenamiento con caret

También podemos ajustar el modelo con la función caret::train (). Recordemo que hay dos formas de ajustar los hiperparámetros cuando usamos train ():

  • establecer el número de valores de parámetros de ajuste a considerar utilizando tuneLength, o
  • establecer ciertos valores para cada parámetro utilizando tuneGrid.

ESTRATEGIA: Construiremos el modelo usando una validación cruzada de 10 veces para optimizar el hiperparámetro CP. Si no tenemos idea de cuál es el parámetro de ajuste óptimo, empezaremos con tuneLength para aproximarnos al valor óptimo y luego ajustaremos el valor con tuneGrid. Crearemos un objeto de control de entrenamiento que puedo reutilizar en otras compilaciones de modelos.

oj_trControl = trainControl (method = "cv",
   number = 10,
   savePredictions = "final",  # guardaremos preds para el valor óptimo del parámetro a tunear
   classProbs = TRUE,  # probs para las clases además de preds
   summaryFunction = twoClassSummary
   )

Ahora estimamos el modelo con

set.seed(1234)
oj_mdl_cart2 <- train(
   Purchase ~ ., 
   data = oj_train, 
   method = "rpart",
   tuneLength = 5,
   metric = "ROC",
   trControl = oj_trControl
   )

caret construye un árbol completo usando los parámetros predeterminados de rpart que son: índice de división de Gini, al menos 20 observaciones en un nodo para considerar dividirlo, y al menos 6 observaciones en cada nodo. Luego, caret calcula la precisión para cada valor candidato del hiperparámetro (CP). Estos son los resultados:

oj_mdl_cart2
CART 

857 samples
 17 predictor
  2 classes: 'CH', 'MM' 

No pre-processing
Resampling: Cross-Validated (10 fold) 
Summary of sample sizes: 772, 772, 771, 770, 771, 771, ... 
Resampling results across tuning parameters:

  cp           ROC        Sens       Spec     
  0.005988024  0.8539885  0.8605225  0.7274510
  0.008982036  0.8502309  0.8568578  0.7334225
  0.013473054  0.8459290  0.8473149  0.7397504
  0.032934132  0.7776483  0.8509071  0.6796791
  0.479041916  0.5878764  0.9201379  0.2556150

ROC was used to select the optimal model using the largest value.
The final value used for the model was cp = 0.005988024.

El segundo CP (0.008982036) produce la mayor precisión. Podemos profundizar en el mejor valor de CP usando un tuning grid.

set.seed(1234)
oj_mdl_cart2 <- train(
   Purchase ~ ., 
   data = oj_train, 
   method = "rpart",
   tuneGrid = expand.grid(cp = seq(from = 0.001, to = 0.010, length = 11)),  
   metric = "ROC",
   trControl = oj_trControl
   )
print(oj_mdl_cart2)
CART 

857 samples
 17 predictor
  2 classes: 'CH', 'MM' 

No pre-processing
Resampling: Cross-Validated (10 fold) 
Summary of sample sizes: 772, 772, 771, 770, 771, 771, ... 
Resampling results across tuning parameters:

  cp      ROC        Sens       Spec     
  0.0010  0.8513056  0.8529390  0.7182709
  0.0019  0.8528471  0.8529753  0.7213012
  0.0028  0.8524435  0.8510522  0.7302139
  0.0037  0.8533529  0.8510522  0.7421569
  0.0046  0.8540042  0.8491292  0.7333333
  0.0055  0.8543820  0.8567126  0.7334225
  0.0064  0.8539885  0.8605225  0.7274510
  0.0073  0.8521076  0.8625181  0.7335116
  0.0082  0.8521076  0.8625181  0.7335116
  0.0091  0.8502309  0.8568578  0.7334225
  0.0100  0.8507262  0.8510885  0.7424242

ROC was used to select the optimal model using the largest value.
The final value used for the model was cp = 0.0055.

El mejor modelo se consigue con CP = 0.0082. A continuación podemos ver las precisiones de validación cruzada para los valores de CP candidatos.

plot(oj_mdl_cart2)

Estos son los resultados para el modelo final:

oj_mdl_cart2$finalModel
n= 857 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

  1) root 857 334 CH (0.61026838 0.38973162)  
    2) LoyalCH>=0.48285 537  94 CH (0.82495345 0.17504655)  
      4) LoyalCH>=0.7648795 271  13 CH (0.95202952 0.04797048) *
      5) LoyalCH< 0.7648795 266  81 CH (0.69548872 0.30451128)  
       10) PriceDiff>=-0.165 226  50 CH (0.77876106 0.22123894) *
       11) PriceDiff< -0.165 40   9 MM (0.22500000 0.77500000) *
    3) LoyalCH< 0.48285 320  80 MM (0.25000000 0.75000000)  
      6) LoyalCH>=0.2761415 146  58 MM (0.39726027 0.60273973)  
       12) SalePriceMM>=2.04 71  31 CH (0.56338028 0.43661972)  
         24) LoyalCH< 0.303104 7   0 CH (1.00000000 0.00000000) *
         25) LoyalCH>=0.303104 64  31 CH (0.51562500 0.48437500)  
           50) WeekofPurchase>=246.5 52  22 CH (0.57692308 0.42307692)  
            100) PriceCH< 1.94 35  11 CH (0.68571429 0.31428571) *
            101) PriceCH>=1.94 17   6 MM (0.35294118 0.64705882) *
           51) WeekofPurchase< 246.5 12   3 MM (0.25000000 0.75000000) *
       13) SalePriceMM< 2.04 75  18 MM (0.24000000 0.76000000)  
         26) SpecialCH>=0.5 14   6 CH (0.57142857 0.42857143) *
         27) SpecialCH< 0.5 61  10 MM (0.16393443 0.83606557) *
      7) LoyalCH< 0.2761415 174  22 MM (0.12643678 0.87356322) *
rpart.plot(oj_mdl_cart2$finalModel)

Veamos el rendimiento en la muestra test:

pred3 <- predict(oj_mdl_cart2, newdata = oj_test, type = "raw") 
oj_cm_cart2 <- confusionMatrix(pred3,  oj_test$Purchase)
oj_cm_cart2
Confusion Matrix and Statistics

          Reference
Prediction  CH  MM
        CH 117  18
        MM  13  65
                                          
               Accuracy : 0.8545          
                 95% CI : (0.7998, 0.8989)
    No Information Rate : 0.6103          
    P-Value [Acc > NIR] : 4.83e-15        
                                          
                  Kappa : 0.6907          
                                          
 Mcnemar's Test P-Value : 0.4725          
                                          
            Sensitivity : 0.9000          
            Specificity : 0.7831          
         Pos Pred Value : 0.8667          
         Neg Pred Value : 0.8333          
             Prevalence : 0.6103          
         Detection Rate : 0.5493          
   Detection Prevalence : 0.6338          
      Balanced Accuracy : 0.8416          
                                          
       'Positive' Class : CH              
                                          

La precisión es 0.8545, un poco peor que la 0.8592 del método directo. El AUC es 0.916 que es mejor que el obtenido con el método directo.

pred4 <- predict(oj_mdl_cart2, newdata = oj_test, type = "prob")[,"CH"] 
roc.car2 <- roc(oj_test$Purchase, pred4, print.auc=TRUE, 
               ci=TRUE,
               plot=TRUE)

Podemos comparar ambas curvas ROC mediante el test de DeLong

roc.test(roc.car, roc.car2)

    DeLong's test for two correlated ROC curves

data:  roc.car and roc.car2
Z = -2.4259, p-value = 0.01527
alternative hypothesis: true difference in AUC is not equal to 0
95 percent confidence interval:
 -0.056801412 -0.006034547
sample estimates:
AUC of roc1 AUC of roc2 
  0.8848471   0.9162651 

Finalmente, podemos crear fácilmente la gráfica de importancia de variables con la función varImp (). La lealtad a la marca es lo más importante, seguida de la diferencia de precio.

plot(varImp(oj_mdl_cart2), main="Importancia de variables con CART (caret)")

Parece que con la estrategia de caret hemos conseguido un mejor modelo predictivo gracias, sobre todo, a la posibilidad de buscar el mejor hiperparámetro haciend fine tuning.

oj_scoreboard <- rbind(
   data.frame(Modelo = "Single Tree", 
              Accuracy = oj_cm_cart$overall["Accuracy"],
              ROC = roc.car$auc),
   data.frame(Modelo = "Single Tree (caret)", 
              Accuracy = oj_cm_cart2$overall["Accuracy"],
              ROC = roc.car2$auc)) %>% 
   arrange(desc(ROC))
knitr::kable(oj_scoreboard, row.names = FALSE)
Modelo Accuracy ROC
Single Tree (caret) 0.8544601 0.9162651
Single Tree 0.8591549 0.8848471

14.3 Árboles de regresión

Un árbol de regresión simple se construye de manera similar a un árbol de clasificación simple y, al igual que el árbol de clasificación, rara vez se usan por sí solo (sobre todo en problemas complejos o de big data). De nuevo, basaremos el aprendizaje de esta metodología partiendo de un ejemplo real. Usaremos el conjunto de datos ISLR::Carseats que pretende predecir las ventas de sillitas de niños para coches (variable Sales) en 400 tiendas usando 10 variables que contienen información de las características de las sillas.

cs_dat <- ISLR::Carseats
skimr::skim(cs_dat)
Table 14.2: Data summary
Name cs_dat
Number of rows 400
Number of columns 11
_______________________
Column type frequency:
factor 3
numeric 8
________________________
Group variables None

Variable type: factor

skim_variable n_missing complete_rate ordered n_unique top_counts
ShelveLoc 0 1 FALSE 3 Med: 219, Bad: 96, Goo: 85
Urban 0 1 FALSE 2 Yes: 282, No: 118
US 0 1 FALSE 2 Yes: 258, No: 142

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
Sales 0 1 7.50 2.82 0 5.39 7.49 9.32 16.27 ▁▆▇▃▁
CompPrice 0 1 124.97 15.33 77 115.00 125.00 135.00 175.00 ▁▅▇▃▁
Income 0 1 68.66 27.99 21 42.75 69.00 91.00 120.00 ▇▆▇▆▅
Advertising 0 1 6.64 6.65 0 0.00 5.00 12.00 29.00 ▇▃▃▁▁
Population 0 1 264.84 147.38 10 139.00 272.00 398.50 509.00 ▇▇▇▇▇
Price 0 1 115.80 23.68 24 100.00 117.00 131.00 191.00 ▁▂▇▆▁
Age 0 1 53.32 16.20 25 39.75 54.50 66.00 80.00 ▇▆▇▇▇
Education 0 1 13.90 2.62 10 12.00 14.00 16.00 18.00 ▇▇▃▇▇

De nuevo, partiremos nuestro conjunto de datos cs_dat (n = 400) en cs_train (80%, n = 321) y cs_test (20%, n = 79).

set.seed(12345)
partition <- createDataPartition(y = cs_dat$Sales, p = 0.8, list = FALSE)
cs_train <- cs_dat[partition, ]
cs_test <- cs_dat[-partition, ]

El primer paso es construir un árbol completo y luego realizar una validación cruzada para ayudar a seleccionar la complejidad de costo óptima (cp). La única diferencia ahora es que usaremos method = "anova" en la función rpart () para poder estimar un árbol de regresión.

set.seed(1234)
cs_mdl_cart_full <- rpart(Sales ~ ., cs_train, method = "anova")
cs_mdl_cart_full
n= 321 

node), split, n, deviance, yval
      * denotes terminal node

 1) root 321 2567.76800  7.535950  
   2) ShelveLoc=Bad,Medium 251 1474.14100  6.770359  
     4) Price>=105.5 168  719.70630  5.987024  
       8) ShelveLoc=Bad 50  165.70160  4.693600  
        16) Population< 201.5 20   48.35505  3.646500 *
        17) Population>=201.5 30   80.79922  5.391667 *
       9) ShelveLoc=Medium 118  434.91370  6.535085  
        18) Advertising< 11.5 88  290.05490  6.113068  
          36) CompPrice< 142 69  193.86340  5.769420  
            72) Price>=132.5 16   50.75440  4.455000 *
            73) Price< 132.5 53  107.12060  6.166226 *
          37) CompPrice>=142 19   58.45118  7.361053 *
        19) Advertising>=11.5 30   83.21323  7.773000 *
     5) Price< 105.5 83  442.68920  8.355904  
      10) Age>=63.5 32  153.42300  6.922500  
        20) Price>=85 25   66.89398  6.160800  
          40) ShelveLoc=Bad 9   18.39396  4.772222 *
          41) ShelveLoc=Medium 16   21.38544  6.941875 *
        21) Price< 85 7   20.22194  9.642857 *
      11) Age< 63.5 51  182.26350  9.255294  
        22) Income< 57.5 12   28.03042  7.707500 *
        23) Income>=57.5 39  116.63950  9.731538  
          46) Age>=50.5 14   21.32597  8.451429 *
          47) Age< 50.5 25   59.52474 10.448400 *
   3) ShelveLoc=Good 70  418.98290 10.281140  
     6) Price>=107.5 49  242.58730  9.441633  
      12) Advertising< 13.5 41  162.47820  8.926098  
        24) Age>=61 17   53.37051  7.757647 *
        25) Age< 61 24   69.45776  9.753750 *
      13) Advertising>=13.5 8   13.36599 12.083750 *
     7) Price< 107.5 21   61.28200 12.240000 *

Las ventas pronosticadas en la raíz son las ventas medias para el conjunto de datos de entrenamiento, 7.5 (los valores corresponden a miles de dolares). La primera división está en ShelveLoc = [Bad, Medium] vs Good (calidad). Aquí está el diagrama de árbol sin podar.

rpart.plot(cs_mdl_cart_full, yesno = TRUE)

Cada caja muestra el valor predicho del nodo (media) y la proporción de observaciones que están en el nodo (o nodos secundarios).

rpart () estima el árbol completo y utiliza validación cruzada para probar el rendimiento de los posibles hiperparámetros de complejidad. Como antes, printcp () muestra los valores de cp candidatos que pueden verse en esta tabla. Estos datos pueden ser utilizados para decidir cómo podar el árbol.

printcp(cs_mdl_cart_full)

Regression tree:
rpart(formula = Sales ~ ., data = cs_train, method = "anova")

Variables actually used in tree construction:
[1] Advertising Age         CompPrice   Income      Population  Price       ShelveLoc  

Root node error: 2567.8/321 = 7.9993

n= 321 

         CP nsplit rel error  xerror     xstd
1  0.262736      0   1.00000 1.00635 0.076664
2  0.121407      1   0.73726 0.74888 0.058981
3  0.046379      2   0.61586 0.65278 0.050839
4  0.044830      3   0.56948 0.67245 0.051638
5  0.041671      4   0.52465 0.66230 0.051065
6  0.025993      5   0.48298 0.62345 0.049368
7  0.025823      6   0.45698 0.61980 0.048026
8  0.024007      7   0.43116 0.62058 0.048213
9  0.015441      8   0.40715 0.58061 0.041738
10 0.014698      9   0.39171 0.56413 0.041368
11 0.014641     10   0.37701 0.56277 0.041271
12 0.014233     11   0.36237 0.56081 0.041097
13 0.014015     12   0.34814 0.55647 0.038308
14 0.013938     13   0.33413 0.55647 0.038308
15 0.010560     14   0.32019 0.57110 0.038872
16 0.010000     15   0.30963 0.56676 0.038090

Hay 16 posibles valores de cp en este modelo. El modelo con el parámetro de complejidad más pequeño permite la mayoría de las divisiones (nsplit). El parámetro de mayor complejidad corresponde a un árbol con solo un nodo raíz. rel error es el SSE relativo al nodo raíz. El SSE del nodo raíz es 2567.76800, por lo que su error rel es 2567.76800 / 2567.76800 = 1.0. Eso significa que el error absoluto del árbol completo (en CP = 0.01) es 0.30963 * 2567.76800 = 795.058. Podemos verificar estos resultados calculando el SSE de los valores predichos del modelo:

data.frame(pred = predict(cs_mdl_cart_full, newdata = cs_train)) %>%
   mutate(obs = cs_train$Sales,
          sq_err = (obs - pred)^2) %>%
   summarise(sse = sum(sq_err))
       sse
1 795.0525

La tabla también muestra, xerror que corresponde al SSE con validación cruzada y xstd a su error estándar. Si deseamos el error más bajo posible, podaremos el árbol con el SSE relativo más pequeño (xerror). Si deseamos equilibrar el poder predictivo con la simplicidad, podaremos al árbol más pequeño que esté dentro de 1 SE para el SSE relativo más pequeño. Al igual que en la sección anterior, la tabla CP no es muy útil para encontrar ese árbol, por lo que debemos añadir una columna para visualizar dicha información:

cs_mdl_cart_full$cptable %>%
   data.frame() %>%
   mutate(min_xerror_idx = which.min(cs_mdl_cart_full$cptable[, "xerror"]),
          rownum = row_number(),
          xerror_cap = cs_mdl_cart_full$cptable[min_xerror_idx, "xerror"] + 
             cs_mdl_cart_full$cptable[min_xerror_idx, "xstd"],
          eval = case_when(rownum == min_xerror_idx ~ "min xerror",
                           xerror < xerror_cap ~ "under cap",
                           TRUE ~ "")) %>%
   dplyr::select(-rownum, -min_xerror_idx) 
           CP nsplit rel.error    xerror       xstd xerror_cap       eval
1  0.26273578      0 1.0000000 1.0063530 0.07666355  0.5947744           
2  0.12140705      1 0.7372642 0.7488767 0.05898146  0.5947744           
3  0.04637919      2 0.6158572 0.6527823 0.05083938  0.5947744           
4  0.04483023      3 0.5694780 0.6724529 0.05163819  0.5947744           
5  0.04167149      4 0.5246478 0.6623028 0.05106530  0.5947744           
6  0.02599265      5 0.4829763 0.6234457 0.04936799  0.5947744           
7  0.02582284      6 0.4569836 0.6198034 0.04802643  0.5947744           
8  0.02400748      7 0.4311608 0.6205756 0.04821332  0.5947744           
9  0.01544139      8 0.4071533 0.5806072 0.04173785  0.5947744  under cap
10 0.01469771      9 0.3917119 0.5641331 0.04136793  0.5947744  under cap
11 0.01464055     10 0.3770142 0.5627713 0.04127139  0.5947744  under cap
12 0.01423309     11 0.3623736 0.5608073 0.04109662  0.5947744  under cap
13 0.01401541     12 0.3481405 0.5564663 0.03830810  0.5947744 min xerror
14 0.01393771     13 0.3341251 0.5564663 0.03830810  0.5947744  under cap
15 0.01055959     14 0.3201874 0.5710951 0.03887227  0.5947744  under cap
16 0.01000000     15 0.3096278 0.5667561 0.03808991  0.5947744  under cap

Bien, entonces el árbol más simple es el que tiene CP = 0.02599265 (5 divisiones). También podemos usar plotcp () para visualizar la relación entrexerrorycp`.

plotcp(cs_mdl_cart_full, upper = "splits")

La línea discontinua se establece en el mínimo xerror + xstd. El eje superior muestra el número de divisiones en el árbol. El error relativo más pequeño está en CP = 0.01 (15 divisiones), pero el CP máximo debajo de la línea discontinua (una desviación estándar por encima del error mínimo) está en CP = 0.02599265 (5 divisiones). Utilizamos entonces la función prune () para podar el árbol especificando el coste-complejidad asociado a este CP.

cs_mdl_cart <- prune(
   cs_mdl_cart_full,
   cp = cs_mdl_cart_full$cptable[cs_mdl_cart_full$cptable[, 2] == 5, "CP"]
)
rpart.plot(cs_mdl_cart, yesno = TRUE)

El indicador más “importante” de ventas es ShelveLoc. Estos son los valores de importancia del modelo:

cs_mdl_cart$variable.importance %>% 
   data.frame() %>%
   rownames_to_column(var = "Feature") %>%
   rename(Overall = '.') %>%
   ggplot(aes(x = fct_reorder(Feature, Overall), y = Overall)) +
   geom_pointrange(aes(ymin = 0, ymax = Overall), color = "cadetblue", size = .3) +
   theme_minimal() +
   coord_flip() +
   labs(x = "", y = "", title = "Variable Importance with Simple Regression")

El indicador más importante de ventas es ShelveLoc, luego Price, luego Age (edad media de la población donde está la tienda). Todas estas variables aparecen en el modelo final. CompPrice (precio del competidor) también es relevante.

El último paso es hacer predicciones sobre el conjunto de datos de validación. Cuando la variable respuesta es continua usamos:

  • la raíz del error cuadrático medio \(RMSE = \sqrt{(1/2) \sum{(actual - pred)^2}})\) y

  • el errr absoluto medio \(MAE = (1/n) \sum{|actual - pred|}\)

La diferencia entre ambos es que RMSE penaliza más los errores grandes. Para un árbol de regresión, basta con indicar type="vector") en la funciónpredict ()` (que es el valor por defecto).

cs_preds_cart <- predict(cs_mdl_cart, cs_test, type = "vector")

cs_rmse_cart <- RMSE(
   pred = cs_preds_cart,
   obs = cs_test$Sales
)
cs_rmse_cart
[1] 2.363202

El proceso de poda conduce a un error de predicción promedio de 2.363 en el conjunto de datos de prueba. No está mal considerando que la desviación estándar de la variable Sales es 2.8. Podemos visualizar la relación entre los datos predichos y los observados mediante:

data.frame(Predichos = cs_preds_cart, Observados = cs_test$Sales) %>%
   ggplot(aes(x = Observados, y = Predichos)) +
   geom_point(alpha = 0.6, color = "cadetblue") +
   geom_smooth() +
   geom_abline(intercept = 0, slope = 1, linetype = 2) +
   labs(title = "Carseats CART, predichos vs observados")

14.3.1 Entrenamiento con caret

También podemos ajustar el modelo con caret::train () especificando method = "rpart". Construirmos el modelo usando 10-fold CV para optimizar el hiperparámetro CP.

cs_trControl = trainControl(
   method = "cv",
   number = 10,
   savePredictions = "final" 
)

Usaremos la misma estrategia que en el caso de los árboles de clasificación, dejaremos que el modelo busque el mejor parámetro de ajuste de CP con tuneLength y luego lo ajustaremos con tuneGrid.

set.seed(1234)
cs_mdl_cart2 = train(
   Sales ~ ., 
   data = cs_train, 
   method = "rpart",
   tuneLength = 5,
   metric = "RMSE",
   trControl = cs_trControl
)
cs_mdl_cart2
CART 

321 samples
 10 predictor

No pre-processing
Resampling: Cross-Validated (10 fold) 
Summary of sample sizes: 289, 289, 289, 289, 289, 289, ... 
Resampling results across tuning parameters:

  cp          RMSE      Rsquared   MAE     
  0.04167149  2.209383  0.4065251  1.778797
  0.04483023  2.243618  0.3849728  1.805027
  0.04637919  2.275563  0.3684309  1.808814
  0.12140705  2.400455  0.2942663  1.936927
  0.26273578  2.692867  0.1898998  2.192774

RMSE was used to select the optimal model using the smallest value.
The final value used for the model was cp = 0.04167149.

El primer cp (0.04167149) presenta el RMSE más pequeño. Puedemos hacer una búsqueda más fina para mejorar el valor de cp usando un grid:

set.seed(1234)
cs_mdl_cart2 = train(
   Sales ~ ., 
   data = cs_train, 
   method = "rpart",
   tuneGrid = expand.grid(cp = seq(from = 0, to = 0.1, by = 0.01)),
   metric = "RMSE",
   trControl = cs_trControl
)
cs_mdl_cart2
CART 

321 samples
 10 predictor

No pre-processing
Resampling: Cross-Validated (10 fold) 
Summary of sample sizes: 289, 289, 289, 289, 289, 289, ... 
Resampling results across tuning parameters:

  cp    RMSE      Rsquared   MAE     
  0.00  2.055676  0.5027431  1.695453
  0.01  2.135096  0.4642577  1.745937
  0.02  2.095767  0.4733269  1.699235
  0.03  2.131246  0.4534544  1.690453
  0.04  2.146886  0.4411380  1.712705
  0.05  2.284937  0.3614130  1.837782
  0.06  2.265498  0.3709523  1.808319
  0.07  2.282630  0.3597216  1.836227
  0.08  2.282630  0.3597216  1.836227
  0.09  2.282630  0.3597216  1.836227
  0.10  2.282630  0.3597216  1.836227

RMSE was used to select the optimal model using the smallest value.
The final value used for the model was cp = 0.

En este ejemplo, parece que el árbol con mejor rendimiento es el que no ha sido podado.

plot(cs_mdl_cart2)

Este sería el modelo final

rpart.plot(cs_mdl_cart2$finalModel)

y estas las variables más importantes

plot(varImp(cs_mdl_cart2), main="Importancia de variables para Regresión")

Como siempre, debemos evaluar el modelo en nuestra muestra test:

cs_preds_cart2 <- predict(cs_mdl_cart2, cs_test, type = "raw")
data.frame(Actual = cs_test$Sales, Predicted = cs_preds_cart2) %>%
ggplot(aes(x = Actual, y = Predicted)) +
   geom_point(alpha = 0.6, color = "cadetblue") +
   geom_smooth(method = "loess", formula = "y ~ x") +
   geom_abline(intercept = 0, slope = 1, linetype = 2) +
   labs(title = "Carseats CART, Predicted vs Actual (caret)")

Observamos como el modelo sobreestima en el extremo inferior y subestima en el extremo superior. Podemos calcular el RMSE para estos datos:

(cs_rmse_cart2 <- RMSE(pred = cs_preds_cart2, obs = cs_test$Sales))
[1] 2.298331

Caret mejora las predicciones:

cs_scoreboard <- rbind(
   data.frame(Modelo = "Single Tree", RMSE = cs_rmse_cart),
   data.frame(Modelo = "Single Tree (caret)", RMSE = cs_rmse_cart2)
) %>% arrange(RMSE)
knitr::kable(cs_scoreboard, row.names = FALSE)
Modelo RMSE
Single Tree (caret) 2.298331
Single Tree 2.363202

14.4 Bagged trees

Los CART tiene una capacidad predictiva moderada, es por ello que se han propuesto unos métodos que combinan varios árboles de decisión para producir un mejor rendimiento predictivo que utilizar un solo árbol de decisión. El principio fundamental detrás de estos modelos es que un grupo de predictores débiles puede conseguir un predictor con mejor capacidad predictiva.

Tenemos dos tipos de estrategias:

  • Bagging
  • Boosting

Bagging (Bootstrap Aggregation) se utiliza cuando nuestro objetivo es reducir la varianza de un árbol de decisión. La idea es crear varios subconjuntos de datos a partir de la muestra de entrenamiento elegida al azar con reemplazamiento. Cada subconjunto de datos se utiliza para entrenar un árbol de decisión. Como resultado, terminamos con un conjunto de diferentes modelos. Se utiliza el promedio de todas las predicciones de diferentes árboles, que es más robusto que considerar un solo árbol de decisión.

Bagged trees

En el Boosting se aprende de forma secuencial. Ajustamos árboles consecutivos (muestra aleatoria) y en cada paso, el objetivo es mejorar el error del árbol anterior.

Boosted trees

Como hemos dicho anteriormente, el algoritmo bagged construye B árboles decisión usando conjuntos de entrenamiento obtenidos mediante remuestreo y promedia las predicciones resultantes. Estos árboles crecen profundamente y no se podan. Por tanto, cada árbol individual tiene una alta varianza, pero un bajo sesgo. Promediar los B árboles ayuda a reducir la varianza. El valor predicho para una observación es la moda (clasificación) o la media (regresión) de los árboles. B generalmente es igual a ~ 25.

Proceso para Bagged trees

Para un conjunto de entrenamiento de tamaño \(n\), cada árbol se compone de \(\sim (1 - e^{-1})n = .632n\) observaciones únicas in-bag y \(.368n\) out-of-bag. Las observaciones que no han sido seleccionadas en el re-muestreo se usan para evaluar la precisión del modelo. La capacidad glogal del método se obtiene promediando la capacidad de cada árbol. Esto tiene una desventaja obvia y es que si cada árbol tiene un rendimiento deficiente, el rendimiento promedio de muchos árboles seguirá siendo deficiente. Además, otra desventaja de este método es que no existe un árbol único con un conjunto de reglas para interpretar. En consecuencia, no queda claro qué variables son más importantes que otras y en algunos problemas (sobre todo biomédicos) esto puede ser una limitación importante.

14.4.1 Bagging árboles de clasificación

Veamos de nuevo con un ejemplo cómo trabajar con estos métodos. Usaremos de nuevo los datos de zumos de naranja OJ. Esta vez usaremos un método bagging especificando method="treebag". Caret no tiene hiperparámetros para este modelo, por lo que no es necesario usar tuneLegth ni tuneGrid. El tamaño de conjunto predeterminado es nbagg = 25 (a veces se puede tunear, pero en este caso lo dejaremos fijo).

set.seed(1234)
oj_mdl_bag <- train(
   Purchase ~ ., 
   data = oj_train, 
   method = "treebag",
   trControl = oj_trControl,
   metric = "ROC"
)
oj_mdl_bag$finalModel

Bagging classification trees with 25 bootstrap replications 

Veamos el rendimiento en el conjunto de datos test.

pred_bag <- predict(oj_mdl_bag, newdata = oj_test, type = "raw")
oj_cm_bag <- confusionMatrix(pred_bag, oj_test$Purchase)
oj_cm_bag
Confusion Matrix and Statistics

          Reference
Prediction  CH  MM
        CH 113  16
        MM  17  67
                                          
               Accuracy : 0.8451          
                 95% CI : (0.7894, 0.8909)
    No Information Rate : 0.6103          
    P-Value [Acc > NIR] : 6.311e-14       
                                          
                  Kappa : 0.675           
                                          
 Mcnemar's Test P-Value : 1               
                                          
            Sensitivity : 0.8692          
            Specificity : 0.8072          
         Pos Pred Value : 0.8760          
         Neg Pred Value : 0.7976          
             Prevalence : 0.6103          
         Detection Rate : 0.5305          
   Detection Prevalence : 0.6056          
      Balanced Accuracy : 0.8382          
                                          
       'Positive' Class : CH              
                                          

La precisión es 0.8451, sorprendentemente peor que el 0.85915 del modelo de árbol único, pero esa es una diferencia que corresponde a tres predicciones en un conjunto de 213. Esta sería la curva ROC.

pred_bag2 <- predict(oj_mdl_bag, newdata = oj_test, type = "prob")[,"CH"]
roc.bag <- roc(oj_test$Purchase, pred_bag2, print.auc=TRUE, 
               ci=TRUE,
               plot=TRUE)

Veamos cuáles son las variables más importantes

plot(varImp(oj_mdl_bag), main="Importancia de variables con Bagging")

Esta es la comparación entre métodos

oj_scoreboard <- rbind(oj_scoreboard,
   data.frame(Modelo = "Bagging", 
              Accuracy = oj_cm_bag$overall["Accuracy"],
              ROC = roc.bag$auc)
) %>% arrange(desc(ROC))
knitr::kable(oj_scoreboard, row.names = FALSE)
Modelo Accuracy ROC
Single Tree (caret) 0.8544601 0.9162651
Bagging 0.8450704 0.9099166
Single Tree 0.8591549 0.8848471

14.4.2 Bagging árboles de regresión

Usemos bagging para predecir las ventas en los datos Carseats:

set.seed(1234)
cs_mdl_bag <- train(
   Sales ~ ., 
   data = cs_train, 
   method = "treebag",
   trControl = cs_trControl
)
cs_mdl_bag
Bagged CART 

321 samples
 10 predictor

No pre-processing
Resampling: Cross-Validated (10 fold) 
Summary of sample sizes: 289, 289, 289, 289, 289, 289, ... 
Resampling results:

  RMSE      Rsquared  MAE     
  1.681889  0.675239  1.343427

Veamos el rendimiento en el conjunto de datos test. El RMSE es 1.9185, pero el modelo predice en exceso en el extremo inferior de ventas y tampoco predice bien en el extremo superior (como un árbol simple).

cs_preds_bag <- bind_cols(
   Predicted = predict(cs_mdl_bag, newdata = cs_test),
   Actual = cs_test$Sales
)
(cs_rmse_bag <- RMSE(pred = cs_preds_bag$Predicted, obs = cs_preds_bag$Actual))
[1] 1.918473
cs_preds_bag %>%
   ggplot(aes(x = Actual, y = Predicted)) +
   geom_point(alpha = 0.6, color = "cadetblue") +
   geom_smooth(method = "loess", formula = "y ~ x") +
   geom_abline(intercept = 0, slope = 1, linetype = 2) +
   labs(title = "Carseats Bagging, Predicted vs Actual (caret)")

La importancia de las variables son:

plot(varImp(cs_mdl_bag), main="Importancia de variables con Bagging")

Y la comparación quedaría

cs_scoreboard <- rbind(cs_scoreboard,
   data.frame(Modelo = "Bagging", RMSE = cs_rmse_bag)
) %>% arrange(RMSE)
knitr::kable(cs_scoreboard, row.names = FALSE)
Modelo RMSE
Bagging 1.918473
Single Tree (caret) 2.298331
Single Tree 2.363202
EJERCICIO (Entrega en Moodle: P2-Bagged):
Implementa una función que implemente el método “bagged tree”. Aplícalo a los datos “Carseats” y compara tus resultados con los que se obtienen usando caret.
EJERCICIO (Entrega en Moodle: P-Bagged Breast cancer):
Utiliza los datos de cáncer de mama que hemos trabajado en clase “breast_train_prep” y “breast_test_prep” (que puedes encontrar en el fichero “breast.Rdata” del Moodle) para crear un modelo predictivo usando CART y Bagged Trees. Compara los resultados con los obtenidos mediant KNN y LDA reportados en el bookdown del curso.

14.5 Random Forest

Los Random Forest (bosques aleatorios) también son un conjunto de árboles de decisión (ensambladores) que mejoran los bagged trees mediante la creación de un bosque no correlacionados de árboles que, de nuevo, mejora la capacidad predictiva de un único árbol. Al igual que en el bagged (embolsado), el algoritmo construye varios árboles de decisión sobre muestras de entrenamiento bootstrap. Sin embargo, al construir estos árboles de decisión, cada vez que se considera una división en un árbol, se elige una muestra aleatoria de predictores (hiperparámetro m o mtry) como candidatos de división del conjunto completo de predictores \(p\). En cada división se toma una nueva muestra de predictores. Típicamente \(m \approx \sqrt{p}\). En consecuencia, los árboles bagged son un caso especial de los random forest cuando \(m = p\).

Cada árbol del modelo random forest se construye de la siguiente forma:

  • Si denotamos por \(N\) el número de casos en el conjunto de entrenamiento, seleccionaremos una muestra de esos \(N\) casos se forma aleatoria CON REEMPLAZAMIENTO. Esta muestra será el conjunto de entrenamiento para construir el árbol i-ésimo.

  • Si denotamos por \(M\) el número total de varibles predictoras, seleccionaremos un número \(m < M\) de variables y crearemos un árbol completo con esas variables. El valor \(m\) se mantiene constante durante la generación de todo el bosque.

  • Cada árbol crece hasta su máxima extensión posible y NO hay proceso de poda.

  • La predicción para nuevos individuos se hace a partir de la información obtenida de las predicciones de los \(B\) árboles (mayoría de votos para clasificación, promedio para regresión). La siguiente figura ilustra este proceso

Random Forest

Podemos estimar un random forest con cart indicando el argumento method = "rf". El hiperparámetro mtry (\(m\)) puede tomar cualquier valor de 1 a 17 (el número de predictores) y se espera que el valor óptimo esté cerca de \(\sqrt{17} \approx 4\). En cuanto al número de árboles (segundo hiperparámetro), hay estudios que demuestran que el rendimiento empeora cuando tenemos muchos árboles, sin embargo esto no está muy claro y por lo general se recomienda entrenar modelos con muchos árboles. Por defecto method = "rf" tiene 500 (argumento num.trees).

set.seed(1234)
oj_mdl_rf <- train(
   Purchase ~ ., 
   data = oj_train, 
   method = "rf",
   metric = "ROC",
   tuneGrid = expand.grid(mtry = 3:10),
   trControl = oj_trControl,
   num.trees = 500
)
oj_mdl_rf
Random Forest 

857 samples
 17 predictor
  2 classes: 'CH', 'MM' 

No pre-processing
Resampling: Cross-Validated (10 fold) 
Summary of sample sizes: 772, 772, 771, 770, 771, 771, ... 
Resampling results across tuning parameters:

  mtry  ROC        Sens       Spec     
   3    0.8655672  0.8565312  0.7185383
   4    0.8685845  0.8641872  0.7122995
   5    0.8682630  0.8470247  0.7183601
   6    0.8672458  0.8412917  0.7124777
   7    0.8695796  0.8412917  0.7183601
   8    0.8668721  0.8393687  0.7213012
   9    0.8652269  0.8432148  0.7153298
  10    0.8671443  0.8413280  0.7152406

ROC was used to select the optimal model using the largest value.
The final value used for the model was mtry = 7.

El valor de ROC más alto se da en \(m = 7\) que es más alto de lo que esperábamos, pero fijémosnos que es un valor de ROC muy similar al que se obtiene con \(m=4\), por lo que por el principio de parsimonia podríamos usar dicho valor

plot(oj_mdl_rf)

También podemos visualizar los resultados con:

plot_rf <- function(model) {
    theme_set(theme_minimal())
    u <- model$results %>%
        dplyr::select(mtry, ROC, Sens, Spec) %>%
        gather(a, b, -mtry)
    
    u %>% ggplot(aes(mtry, b)) + geom_line() + geom_point() + 
        facet_wrap(~ a, scales = "free") + 
        labs(x = "Número de predictores", y = NULL, 
             title = "Relación entre el número de predictores y el comportamiento del modelo")
}

oj_mdl_rf %>% plot_rf()

EJERCICIO (Entrega en Moodle: P2-Plot RF):
Si en el modelo train usamos metric = "prSummary" también podemos visualizar otras métricas para evaluar el comportamiento del modelo. En particular el AUC que puede ser interesante desde un punto de vista práctico. Crea una función que visualice dichas métricas y aplícalo a los datos oj_train (que obviamente tendrás que re-entrenar con esa nueva métrica).

Podemos usar este modelo para hacer predicciones sobre la muestra test

pred_rf <- predict(oj_mdl_rf, newdata = oj_test, type = "raw")
oj_cm_rf <- confusionMatrix(pred_rf, oj_test$Purchase)
oj_cm_rf
Confusion Matrix and Statistics

          Reference
Prediction  CH  MM
        CH 112  16
        MM  18  67
                                          
               Accuracy : 0.8404          
                 95% CI : (0.7841, 0.8869)
    No Information Rate : 0.6103          
    P-Value [Acc > NIR] : 2.164e-13       
                                          
                  Kappa : 0.6659          
                                          
 Mcnemar's Test P-Value : 0.8638          
                                          
            Sensitivity : 0.8615          
            Specificity : 0.8072          
         Pos Pred Value : 0.8750          
         Neg Pred Value : 0.7882          
             Prevalence : 0.6103          
         Detection Rate : 0.5258          
   Detection Prevalence : 0.6009          
      Balanced Accuracy : 0.8344          
                                          
       'Positive' Class : CH              
                                          

Y el área bajo la curva ROC sería:

pred_rf2 <- predict(oj_mdl_rf, newdata = oj_test, type = "prob")[,"CH"]
roc.rf <- roc(oj_test$Purchase, pred_bag2, print.auc=TRUE, 
               ci=TRUE,
               plot=TRUE)

que compara con los modelos anteriores de esta forma:

Y la comparación quedaría

oj_scoreboard <- rbind(oj_scoreboard,
   data.frame(Modelo = "Random Forest", 
              Accuracy = oj_cm_rf$overall["Accuracy"],
              ROC = roc.rf$auc)
) %>% arrange(desc(ROC))
knitr::kable(oj_scoreboard, row.names = FALSE)
Modelo Accuracy ROC
Single Tree (caret) 0.8544601 0.9162651
Bagging 0.8450704 0.9099166
Random Forest 0.8403756 0.9099166
Single Tree 0.8591549 0.8848471

Recordemos que la importancia de las variables se puede ver con la función varImp al igual que cualquier otro modelo basado en àrboles de decisión.

NOTA: El ejemplo para Random Forest con árboles de regresión es igual que lo que vimos en la sección anterior.

NOTA2: En este artículo se hace un “benchmarking” muy interesante para saber qué método y libería de R usar en función de las características de nuestro conjunto de datos.

Speiser JL et al. (2012). A Comparison of Random Forest Variable Selection Methods for Classification Prediction Modeling

Random forest classification is a popular machine learning method for developing prediction models in many research settings. Often in prediction modeling, a goal is to reduce the number of variables needed to obtain a prediction in order to reduce the burden of data collection and improve efficiency. Several variable selection methods exist for the setting of random forest classification; however, there is a paucity of literature to guide users as to which method may be preferable for different types of datasets. Using 311 classification datasets freely available online, we evaluate the prediction error rates, number of variables, computation times and area under the receiver operating curve for many random forest variable selection methods. We compare random forest variable selection methods for different types of datasets (datasets with binary outcomes, datasets with many predictors, and datasets with imbalanced outcomes) and for different types of methods (standard random forest versus conditional random forest methods and test based versus performance based methods). Based on our study, the best variable selection methods for most datasets are Jiang’s method and the method implemented in the VSURF R package. For datasets with many predictors, the methods implemented in the R packages varSelRF and Boruta are preferable due to computational efficiency. A significant contribution of this study is the ability to assess different variable selection techniques in the setting of random forest classification in order to identify preferable methods based on applications in expert and intelligent systems.

EJERCICIO (Entrega en Moodle: P-RF Breast cancer):
Utiliza los datos de cáncer de mama que hemos trabajado en clase “breast_train_prep” y “breast_test_prep” (que puedes encontrar en el fichero “breast.Rdata” del Moodle) para crear un modelo predictivo usando Random Forest. Compara los resultados con los obtenidos mediante KNN y LDA reportados en el bookdown del curso (y di si mejora los que obtuviste con CART y Bagged Trees).

14.6 Random Forest p>>n

Problema: Aplicar Random Forest para conjunto de datos con muchas variables (caso p>>n)

  • Posible estrategia:
  1. Creamos K subconjuntos de variable
  2. Llevamos a cabo una selección de las variables más importantes y nos quedamos con una parte de ellas. Por ejemplo, las M más informativas
  3. Combinamos las K*M variables y repetimos los pasos 1 y 2
  4. Acabamos con M variables seleccionadas
  5. Aplicamos Random Forest

Este enfoque podría causar la pérdida de algunas variables importantes, pero generalmente seleccionará las variables más informativas.

Selección de K y M Breiman (2001) recomienda \(m=p/3\) en clasificación y \(m=\sqrt{p}\) en regresión (mtry). ¿Puede servir esto de ayuda?