14 Árboles de decisión
Los árboles de decisión, también conocidos como modelos de árbol de clasificación y regresión (CART), son métodos basados en árboles para el aprendizaje automático supervisado. Los árboles de clasificación y de regresión simples son fáciles de usar e interpretar, pero no son competitivos con los mejores métodos de aprendizaje automático. Sin embargo, forman la base para el conjunto de modelos de ensamblaje como “bagged trees”, “random forest” y “boosted trees”, que aunque son menos interpretables, son muy precisos.
Los modelos CART se puede definir en dos tipos de problemas
Árboles de clasificación: la variable resultado es categórica y el métodos se utiliza para identificar la “clase” dentro de la cual es más probable que caiga nuestra variable resultado. Un ejemplo de un problema de tipo clasificación sería determinar quién se suscribirá o no a una plataforma digital; o quién se graduará o no de la escuela secundaria; o si una persona tiene cáncer o no.
Árboles de regressión: la variable resultado es continua y el métodos se utiliza para predecir su valor. Un ejemplo de un problema de tipo regresión sería predecir los precios de venta de una casa residencial o el nivel de colesterol de una persona.
Los modelos CART segmentan el espacio predictor en \(K\) nodos terminales no superpuestos (hojas). Cada nodo se describe mediante un conjunto de reglas que se pueden utilizar para predecir nuevas respuestas. El valor predicho \(\hat{y}\) para cada nodo es la moda (clasificación) o la media (regresión).
Los modelos CART definen los nodos a través de un proceso top-down greedy llamado división binaria recursiva (recursive binary splitting). El proceso es de arriba hacia abajo porque comienza en la parte superior del árbol con todas las observaciones en una sola región y divide sucesivamente el espacio de predicción. Es greedy porque en cada paso de división, la mejor división se realiza en ese paso en particular sin tener en cuenta las divisiones posteriores. La siguiente figura muestra la idea general de esta metodología:
Como vemos en el ejemplo una de las ventajas de los modelos CART es que consideran interacciones. En este curso no vamos a ver la regresión lógica pero es una metodología muy interesante que extiende CART cuando las variables predictoras son binarias y las interacciones que buscamos son del tipo AND y OR. Esta metodología se ha empleado con éxito para analizar datos genéticos donde el interés radica en saber cuál es el riesgo de desarrolar una enfermeda si te tiene por ejemplo: “una mutación en un punto A del genoma (SNP) y otra mutación en el punto B ó si se tiene una mutación en el punto C pero no se tiene en el punto D.
También son interesantes porque permiten valores faltantes sin la necesidad de hacer imputaciones previas.
La mejor división es la variable predictora y el punto de corte que minimiza una función de costo. La función de costo más común para los árboles de regresión es la suma de los residuos al cuadrado,
\[RSS = \sum_{k=1}^K\sum_{i \in A_k}{\left(y_i - \hat{y}_{A_k} \right)^2}.\] Para árboles de clasificación, es el índice de Gini,
\[G = \sum_{c=1}^C{\hat{p}_{kc}(1 - \hat{p}_{kc})},\]
y la entropía (aka información estadística)
\[D = - \sum_{c=1}^C{\hat{p}_{kc} \log \hat{p}_{kc}}\]
dónde \(\hat{p}_{kc}\) es la proporción de observaciones de entrenamiento en el nodo \(k\) que son de clase \(c\). Un nodo completamente puro en un árbol binario tendría \(\hat{p} \in \{ 0, 1 \}\) y \(G=D=0\). Un nodo completamente impuro en un árbol binario tendría \(\hat{p}=0.5\) y \(G=0.5^2 \cdot 2 = 0.25\) y \(D = -(0.5 \log(0.5)) \cdot 2 = 0.69\).
CART repite el proceso de división para cada nodo hijo hasta que se satisface un criterio de detención, generalmente cuando ningún tamaño de nodo supera un máximo predefinido o la división no mejora el modelo de manera significativa. CART también puede imponer un número mínimo de observaciones en cada nodo.
Es probable que el árbol resultante esté sobre-entrenado (over-fitting) y, por lo tanto, no se generalice bien para los datos de prueba. Para evitar este problema CART poda el árbol, minimizando el error de predicción de validación cruzada. En este caso, el hiperparámetro que debermos seleccionar en este modelo es la profundidad del arbol (e.g. número de nodos).
En lugar de realizar una validación cruzada de todos los subárboles posibles para encontrar el que tenga el mínimo de error, CART utiliza la poda de complejidad de costos (cost-complexity pruning). Costo-complejidad es la compensación entre error (costo) y tamaño del árbol (complejidad) donde la compensación se cuantifica con el parámetro costo-complejidad \(c_p\). El costo-complejidad del árbol, \(R_{c_p}(T)\), es la suma de su riesgo (error) más un factor de “complejidad de costos” \(c_p\) multiplicado pro el tamaño del arbol \(|T|\).
\[R_{c_p}(T) = R(T) + c_p|T|\]
\(c_p\) puede tomar cualquier valor de \([0..\infty]\), pero resulta que hay un árbol óptimo para rangos de \(c_p\), por lo que solo hay un conjunto finito de valores interesantes para \(c_p\) (ver Therneau y Atkinson 2019. CART utiliza validación cruzada para determinar qué \(c_p\) es óptimo.
14.1 Árboles de clasificación
Veamos cómo crear árboles de clasificación usando el conjunto de datos ISLR::OJ
que se usaron para predecir qué marca de zumo de naranja, Citrus Hill (CH) o Minute Maid = (MM) toman los clientes (variable `Purchase) a partir de 17 variables predictoras.
Vamos a introducir la librería skimr
que es interesante para hacer descriptivas. Con ella podremos saber, por ejemplo, cuántos tipos de variables tenemos o ver qué distribuciones tienen las variables continuas
library(tidyverse)
library(caret)
library(rpart) # classification and regression trees
library(rpart.plot) # better formatted plots than the ones in rpart
<- ISLR::OJ
oj_dat ::skim(oj_dat) skimr
Name | oj_dat |
Number of rows | 1070 |
Number of columns | 18 |
_______________________ | |
Column type frequency: | |
factor | 2 |
numeric | 16 |
________________________ | |
Group variables | None |
Variable type: factor
skim_variable | n_missing | complete_rate | ordered | n_unique | top_counts |
---|---|---|---|---|---|
Purchase | 0 | 1 | FALSE | 2 | CH: 653, MM: 417 |
Store7 | 0 | 1 | FALSE | 2 | No: 714, Yes: 356 |
Variable type: numeric
skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
---|---|---|---|---|---|---|---|---|---|---|
WeekofPurchase | 0 | 1 | 254.38 | 15.56 | 227.00 | 240.00 | 257.00 | 268.00 | 278.00 | ▆▅▅▇▇ |
StoreID | 0 | 1 | 3.96 | 2.31 | 1.00 | 2.00 | 3.00 | 7.00 | 7.00 | ▇▅▃▁▇ |
PriceCH | 0 | 1 | 1.87 | 0.10 | 1.69 | 1.79 | 1.86 | 1.99 | 2.09 | ▅▂▇▆▁ |
PriceMM | 0 | 1 | 2.09 | 0.13 | 1.69 | 1.99 | 2.09 | 2.18 | 2.29 | ▂▁▃▇▆ |
DiscCH | 0 | 1 | 0.05 | 0.12 | 0.00 | 0.00 | 0.00 | 0.00 | 0.50 | ▇▁▁▁▁ |
DiscMM | 0 | 1 | 0.12 | 0.21 | 0.00 | 0.00 | 0.00 | 0.23 | 0.80 | ▇▁▂▁▁ |
SpecialCH | 0 | 1 | 0.15 | 0.35 | 0.00 | 0.00 | 0.00 | 0.00 | 1.00 | ▇▁▁▁▂ |
SpecialMM | 0 | 1 | 0.16 | 0.37 | 0.00 | 0.00 | 0.00 | 0.00 | 1.00 | ▇▁▁▁▂ |
LoyalCH | 0 | 1 | 0.57 | 0.31 | 0.00 | 0.33 | 0.60 | 0.85 | 1.00 | ▅▃▆▆▇ |
SalePriceMM | 0 | 1 | 1.96 | 0.25 | 1.19 | 1.69 | 2.09 | 2.13 | 2.29 | ▁▂▂▂▇ |
SalePriceCH | 0 | 1 | 1.82 | 0.14 | 1.39 | 1.75 | 1.86 | 1.89 | 2.09 | ▂▁▇▇▅ |
PriceDiff | 0 | 1 | 0.15 | 0.27 | -0.67 | 0.00 | 0.23 | 0.32 | 0.64 | ▁▂▃▇▂ |
PctDiscMM | 0 | 1 | 0.06 | 0.10 | 0.00 | 0.00 | 0.00 | 0.11 | 0.40 | ▇▁▂▁▁ |
PctDiscCH | 0 | 1 | 0.03 | 0.06 | 0.00 | 0.00 | 0.00 | 0.00 | 0.25 | ▇▁▁▁▁ |
ListPriceDiff | 0 | 1 | 0.22 | 0.11 | 0.00 | 0.14 | 0.24 | 0.30 | 0.44 | ▂▃▆▇▁ |
STORE | 0 | 1 | 1.63 | 1.43 | 0.00 | 0.00 | 2.00 | 3.00 | 4.00 | ▇▃▅▅▃ |
Dividiremos nuestra base de datos oj_dat
(n = 1070) en oj_train
(80%, n = 857) para estimar varios modelos, y oj_test
(20%, n = 213) para comparar su rendimiento con datos nuevos.
set.seed(12345)
<- createDataPartition(y = oj_dat$Purchase, p = 0.8, list = FALSE)
partition <- oj_dat[partition, ]
oj_train <- oj_dat[-partition, ] oj_test
La función rpart::rpart()
construye un árbol completo, minimizando el índice de Gini \(G\) por defecto (parms = list (split = "gini")
), hasta que se cumpla el criterio de parada. El criterio de parada predeterminado es:
- solo intenta una división si el nodo actual tiene al menos
minsplit = 20
observaciones, y - solo acepta una división si
- los nodos resultantes tienen al menos
minbucket = round (minsplit / 3)
observaciones, y - el ajuste general resultante mejora en
cp = 0.01
(es decir, \(\Delta G <= 0.01\)).
- los nodos resultantes tienen al menos
# Usar method = "class" para clasificación y method = "anova" para regresión
set.seed(123)
<- rpart(formula = Purchase ~ ., data = oj_train,
oj_mdl_cart_full method = "class")
oj_mdl_cart_full
n= 857
node), split, n, loss, yval, (yprob)
* denotes terminal node
1) root 857 334 CH (0.61026838 0.38973162)
2) LoyalCH>=0.48285 537 94 CH (0.82495345 0.17504655)
4) LoyalCH>=0.7648795 271 13 CH (0.95202952 0.04797048) *
5) LoyalCH< 0.7648795 266 81 CH (0.69548872 0.30451128)
10) PriceDiff>=-0.165 226 50 CH (0.77876106 0.22123894) *
11) PriceDiff< -0.165 40 9 MM (0.22500000 0.77500000) *
3) LoyalCH< 0.48285 320 80 MM (0.25000000 0.75000000)
6) LoyalCH>=0.2761415 146 58 MM (0.39726027 0.60273973)
12) SalePriceMM>=2.04 71 31 CH (0.56338028 0.43661972) *
13) SalePriceMM< 2.04 75 18 MM (0.24000000 0.76000000) *
7) LoyalCH< 0.2761415 174 22 MM (0.12643678 0.87356322) *
La salida comienza con el nodo raíz. La clase predicha en la raíz es CH
y esta predicción produce 334 errores en las 857 observaciones para una tasa de éxito (precisión) del 61% y una tasa de error del 39%. Los nodos secundarios del nodo “x” están etiquetados como 2x) y 2x + 1), por lo que los nodos secundarios de 1) son 2) y 3), y los nodos secundarios de 2) son 4) y 5). Los nodos terminales están etiquetados con un asterisco (*).
Sorprendentemente, solo 3 de las 17 variables se utilizaron en el árbol completo: LoyalCH (lealtad de marca del cliente para CH), PriceDiff (precio relativo de MM sobre CH) y SalePriceMM (precio absoluto de MM). La primera división está en LoyalCH = 0.48285. Aquí hay un diagrama del árbol completo (sin podar).
rpart.plot(oj_mdl_cart_full, yesno = TRUE)
Las cajas muestran la clasificación del nodo (según la moda), la proporción de observaciones que no son CH
y la proporción de observaciones incluidas en el nodo.
rpart ()
no solo hizo crecer el árbol completo, sino que identificó el conjunto de parámetros de complejidad de costos y midió el rendimiento del modelo de cada árbol correspondiente mediante validación cruzada. printcp ()
muestra los posibles valores de \(c_p\). La siguiente tabla se puede utilizar esta tabla para decidir cómo podar el árbol.
printcp(oj_mdl_cart_full)
Classification tree:
rpart(formula = Purchase ~ ., data = oj_train, method = "class")
Variables actually used in tree construction:
[1] LoyalCH PriceDiff SalePriceMM
Root node error: 334/857 = 0.38973
n= 857
CP nsplit rel error xerror xstd
1 0.479042 0 1.00000 1.00000 0.042745
2 0.032934 1 0.52096 0.54192 0.035775
3 0.013473 3 0.45509 0.47006 0.033905
4 0.010000 5 0.42814 0.46407 0.033736
Hay 4 valores de \(c_p\) en este modelo. El modelo con el parámetro de complejidad más pequeño permite la mayoría de las divisiones (nsplit
). El parámetro de mayor complejidad corresponde a un árbol con solo un nodo raíz. rel error
es la tasa de error relativa al nodo raíz. El error absoluto del nodo raíz es 0.38973162 (la proporción de MM), por lo que su rel error
es 0.38973162 / 0.38973162 = 1.0. Eso significa que el error absoluto del árbol completo (en CP = 0.01) es 0.42814 * 0.38973162 = 0.1669. Podemos verificarlo calculando la tasa de error de los valores predichos:
<- predict(oj_mdl_cart_full, newdata = oj_train, type = "class")
pred mean(oj_train$Purchase != pred)
[1] 0.1668611
Para acaber de explicar toda la salida de la table CP, xerror
es la tasa de error relativa con validación cruzada y xstd
es su error estándar. Si se desea el error más bajo posible, deberíamos podar el árbol con el error de CV relativo más pequeño, \(c_p=0.01\). Si deseamos equilibrar el poder predictivo con la simplicidad, podaremos al árbol más pequeño dentro de 1 SE del que tiene el error relativo más pequeño. La tabla CP no es muy útil para encontrar ese árbol, así que añadiremos una columna para encontrarlo.
$cptable %>%
oj_mdl_cart_fulldata.frame() %>%
mutate(
min_idx = which.min(oj_mdl_cart_full$cptable[, "xerror"]),
rownum = row_number(),
xerror_cap = oj_mdl_cart_full$cptable[min_idx, "xerror"] +
$cptable[min_idx, "xstd"],
oj_mdl_cart_fulleval = case_when(rownum == min_idx ~ "min xerror",
< xerror_cap ~ "under cap",
xerror TRUE ~ "")
%>%
) ::select(-rownum, -min_idx) dplyr
CP nsplit rel.error xerror xstd xerror_cap eval
1 0.47904192 0 1.0000000 1.0000000 0.04274518 0.4978082
2 0.03293413 1 0.5209581 0.5419162 0.03577468 0.4978082
3 0.01347305 3 0.4550898 0.4700599 0.03390486 0.4978082 under cap
4 0.01000000 5 0.4281437 0.4640719 0.03373631 0.4978082 min xerror
El árbol más simple que usa la regla 1-SE es \(c_p = 0.01347305\) (error CV = 0.18). Afortunadamente, plotcp ()
nos da una representación gráfica de la relación entre xerror
y cp
.
plotcp(oj_mdl_cart_full, upper = "splits")
La línea discontinua se establece en el mínimo xerror + xstd
. El eje superior muestra el número de divisiones en el árbol. NOTA: No estoy seguro de por qué los valores de CP no son los mismos que en la tabla (están cerca, pero no son los mismos). La figura sugiere que debería podar a 5 o 3 divisiones. Vemos que esta curva nunca llega al mínimo, sigue disminuyendo en 5 divisiones. El valor del parámetro de ajuste predeterminado cp = 0.01
puede ser demasiado grande, así que lo cambiaremos a cp = 0.001
y empezaremos de nuevo.
set.seed(123)
<- rpart(
oj_mdl_cart_full formula = Purchase ~ .,
data = oj_train,
method = "class",
cp = 0.001
)print(oj_mdl_cart_full)
n= 857
node), split, n, loss, yval, (yprob)
* denotes terminal node
1) root 857 334 CH (0.61026838 0.38973162)
2) LoyalCH>=0.48285 537 94 CH (0.82495345 0.17504655)
4) LoyalCH>=0.7648795 271 13 CH (0.95202952 0.04797048) *
5) LoyalCH< 0.7648795 266 81 CH (0.69548872 0.30451128)
10) PriceDiff>=-0.165 226 50 CH (0.77876106 0.22123894)
20) ListPriceDiff>=0.255 115 11 CH (0.90434783 0.09565217) *
21) ListPriceDiff< 0.255 111 39 CH (0.64864865 0.35135135)
42) PriceMM>=2.155 19 2 CH (0.89473684 0.10526316) *
43) PriceMM< 2.155 92 37 CH (0.59782609 0.40217391)
86) DiscCH>=0.115 7 0 CH (1.00000000 0.00000000) *
87) DiscCH< 0.115 85 37 CH (0.56470588 0.43529412)
174) ListPriceDiff>=0.215 45 15 CH (0.66666667 0.33333333) *
175) ListPriceDiff< 0.215 40 18 MM (0.45000000 0.55000000)
350) LoyalCH>=0.527571 28 13 CH (0.53571429 0.46428571)
700) WeekofPurchase< 266.5 21 8 CH (0.61904762 0.38095238) *
701) WeekofPurchase>=266.5 7 2 MM (0.28571429 0.71428571) *
351) LoyalCH< 0.527571 12 3 MM (0.25000000 0.75000000) *
11) PriceDiff< -0.165 40 9 MM (0.22500000 0.77500000) *
3) LoyalCH< 0.48285 320 80 MM (0.25000000 0.75000000)
6) LoyalCH>=0.2761415 146 58 MM (0.39726027 0.60273973)
12) SalePriceMM>=2.04 71 31 CH (0.56338028 0.43661972)
24) LoyalCH< 0.303104 7 0 CH (1.00000000 0.00000000) *
25) LoyalCH>=0.303104 64 31 CH (0.51562500 0.48437500)
50) WeekofPurchase>=246.5 52 22 CH (0.57692308 0.42307692)
100) PriceCH< 1.94 35 11 CH (0.68571429 0.31428571)
200) StoreID< 1.5 9 1 CH (0.88888889 0.11111111) *
201) StoreID>=1.5 26 10 CH (0.61538462 0.38461538)
402) LoyalCH< 0.410969 17 4 CH (0.76470588 0.23529412) *
403) LoyalCH>=0.410969 9 3 MM (0.33333333 0.66666667) *
101) PriceCH>=1.94 17 6 MM (0.35294118 0.64705882) *
51) WeekofPurchase< 246.5 12 3 MM (0.25000000 0.75000000) *
13) SalePriceMM< 2.04 75 18 MM (0.24000000 0.76000000)
26) SpecialCH>=0.5 14 6 CH (0.57142857 0.42857143) *
27) SpecialCH< 0.5 61 10 MM (0.16393443 0.83606557) *
7) LoyalCH< 0.2761415 174 22 MM (0.12643678 0.87356322)
14) LoyalCH>=0.035047 117 21 MM (0.17948718 0.82051282)
28) WeekofPurchase< 273.5 104 21 MM (0.20192308 0.79807692)
56) PriceCH>=1.875 20 9 MM (0.45000000 0.55000000)
112) WeekofPurchase>=252.5 12 5 CH (0.58333333 0.41666667) *
113) WeekofPurchase< 252.5 8 2 MM (0.25000000 0.75000000) *
57) PriceCH< 1.875 84 12 MM (0.14285714 0.85714286) *
29) WeekofPurchase>=273.5 13 0 MM (0.00000000 1.00000000) *
15) LoyalCH< 0.035047 57 1 MM (0.01754386 0.98245614) *
Este es un árbol mucho más grande. ¿Encontramo un valor cp
que produce un mínimo?
plotcp(oj_mdl_cart_full, upper = "splits")
Sí, el mínimo está en CP = 0.011 con 5 divisiones. El mínimo + 1 SE está en CP = 0.021 con 3 divisiones. Podaremos entonces el árbol en 3.
<- prune(
oj_mdl_cart
oj_mdl_cart_full,cp = oj_mdl_cart_full$cptable[oj_mdl_cart_full$cptable[, 2] == 3, "CP"]
)rpart.plot(oj_mdl_cart, yesno = TRUE)
El indicador de compra más “importante” parece ser LoyalCH
. De la vignette de rpart (página 12) tenemos que:
An overall measure of variable importance is the sum of the goodness of split measures for each split for which it was the primary variable, plus goodness (adjusted agreement) for all splits in which it was a surrogate.
“Surrogate” (variable subrogada[^La FDA define una variable subrogada como “una medida de laboratorio o signo físico que se usa en ensayos terapéuticos como sustituto de una variable clínicamente significativa que es una medida directa sobre lo que siente un paciente, sus funciones o su supervivencia y que se espera que prediga el efecto de la terapia]) se refieren a características alternativas para que un nodo maneje los datos faltantes. Para cada división, CART evalúa una variedad de divisiones alternativas”sustitutas” para usar cuando el valor de la característica para la división principal es NA. Las divisiones sustitutas son divisiones que producen resultados similares a la división original.
La importancia de una variable es la suma de la mejora en la medida general de Gini (o RMSE) producida por los nodos en los que aparece. En el siguiente gráfico podemos ver la importancia de cada variable para este modelo.
$variable.importance %>%
oj_mdl_cartdata.frame() %>%
rownames_to_column(var = "Feature") %>%
rename(Overall = '.') %>%
ggplot(aes(x = fct_reorder(Feature, Overall), y = Overall)) +
geom_pointrange(aes(ymin = 0, ymax = Overall), color = "cadetblue", size = .3) +
theme_minimal() +
coord_flip() +
labs(x = "", y = "", title = "Importancia mediante clasificación simple")
LoyalCH
es, con mucho, la variable más importante, como se esperaba de su posición en la parte superior del árbol así como en el siguiente nivel abajo.
Podemos ver cómo aparecen los variables subrogadas en el modelo con la función summary()
.
summary(oj_mdl_cart)
Call:
rpart(formula = Purchase ~ ., data = oj_train, method = "class",
cp = 0.001)
n= 857
CP nsplit rel error xerror xstd
1 0.47904192 0 1.0000000 1.0000000 0.04274518
2 0.03293413 1 0.5209581 0.5419162 0.03577468
3 0.01347305 3 0.4550898 0.4700599 0.03390486
Variable importance
LoyalCH PriceDiff SalePriceMM StoreID WeekofPurchase DiscMM
67 9 5 4 4 3
PriceMM PctDiscMM PriceCH
3 3 1
Node number 1: 857 observations, complexity param=0.4790419
predicted class=CH expected loss=0.3897316 P(node) =1
class counts: 523 334
probabilities: 0.610 0.390
left son=2 (537 obs) right son=3 (320 obs)
Primary splits:
LoyalCH < 0.48285 to the right, improve=132.56800, (0 missing)
StoreID < 3.5 to the right, improve= 40.12097, (0 missing)
PriceDiff < 0.015 to the right, improve= 24.26552, (0 missing)
ListPriceDiff < 0.255 to the right, improve= 22.79117, (0 missing)
SalePriceMM < 1.84 to the right, improve= 20.16447, (0 missing)
Surrogate splits:
StoreID < 3.5 to the right, agree=0.646, adj=0.053, (0 split)
PriceMM < 1.89 to the right, agree=0.638, adj=0.031, (0 split)
WeekofPurchase < 229.5 to the right, agree=0.632, adj=0.016, (0 split)
DiscMM < 0.77 to the left, agree=0.629, adj=0.006, (0 split)
SalePriceMM < 1.385 to the right, agree=0.629, adj=0.006, (0 split)
Node number 2: 537 observations, complexity param=0.03293413
predicted class=CH expected loss=0.1750466 P(node) =0.6266044
class counts: 443 94
probabilities: 0.825 0.175
left son=4 (271 obs) right son=5 (266 obs)
Primary splits:
LoyalCH < 0.7648795 to the right, improve=17.669310, (0 missing)
PriceDiff < 0.015 to the right, improve=15.475200, (0 missing)
SalePriceMM < 1.84 to the right, improve=13.951730, (0 missing)
ListPriceDiff < 0.255 to the right, improve=11.407560, (0 missing)
DiscMM < 0.15 to the left, improve= 7.795122, (0 missing)
Surrogate splits:
WeekofPurchase < 257.5 to the right, agree=0.594, adj=0.180, (0 split)
PriceCH < 1.775 to the right, agree=0.590, adj=0.173, (0 split)
StoreID < 3.5 to the right, agree=0.587, adj=0.165, (0 split)
PriceMM < 2.04 to the right, agree=0.587, adj=0.165, (0 split)
SalePriceMM < 2.04 to the right, agree=0.587, adj=0.165, (0 split)
Node number 3: 320 observations
predicted class=MM expected loss=0.25 P(node) =0.3733956
class counts: 80 240
probabilities: 0.250 0.750
Node number 4: 271 observations
predicted class=CH expected loss=0.04797048 P(node) =0.3162194
class counts: 258 13
probabilities: 0.952 0.048
Node number 5: 266 observations, complexity param=0.03293413
predicted class=CH expected loss=0.3045113 P(node) =0.3103851
class counts: 185 81
probabilities: 0.695 0.305
left son=10 (226 obs) right son=11 (40 obs)
Primary splits:
PriceDiff < -0.165 to the right, improve=20.84307, (0 missing)
ListPriceDiff < 0.235 to the right, improve=20.82404, (0 missing)
SalePriceMM < 1.84 to the right, improve=16.80587, (0 missing)
DiscMM < 0.15 to the left, improve=10.05120, (0 missing)
PctDiscMM < 0.0729725 to the left, improve=10.05120, (0 missing)
Surrogate splits:
SalePriceMM < 1.585 to the right, agree=0.906, adj=0.375, (0 split)
DiscMM < 0.57 to the left, agree=0.895, adj=0.300, (0 split)
PctDiscMM < 0.264375 to the left, agree=0.895, adj=0.300, (0 split)
WeekofPurchase < 274.5 to the left, agree=0.872, adj=0.150, (0 split)
SalePriceCH < 2.075 to the left, agree=0.857, adj=0.050, (0 split)
Node number 10: 226 observations
predicted class=CH expected loss=0.2212389 P(node) =0.2637106
class counts: 176 50
probabilities: 0.779 0.221
Node number 11: 40 observations
predicted class=MM expected loss=0.225 P(node) =0.04667445
class counts: 9 31
probabilities: 0.225 0.775
Una vez tenemos un modelo (o varios) los podemos evaluar en la muestra test con las medidas estándard
<- predict(oj_mdl_cart, newdata = oj_test, type = "class")
pred <- confusionMatrix(pred, oj_test$Purchase)
oj_cm_cart oj_cm_cart
Confusion Matrix and Statistics
Reference
Prediction CH MM
CH 113 13
MM 17 70
Accuracy : 0.8592
95% CI : (0.8051, 0.9029)
No Information Rate : 0.6103
P-Value [Acc > NIR] : 1.265e-15
Kappa : 0.7064
Mcnemar's Test P-Value : 0.5839
Sensitivity : 0.8692
Specificity : 0.8434
Pos Pred Value : 0.8968
Neg Pred Value : 0.8046
Prevalence : 0.6103
Detection Rate : 0.5305
Detection Prevalence : 0.5915
Balanced Accuracy : 0.8563
'Positive' Class : CH
También podemos representar gráficamente la tabla de confusión
plot(oj_test$Purchase, pred,
main = "Clasificación: Predicho vs. Observado",
xlab = "Observado",
ylab = "Predicho")
14.2 Área bajo la curva ROC
También podemos calcular el área bajo la curva ROC. La curva ROC (características operativas del receptor) es otra medida de precisión. Corresponde a un gráfico de la tasa de verdaderos positivos (TPR, sensibilidad) versus la tasa de falsos positivos (FPR, 1 - especificidad) para un conjunto de umbrales. De forma predeterminada, el umbral para predecir la clasificación predeterminada es 0.50, pero podría ser cualquier umbral. La función precrec::evalmod ()
calcula los valores de la matriz de confusión del modelo usando el conjunto de datos test. El AUC en el conjunto de datos test es 0.8848 y podemos calcularlo con varias funciones: pROC::plot.roc ()
, plotROC::geom_roc ()
, yardstick::roc_curve ()
y plotROC
para usar ggplot()
[geometría geom_roc ()
].
Nosotros usaremos pROC
. Para ello necesitamos tener las predicciones como probabilidades para la categoría de referencia. NOTA: El AUC es, pues, una medida útil para casos donde el predictor es binario.
library(pROC)
<- predict(oj_mdl_cart, newdata = oj_test, type = "prob")[,"CH"]
pred2 <- roc(oj_test$Purchase, pred2, print.auc=TRUE,
roc.car ci=TRUE,
plot=TRUE)
14.2.1 Entrenamiento con caret
También podemos ajustar el modelo con la función caret::train ()
. Recordemo que hay dos formas de ajustar los hiperparámetros cuando usamos train ()
:
- establecer el número de valores de parámetros de ajuste a considerar utilizando
tuneLength
, o - establecer ciertos valores para cada parámetro utilizando
tuneGrid
.
ESTRATEGIA: Construiremos el modelo usando una validación cruzada de 10 veces para optimizar el hiperparámetro CP. Si no tenemos idea de cuál es el parámetro de ajuste óptimo, empezaremos con tuneLength
para aproximarnos al valor óptimo y luego ajustaremos el valor con tuneGrid
. Crearemos un objeto de control de entrenamiento que puedo reutilizar en otras compilaciones de modelos.
= trainControl (method = "cv",
oj_trControl number = 10,
savePredictions = "final", # guardaremos preds para el valor óptimo del parámetro a tunear
classProbs = TRUE, # probs para las clases además de preds
summaryFunction = twoClassSummary
)
Ahora estimamos el modelo con
set.seed(1234)
<- train(
oj_mdl_cart2 ~ .,
Purchase data = oj_train,
method = "rpart",
tuneLength = 5,
metric = "ROC",
trControl = oj_trControl
)
caret
construye un árbol completo usando los parámetros predeterminados de rpart
que son: índice de división de Gini, al menos 20 observaciones en un nodo para considerar dividirlo, y al menos 6 observaciones en cada nodo. Luego, caret
calcula la precisión para cada valor candidato del hiperparámetro (CP). Estos son los resultados:
oj_mdl_cart2
CART
857 samples
17 predictor
2 classes: 'CH', 'MM'
No pre-processing
Resampling: Cross-Validated (10 fold)
Summary of sample sizes: 772, 772, 771, 770, 771, 771, ...
Resampling results across tuning parameters:
cp ROC Sens Spec
0.005988024 0.8539885 0.8605225 0.7274510
0.008982036 0.8502309 0.8568578 0.7334225
0.013473054 0.8459290 0.8473149 0.7397504
0.032934132 0.7776483 0.8509071 0.6796791
0.479041916 0.5878764 0.9201379 0.2556150
ROC was used to select the optimal model using the largest value.
The final value used for the model was cp = 0.005988024.
El segundo CP (0.008982036) produce la mayor precisión. Podemos profundizar en el mejor valor de CP usando un tuning grid.
set.seed(1234)
<- train(
oj_mdl_cart2 ~ .,
Purchase data = oj_train,
method = "rpart",
tuneGrid = expand.grid(cp = seq(from = 0.001, to = 0.010, length = 11)),
metric = "ROC",
trControl = oj_trControl
)print(oj_mdl_cart2)
CART
857 samples
17 predictor
2 classes: 'CH', 'MM'
No pre-processing
Resampling: Cross-Validated (10 fold)
Summary of sample sizes: 772, 772, 771, 770, 771, 771, ...
Resampling results across tuning parameters:
cp ROC Sens Spec
0.0010 0.8513056 0.8529390 0.7182709
0.0019 0.8528471 0.8529753 0.7213012
0.0028 0.8524435 0.8510522 0.7302139
0.0037 0.8533529 0.8510522 0.7421569
0.0046 0.8540042 0.8491292 0.7333333
0.0055 0.8543820 0.8567126 0.7334225
0.0064 0.8539885 0.8605225 0.7274510
0.0073 0.8521076 0.8625181 0.7335116
0.0082 0.8521076 0.8625181 0.7335116
0.0091 0.8502309 0.8568578 0.7334225
0.0100 0.8507262 0.8510885 0.7424242
ROC was used to select the optimal model using the largest value.
The final value used for the model was cp = 0.0055.
El mejor modelo se consigue con CP = 0.0082. A continuación podemos ver las precisiones de validación cruzada para los valores de CP candidatos.
plot(oj_mdl_cart2)
Estos son los resultados para el modelo final:
$finalModel oj_mdl_cart2
n= 857
node), split, n, loss, yval, (yprob)
* denotes terminal node
1) root 857 334 CH (0.61026838 0.38973162)
2) LoyalCH>=0.48285 537 94 CH (0.82495345 0.17504655)
4) LoyalCH>=0.7648795 271 13 CH (0.95202952 0.04797048) *
5) LoyalCH< 0.7648795 266 81 CH (0.69548872 0.30451128)
10) PriceDiff>=-0.165 226 50 CH (0.77876106 0.22123894) *
11) PriceDiff< -0.165 40 9 MM (0.22500000 0.77500000) *
3) LoyalCH< 0.48285 320 80 MM (0.25000000 0.75000000)
6) LoyalCH>=0.2761415 146 58 MM (0.39726027 0.60273973)
12) SalePriceMM>=2.04 71 31 CH (0.56338028 0.43661972)
24) LoyalCH< 0.303104 7 0 CH (1.00000000 0.00000000) *
25) LoyalCH>=0.303104 64 31 CH (0.51562500 0.48437500)
50) WeekofPurchase>=246.5 52 22 CH (0.57692308 0.42307692)
100) PriceCH< 1.94 35 11 CH (0.68571429 0.31428571) *
101) PriceCH>=1.94 17 6 MM (0.35294118 0.64705882) *
51) WeekofPurchase< 246.5 12 3 MM (0.25000000 0.75000000) *
13) SalePriceMM< 2.04 75 18 MM (0.24000000 0.76000000)
26) SpecialCH>=0.5 14 6 CH (0.57142857 0.42857143) *
27) SpecialCH< 0.5 61 10 MM (0.16393443 0.83606557) *
7) LoyalCH< 0.2761415 174 22 MM (0.12643678 0.87356322) *
rpart.plot(oj_mdl_cart2$finalModel)
Veamos el rendimiento en la muestra test:
<- predict(oj_mdl_cart2, newdata = oj_test, type = "raw")
pred3 <- confusionMatrix(pred3, oj_test$Purchase)
oj_cm_cart2 oj_cm_cart2
Confusion Matrix and Statistics
Reference
Prediction CH MM
CH 117 18
MM 13 65
Accuracy : 0.8545
95% CI : (0.7998, 0.8989)
No Information Rate : 0.6103
P-Value [Acc > NIR] : 4.83e-15
Kappa : 0.6907
Mcnemar's Test P-Value : 0.4725
Sensitivity : 0.9000
Specificity : 0.7831
Pos Pred Value : 0.8667
Neg Pred Value : 0.8333
Prevalence : 0.6103
Detection Rate : 0.5493
Detection Prevalence : 0.6338
Balanced Accuracy : 0.8416
'Positive' Class : CH
La precisión es 0.8545, un poco peor que la 0.8592 del método directo. El AUC es 0.916 que es mejor que el obtenido con el método directo.
<- predict(oj_mdl_cart2, newdata = oj_test, type = "prob")[,"CH"]
pred4 <- roc(oj_test$Purchase, pred4, print.auc=TRUE,
roc.car2 ci=TRUE,
plot=TRUE)
Podemos comparar ambas curvas ROC mediante el test de DeLong
roc.test(roc.car, roc.car2)
DeLong's test for two correlated ROC curves
data: roc.car and roc.car2
Z = -2.4259, p-value = 0.01527
alternative hypothesis: true difference in AUC is not equal to 0
95 percent confidence interval:
-0.056801412 -0.006034547
sample estimates:
AUC of roc1 AUC of roc2
0.8848471 0.9162651
Finalmente, podemos crear fácilmente la gráfica de importancia de variables con la función varImp ()
. La lealtad a la marca es lo más importante, seguida de la diferencia de precio.
plot(varImp(oj_mdl_cart2), main="Importancia de variables con CART (caret)")
Parece que con la estrategia de caret
hemos conseguido un mejor modelo predictivo gracias, sobre todo, a la posibilidad de buscar el mejor hiperparámetro haciend fine tuning.
<- rbind(
oj_scoreboard data.frame(Modelo = "Single Tree",
Accuracy = oj_cm_cart$overall["Accuracy"],
ROC = roc.car$auc),
data.frame(Modelo = "Single Tree (caret)",
Accuracy = oj_cm_cart2$overall["Accuracy"],
ROC = roc.car2$auc)) %>%
arrange(desc(ROC))
::kable(oj_scoreboard, row.names = FALSE) knitr
Modelo | Accuracy | ROC |
---|---|---|
Single Tree (caret) | 0.8544601 | 0.9162651 |
Single Tree | 0.8591549 | 0.8848471 |
14.3 Árboles de regresión
Un árbol de regresión simple se construye de manera similar a un árbol de clasificación simple y, al igual que el árbol de clasificación, rara vez se usan por sí solo (sobre todo en problemas complejos o de big data). De nuevo, basaremos el aprendizaje de esta metodología partiendo de un ejemplo real. Usaremos el conjunto de datos ISLR::Carseats
que pretende predecir las ventas de sillitas de niños para coches (variable Sales
) en 400 tiendas usando 10 variables que contienen información de las características de las sillas.
<- ISLR::Carseats
cs_dat ::skim(cs_dat) skimr
Name | cs_dat |
Number of rows | 400 |
Number of columns | 11 |
_______________________ | |
Column type frequency: | |
factor | 3 |
numeric | 8 |
________________________ | |
Group variables | None |
Variable type: factor
skim_variable | n_missing | complete_rate | ordered | n_unique | top_counts |
---|---|---|---|---|---|
ShelveLoc | 0 | 1 | FALSE | 3 | Med: 219, Bad: 96, Goo: 85 |
Urban | 0 | 1 | FALSE | 2 | Yes: 282, No: 118 |
US | 0 | 1 | FALSE | 2 | Yes: 258, No: 142 |
Variable type: numeric
skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
---|---|---|---|---|---|---|---|---|---|---|
Sales | 0 | 1 | 7.50 | 2.82 | 0 | 5.39 | 7.49 | 9.32 | 16.27 | ▁▆▇▃▁ |
CompPrice | 0 | 1 | 124.97 | 15.33 | 77 | 115.00 | 125.00 | 135.00 | 175.00 | ▁▅▇▃▁ |
Income | 0 | 1 | 68.66 | 27.99 | 21 | 42.75 | 69.00 | 91.00 | 120.00 | ▇▆▇▆▅ |
Advertising | 0 | 1 | 6.64 | 6.65 | 0 | 0.00 | 5.00 | 12.00 | 29.00 | ▇▃▃▁▁ |
Population | 0 | 1 | 264.84 | 147.38 | 10 | 139.00 | 272.00 | 398.50 | 509.00 | ▇▇▇▇▇ |
Price | 0 | 1 | 115.80 | 23.68 | 24 | 100.00 | 117.00 | 131.00 | 191.00 | ▁▂▇▆▁ |
Age | 0 | 1 | 53.32 | 16.20 | 25 | 39.75 | 54.50 | 66.00 | 80.00 | ▇▆▇▇▇ |
Education | 0 | 1 | 13.90 | 2.62 | 10 | 12.00 | 14.00 | 16.00 | 18.00 | ▇▇▃▇▇ |
De nuevo, partiremos nuestro conjunto de datos cs_dat
(n = 400) en cs_train
(80%, n = 321) y cs_test
(20%, n = 79).
set.seed(12345)
<- createDataPartition(y = cs_dat$Sales, p = 0.8, list = FALSE)
partition <- cs_dat[partition, ]
cs_train <- cs_dat[-partition, ] cs_test
El primer paso es construir un árbol completo y luego realizar una validación cruzada para ayudar a seleccionar la complejidad de costo óptima (cp). La única diferencia ahora es que usaremos method = "anova"
en la función rpart ()
para poder estimar un árbol de regresión.
set.seed(1234)
<- rpart(Sales ~ ., cs_train, method = "anova")
cs_mdl_cart_full cs_mdl_cart_full
n= 321
node), split, n, deviance, yval
* denotes terminal node
1) root 321 2567.76800 7.535950
2) ShelveLoc=Bad,Medium 251 1474.14100 6.770359
4) Price>=105.5 168 719.70630 5.987024
8) ShelveLoc=Bad 50 165.70160 4.693600
16) Population< 201.5 20 48.35505 3.646500 *
17) Population>=201.5 30 80.79922 5.391667 *
9) ShelveLoc=Medium 118 434.91370 6.535085
18) Advertising< 11.5 88 290.05490 6.113068
36) CompPrice< 142 69 193.86340 5.769420
72) Price>=132.5 16 50.75440 4.455000 *
73) Price< 132.5 53 107.12060 6.166226 *
37) CompPrice>=142 19 58.45118 7.361053 *
19) Advertising>=11.5 30 83.21323 7.773000 *
5) Price< 105.5 83 442.68920 8.355904
10) Age>=63.5 32 153.42300 6.922500
20) Price>=85 25 66.89398 6.160800
40) ShelveLoc=Bad 9 18.39396 4.772222 *
41) ShelveLoc=Medium 16 21.38544 6.941875 *
21) Price< 85 7 20.22194 9.642857 *
11) Age< 63.5 51 182.26350 9.255294
22) Income< 57.5 12 28.03042 7.707500 *
23) Income>=57.5 39 116.63950 9.731538
46) Age>=50.5 14 21.32597 8.451429 *
47) Age< 50.5 25 59.52474 10.448400 *
3) ShelveLoc=Good 70 418.98290 10.281140
6) Price>=107.5 49 242.58730 9.441633
12) Advertising< 13.5 41 162.47820 8.926098
24) Age>=61 17 53.37051 7.757647 *
25) Age< 61 24 69.45776 9.753750 *
13) Advertising>=13.5 8 13.36599 12.083750 *
7) Price< 107.5 21 61.28200 12.240000 *
Las ventas pronosticadas en la raíz son las ventas medias para el conjunto de datos de entrenamiento, 7.5 (los valores corresponden a miles de dolares). La primera división está en ShelveLoc = [Bad, Medium]
vs Good
(calidad). Aquí está el diagrama de árbol sin podar.
rpart.plot(cs_mdl_cart_full, yesno = TRUE)
Cada caja muestra el valor predicho del nodo (media) y la proporción de observaciones que están en el nodo (o nodos secundarios).
rpart ()
estima el árbol completo y utiliza validación cruzada para probar el rendimiento de los posibles hiperparámetros de complejidad. Como antes, printcp ()
muestra los valores de cp candidatos que pueden verse en esta tabla. Estos datos pueden ser utilizados para decidir cómo podar el árbol.
printcp(cs_mdl_cart_full)
Regression tree:
rpart(formula = Sales ~ ., data = cs_train, method = "anova")
Variables actually used in tree construction:
[1] Advertising Age CompPrice Income Population Price ShelveLoc
Root node error: 2567.8/321 = 7.9993
n= 321
CP nsplit rel error xerror xstd
1 0.262736 0 1.00000 1.00635 0.076664
2 0.121407 1 0.73726 0.74888 0.058981
3 0.046379 2 0.61586 0.65278 0.050839
4 0.044830 3 0.56948 0.67245 0.051638
5 0.041671 4 0.52465 0.66230 0.051065
6 0.025993 5 0.48298 0.62345 0.049368
7 0.025823 6 0.45698 0.61980 0.048026
8 0.024007 7 0.43116 0.62058 0.048213
9 0.015441 8 0.40715 0.58061 0.041738
10 0.014698 9 0.39171 0.56413 0.041368
11 0.014641 10 0.37701 0.56277 0.041271
12 0.014233 11 0.36237 0.56081 0.041097
13 0.014015 12 0.34814 0.55647 0.038308
14 0.013938 13 0.33413 0.55647 0.038308
15 0.010560 14 0.32019 0.57110 0.038872
16 0.010000 15 0.30963 0.56676 0.038090
Hay 16 posibles valores de cp en este modelo. El modelo con el parámetro de complejidad más pequeño permite la mayoría de las divisiones (nsplit). El parámetro de mayor complejidad corresponde a un árbol con solo un nodo raíz. rel error
es el SSE relativo al nodo raíz. El SSE del nodo raíz es 2567.76800, por lo que su error rel es 2567.76800 / 2567.76800 = 1.0. Eso significa que el error absoluto del árbol completo (en CP = 0.01) es 0.30963 * 2567.76800 = 795.058. Podemos verificar estos resultados calculando el SSE de los valores predichos del modelo:
data.frame(pred = predict(cs_mdl_cart_full, newdata = cs_train)) %>%
mutate(obs = cs_train$Sales,
sq_err = (obs - pred)^2) %>%
summarise(sse = sum(sq_err))
sse
1 795.0525
La tabla también muestra, xerror
que corresponde al SSE con validación cruzada y xstd
a su error estándar. Si deseamos el error más bajo posible, podaremos el árbol con el SSE relativo más pequeño (xerror
). Si deseamos equilibrar el poder predictivo con la simplicidad, podaremos al árbol más pequeño que esté dentro de 1 SE para el SSE relativo más pequeño. Al igual que en la sección anterior, la tabla CP no es muy útil para encontrar ese árbol, por lo que debemos añadir una columna para visualizar dicha información:
$cptable %>%
cs_mdl_cart_fulldata.frame() %>%
mutate(min_xerror_idx = which.min(cs_mdl_cart_full$cptable[, "xerror"]),
rownum = row_number(),
xerror_cap = cs_mdl_cart_full$cptable[min_xerror_idx, "xerror"] +
$cptable[min_xerror_idx, "xstd"],
cs_mdl_cart_fulleval = case_when(rownum == min_xerror_idx ~ "min xerror",
< xerror_cap ~ "under cap",
xerror TRUE ~ "")) %>%
::select(-rownum, -min_xerror_idx) dplyr
CP nsplit rel.error xerror xstd xerror_cap eval
1 0.26273578 0 1.0000000 1.0063530 0.07666355 0.5947744
2 0.12140705 1 0.7372642 0.7488767 0.05898146 0.5947744
3 0.04637919 2 0.6158572 0.6527823 0.05083938 0.5947744
4 0.04483023 3 0.5694780 0.6724529 0.05163819 0.5947744
5 0.04167149 4 0.5246478 0.6623028 0.05106530 0.5947744
6 0.02599265 5 0.4829763 0.6234457 0.04936799 0.5947744
7 0.02582284 6 0.4569836 0.6198034 0.04802643 0.5947744
8 0.02400748 7 0.4311608 0.6205756 0.04821332 0.5947744
9 0.01544139 8 0.4071533 0.5806072 0.04173785 0.5947744 under cap
10 0.01469771 9 0.3917119 0.5641331 0.04136793 0.5947744 under cap
11 0.01464055 10 0.3770142 0.5627713 0.04127139 0.5947744 under cap
12 0.01423309 11 0.3623736 0.5608073 0.04109662 0.5947744 under cap
13 0.01401541 12 0.3481405 0.5564663 0.03830810 0.5947744 min xerror
14 0.01393771 13 0.3341251 0.5564663 0.03830810 0.5947744 under cap
15 0.01055959 14 0.3201874 0.5710951 0.03887227 0.5947744 under cap
16 0.01000000 15 0.3096278 0.5667561 0.03808991 0.5947744 under cap
Bien, entonces el árbol más simple es el que tiene CP = 0.02599265 (5 divisiones). También podemos usar plotcp () para visualizar la relación entre
xerrory
cp`.
plotcp(cs_mdl_cart_full, upper = "splits")
La línea discontinua se establece en el mínimo xerror + xstd
. El eje superior muestra el número de divisiones en el árbol. El error relativo más pequeño está en CP = 0.01 (15 divisiones), pero el CP máximo debajo de la línea discontinua (una desviación estándar por encima del error mínimo) está en CP = 0.02599265 (5 divisiones). Utilizamos entonces la función prune ()
para podar el árbol especificando el coste-complejidad asociado a este CP.
<- prune(
cs_mdl_cart
cs_mdl_cart_full,cp = cs_mdl_cart_full$cptable[cs_mdl_cart_full$cptable[, 2] == 5, "CP"]
)rpart.plot(cs_mdl_cart, yesno = TRUE)
El indicador más “importante” de ventas es ShelveLoc
. Estos son los valores de importancia del modelo:
$variable.importance %>%
cs_mdl_cartdata.frame() %>%
rownames_to_column(var = "Feature") %>%
rename(Overall = '.') %>%
ggplot(aes(x = fct_reorder(Feature, Overall), y = Overall)) +
geom_pointrange(aes(ymin = 0, ymax = Overall), color = "cadetblue", size = .3) +
theme_minimal() +
coord_flip() +
labs(x = "", y = "", title = "Variable Importance with Simple Regression")
El indicador más importante de ventas es ShelveLoc
, luego Price
, luego Age
(edad media de la población donde está la tienda). Todas estas variables aparecen en el modelo final. CompPrice
(precio del competidor) también es relevante.
El último paso es hacer predicciones sobre el conjunto de datos de validación. Cuando la variable respuesta es continua usamos:
la raíz del error cuadrático medio \(RMSE = \sqrt{(1/2) \sum{(actual - pred)^2}})\) y
el errr absoluto medio \(MAE = (1/n) \sum{|actual - pred|}\)
La diferencia entre ambos es que RMSE penaliza más los errores grandes. Para un árbol de regresión, basta con indicar type="vector") en la función
predict ()` (que es el valor por defecto).
<- predict(cs_mdl_cart, cs_test, type = "vector")
cs_preds_cart
<- RMSE(
cs_rmse_cart pred = cs_preds_cart,
obs = cs_test$Sales
) cs_rmse_cart
[1] 2.363202
El proceso de poda conduce a un error de predicción promedio de 2.363 en el conjunto de datos de prueba. No está mal considerando que la desviación estándar de la variable Sales
es 2.8. Podemos visualizar la relación entre los datos predichos y los observados mediante:
data.frame(Predichos = cs_preds_cart, Observados = cs_test$Sales) %>%
ggplot(aes(x = Observados, y = Predichos)) +
geom_point(alpha = 0.6, color = "cadetblue") +
geom_smooth() +
geom_abline(intercept = 0, slope = 1, linetype = 2) +
labs(title = "Carseats CART, predichos vs observados")
14.3.1 Entrenamiento con caret
También podemos ajustar el modelo con caret::train ()
especificando method = "rpart"
. Construirmos el modelo usando 10-fold CV para optimizar el hiperparámetro CP.
= trainControl(
cs_trControl method = "cv",
number = 10,
savePredictions = "final"
)
Usaremos la misma estrategia que en el caso de los árboles de clasificación,
dejaremos que el modelo busque el mejor parámetro de ajuste de CP con tuneLength
y luego lo ajustaremos con tuneGrid
.
set.seed(1234)
= train(
cs_mdl_cart2 ~ .,
Sales data = cs_train,
method = "rpart",
tuneLength = 5,
metric = "RMSE",
trControl = cs_trControl
)
cs_mdl_cart2
CART
321 samples
10 predictor
No pre-processing
Resampling: Cross-Validated (10 fold)
Summary of sample sizes: 289, 289, 289, 289, 289, 289, ...
Resampling results across tuning parameters:
cp RMSE Rsquared MAE
0.04167149 2.209383 0.4065251 1.778797
0.04483023 2.243618 0.3849728 1.805027
0.04637919 2.275563 0.3684309 1.808814
0.12140705 2.400455 0.2942663 1.936927
0.26273578 2.692867 0.1898998 2.192774
RMSE was used to select the optimal model using the smallest value.
The final value used for the model was cp = 0.04167149.
El primer cp (0.04167149) presenta el RMSE más pequeño. Puedemos hacer una búsqueda más fina para mejorar el valor de cp usando un grid:
set.seed(1234)
= train(
cs_mdl_cart2 ~ .,
Sales data = cs_train,
method = "rpart",
tuneGrid = expand.grid(cp = seq(from = 0, to = 0.1, by = 0.01)),
metric = "RMSE",
trControl = cs_trControl
) cs_mdl_cart2
CART
321 samples
10 predictor
No pre-processing
Resampling: Cross-Validated (10 fold)
Summary of sample sizes: 289, 289, 289, 289, 289, 289, ...
Resampling results across tuning parameters:
cp RMSE Rsquared MAE
0.00 2.055676 0.5027431 1.695453
0.01 2.135096 0.4642577 1.745937
0.02 2.095767 0.4733269 1.699235
0.03 2.131246 0.4534544 1.690453
0.04 2.146886 0.4411380 1.712705
0.05 2.284937 0.3614130 1.837782
0.06 2.265498 0.3709523 1.808319
0.07 2.282630 0.3597216 1.836227
0.08 2.282630 0.3597216 1.836227
0.09 2.282630 0.3597216 1.836227
0.10 2.282630 0.3597216 1.836227
RMSE was used to select the optimal model using the smallest value.
The final value used for the model was cp = 0.
En este ejemplo, parece que el árbol con mejor rendimiento es el que no ha sido podado.
plot(cs_mdl_cart2)
Este sería el modelo final
rpart.plot(cs_mdl_cart2$finalModel)
y estas las variables más importantes
plot(varImp(cs_mdl_cart2), main="Importancia de variables para Regresión")
Como siempre, debemos evaluar el modelo en nuestra muestra test:
<- predict(cs_mdl_cart2, cs_test, type = "raw")
cs_preds_cart2 data.frame(Actual = cs_test$Sales, Predicted = cs_preds_cart2) %>%
ggplot(aes(x = Actual, y = Predicted)) +
geom_point(alpha = 0.6, color = "cadetblue") +
geom_smooth(method = "loess", formula = "y ~ x") +
geom_abline(intercept = 0, slope = 1, linetype = 2) +
labs(title = "Carseats CART, Predicted vs Actual (caret)")
Observamos como el modelo sobreestima en el extremo inferior y subestima en el extremo superior. Podemos calcular el RMSE para estos datos:
<- RMSE(pred = cs_preds_cart2, obs = cs_test$Sales)) (cs_rmse_cart2
[1] 2.298331
Caret mejora las predicciones:
<- rbind(
cs_scoreboard data.frame(Modelo = "Single Tree", RMSE = cs_rmse_cart),
data.frame(Modelo = "Single Tree (caret)", RMSE = cs_rmse_cart2)
%>% arrange(RMSE)
) ::kable(cs_scoreboard, row.names = FALSE) knitr
Modelo | RMSE |
---|---|
Single Tree (caret) | 2.298331 |
Single Tree | 2.363202 |
14.4 Bagged trees
Los CART tiene una capacidad predictiva moderada, es por ello que se han propuesto unos métodos que combinan varios árboles de decisión para producir un mejor rendimiento predictivo que utilizar un solo árbol de decisión. El principio fundamental detrás de estos modelos es que un grupo de predictores débiles puede conseguir un predictor con mejor capacidad predictiva.
Tenemos dos tipos de estrategias:
- Bagging
- Boosting
Bagging (Bootstrap Aggregation) se utiliza cuando nuestro objetivo es reducir la varianza de un árbol de decisión. La idea es crear varios subconjuntos de datos a partir de la muestra de entrenamiento elegida al azar con reemplazamiento. Cada subconjunto de datos se utiliza para entrenar un árbol de decisión. Como resultado, terminamos con un conjunto de diferentes modelos. Se utiliza el promedio de todas las predicciones de diferentes árboles, que es más robusto que considerar un solo árbol de decisión.
En el Boosting se aprende de forma secuencial. Ajustamos árboles consecutivos (muestra aleatoria) y en cada paso, el objetivo es mejorar el error del árbol anterior.
Como hemos dicho anteriormente, el algoritmo bagged construye B árboles decisión usando conjuntos de entrenamiento obtenidos mediante remuestreo y promedia las predicciones resultantes. Estos árboles crecen profundamente y no se podan. Por tanto, cada árbol individual tiene una alta varianza, pero un bajo sesgo. Promediar los B árboles ayuda a reducir la varianza. El valor predicho para una observación es la moda (clasificación) o la media (regresión) de los árboles. B generalmente es igual a ~ 25.
Para un conjunto de entrenamiento de tamaño \(n\), cada árbol se compone de \(\sim (1 - e^{-1})n = .632n\) observaciones únicas in-bag y \(.368n\) out-of-bag. Las observaciones que no han sido seleccionadas en el re-muestreo se usan para evaluar la precisión del modelo. La capacidad glogal del método se obtiene promediando la capacidad de cada árbol. Esto tiene una desventaja obvia y es que si cada árbol tiene un rendimiento deficiente, el rendimiento promedio de muchos árboles seguirá siendo deficiente. Además, otra desventaja de este método es que no existe un árbol único con un conjunto de reglas para interpretar. En consecuencia, no queda claro qué variables son más importantes que otras y en algunos problemas (sobre todo biomédicos) esto puede ser una limitación importante.
14.4.1 Bagging árboles de clasificación
Veamos de nuevo con un ejemplo cómo trabajar con estos métodos. Usaremos de nuevo los datos de zumos de naranja OJ
. Esta vez usaremos un método bagging especificando method="treebag"
. Caret no tiene hiperparámetros para este modelo, por lo que no es necesario usar tuneLegth
ni tuneGrid
. El tamaño de conjunto predeterminado es nbagg = 25
(a veces se puede tunear, pero en este caso lo dejaremos fijo).
set.seed(1234)
<- train(
oj_mdl_bag ~ .,
Purchase data = oj_train,
method = "treebag",
trControl = oj_trControl,
metric = "ROC"
)$finalModel oj_mdl_bag
Bagging classification trees with 25 bootstrap replications
Veamos el rendimiento en el conjunto de datos test.
<- predict(oj_mdl_bag, newdata = oj_test, type = "raw")
pred_bag <- confusionMatrix(pred_bag, oj_test$Purchase)
oj_cm_bag oj_cm_bag
Confusion Matrix and Statistics
Reference
Prediction CH MM
CH 113 16
MM 17 67
Accuracy : 0.8451
95% CI : (0.7894, 0.8909)
No Information Rate : 0.6103
P-Value [Acc > NIR] : 6.311e-14
Kappa : 0.675
Mcnemar's Test P-Value : 1
Sensitivity : 0.8692
Specificity : 0.8072
Pos Pred Value : 0.8760
Neg Pred Value : 0.7976
Prevalence : 0.6103
Detection Rate : 0.5305
Detection Prevalence : 0.6056
Balanced Accuracy : 0.8382
'Positive' Class : CH
La precisión es 0.8451, sorprendentemente peor que el 0.85915 del modelo de árbol único, pero esa es una diferencia que corresponde a tres predicciones en un conjunto de 213. Esta sería la curva ROC.
<- predict(oj_mdl_bag, newdata = oj_test, type = "prob")[,"CH"]
pred_bag2 <- roc(oj_test$Purchase, pred_bag2, print.auc=TRUE,
roc.bag ci=TRUE,
plot=TRUE)
Veamos cuáles son las variables más importantes
plot(varImp(oj_mdl_bag), main="Importancia de variables con Bagging")
Esta es la comparación entre métodos
<- rbind(oj_scoreboard,
oj_scoreboard data.frame(Modelo = "Bagging",
Accuracy = oj_cm_bag$overall["Accuracy"],
ROC = roc.bag$auc)
%>% arrange(desc(ROC))
) ::kable(oj_scoreboard, row.names = FALSE) knitr
Modelo | Accuracy | ROC |
---|---|---|
Single Tree (caret) | 0.8544601 | 0.9162651 |
Bagging | 0.8450704 | 0.9099166 |
Single Tree | 0.8591549 | 0.8848471 |
14.4.2 Bagging árboles de regresión
Usemos bagging para predecir las ventas en los datos Carseats
:
set.seed(1234)
<- train(
cs_mdl_bag ~ .,
Sales data = cs_train,
method = "treebag",
trControl = cs_trControl
) cs_mdl_bag
Bagged CART
321 samples
10 predictor
No pre-processing
Resampling: Cross-Validated (10 fold)
Summary of sample sizes: 289, 289, 289, 289, 289, 289, ...
Resampling results:
RMSE Rsquared MAE
1.681889 0.675239 1.343427
Veamos el rendimiento en el conjunto de datos test. El RMSE es 1.9185, pero el modelo predice en exceso en el extremo inferior de ventas y tampoco predice bien en el extremo superior (como un árbol simple).
<- bind_cols(
cs_preds_bag Predicted = predict(cs_mdl_bag, newdata = cs_test),
Actual = cs_test$Sales
)<- RMSE(pred = cs_preds_bag$Predicted, obs = cs_preds_bag$Actual)) (cs_rmse_bag
[1] 1.918473
%>%
cs_preds_bag ggplot(aes(x = Actual, y = Predicted)) +
geom_point(alpha = 0.6, color = "cadetblue") +
geom_smooth(method = "loess", formula = "y ~ x") +
geom_abline(intercept = 0, slope = 1, linetype = 2) +
labs(title = "Carseats Bagging, Predicted vs Actual (caret)")
La importancia de las variables son:
plot(varImp(cs_mdl_bag), main="Importancia de variables con Bagging")
Y la comparación quedaría
<- rbind(cs_scoreboard,
cs_scoreboard data.frame(Modelo = "Bagging", RMSE = cs_rmse_bag)
%>% arrange(RMSE)
) ::kable(cs_scoreboard, row.names = FALSE) knitr
Modelo | RMSE |
---|---|
Bagging | 1.918473 |
Single Tree (caret) | 2.298331 |
Single Tree | 2.363202 |
EJERCICIO (Entrega en Moodle: P2-Bagged): |
Implementa una función que implemente el método “bagged tree”. Aplícalo a los datos “Carseats” y compara tus resultados con los que se obtienen usando caret. |
EJERCICIO (Entrega en Moodle: P-Bagged Breast cancer): |
Utiliza los datos de cáncer de mama que hemos trabajado en clase “breast_train_prep” y “breast_test_prep” (que puedes encontrar en el fichero “breast.Rdata” del Moodle) para crear un modelo predictivo usando CART y Bagged Trees. Compara los resultados con los obtenidos mediant KNN y LDA reportados en el bookdown del curso. |
14.5 Random Forest
Los Random Forest (bosques aleatorios) también son un conjunto de árboles de decisión (ensambladores) que mejoran los bagged trees mediante la creación de un bosque no correlacionados de árboles que, de nuevo, mejora la capacidad predictiva de un único árbol. Al igual que en el bagged (embolsado), el algoritmo construye varios árboles de decisión sobre muestras de entrenamiento bootstrap. Sin embargo, al construir estos árboles de decisión, cada vez que se considera una división en un árbol, se elige una muestra aleatoria de predictores (hiperparámetro m o mtry) como candidatos de división del conjunto completo de predictores \(p\). En cada división se toma una nueva muestra de predictores. Típicamente \(m \approx \sqrt{p}\). En consecuencia, los árboles bagged son un caso especial de los random forest cuando \(m = p\).
Cada árbol del modelo random forest se construye de la siguiente forma:
Si denotamos por \(N\) el número de casos en el conjunto de entrenamiento, seleccionaremos una muestra de esos \(N\) casos se forma aleatoria CON REEMPLAZAMIENTO. Esta muestra será el conjunto de entrenamiento para construir el árbol i-ésimo.
Si denotamos por \(M\) el número total de varibles predictoras, seleccionaremos un número \(m < M\) de variables y crearemos un árbol completo con esas variables. El valor \(m\) se mantiene constante durante la generación de todo el bosque.
Cada árbol crece hasta su máxima extensión posible y NO hay proceso de poda.
La predicción para nuevos individuos se hace a partir de la información obtenida de las predicciones de los \(B\) árboles (mayoría de votos para clasificación, promedio para regresión). La siguiente figura ilustra este proceso
Podemos estimar un random forest con cart
indicando el argumento method = "rf"
. El hiperparámetro mtry
(\(m\)) puede tomar cualquier valor de 1 a 17 (el número de predictores) y se espera que el valor óptimo esté cerca de \(\sqrt{17} \approx 4\). En cuanto al número de árboles (segundo hiperparámetro), hay estudios que demuestran que el rendimiento empeora cuando tenemos muchos árboles, sin embargo esto no está muy claro y por lo general se recomienda entrenar modelos con muchos árboles. Por defecto method = "rf"
tiene 500 (argumento num.trees
).
set.seed(1234)
<- train(
oj_mdl_rf ~ .,
Purchase data = oj_train,
method = "rf",
metric = "ROC",
tuneGrid = expand.grid(mtry = 3:10),
trControl = oj_trControl,
num.trees = 500
) oj_mdl_rf
Random Forest
857 samples
17 predictor
2 classes: 'CH', 'MM'
No pre-processing
Resampling: Cross-Validated (10 fold)
Summary of sample sizes: 772, 772, 771, 770, 771, 771, ...
Resampling results across tuning parameters:
mtry ROC Sens Spec
3 0.8655672 0.8565312 0.7185383
4 0.8685845 0.8641872 0.7122995
5 0.8682630 0.8470247 0.7183601
6 0.8672458 0.8412917 0.7124777
7 0.8695796 0.8412917 0.7183601
8 0.8668721 0.8393687 0.7213012
9 0.8652269 0.8432148 0.7153298
10 0.8671443 0.8413280 0.7152406
ROC was used to select the optimal model using the largest value.
The final value used for the model was mtry = 7.
El valor de ROC más alto se da en \(m = 7\) que es más alto de lo que esperábamos, pero fijémosnos que es un valor de ROC muy similar al que se obtiene con \(m=4\), por lo que por el principio de parsimonia podríamos usar dicho valor
plot(oj_mdl_rf)
También podemos visualizar los resultados con:
<- function(model) {
plot_rf theme_set(theme_minimal())
<- model$results %>%
u ::select(mtry, ROC, Sens, Spec) %>%
dplyrgather(a, b, -mtry)
%>% ggplot(aes(mtry, b)) + geom_line() + geom_point() +
u facet_wrap(~ a, scales = "free") +
labs(x = "Número de predictores", y = NULL,
title = "Relación entre el número de predictores y el comportamiento del modelo")
}
%>% plot_rf() oj_mdl_rf
EJERCICIO (Entrega en Moodle: P2-Plot RF): |
Si en el modelo train usamos metric = "prSummary" también podemos visualizar otras métricas para evaluar el comportamiento del modelo. En particular el AUC que puede ser interesante desde un punto de vista práctico. Crea una función que visualice dichas métricas y aplícalo a los datos oj_train (que obviamente tendrás que re-entrenar con esa nueva métrica). |
Podemos usar este modelo para hacer predicciones sobre la muestra test
<- predict(oj_mdl_rf, newdata = oj_test, type = "raw")
pred_rf <- confusionMatrix(pred_rf, oj_test$Purchase)
oj_cm_rf oj_cm_rf
Confusion Matrix and Statistics
Reference
Prediction CH MM
CH 112 16
MM 18 67
Accuracy : 0.8404
95% CI : (0.7841, 0.8869)
No Information Rate : 0.6103
P-Value [Acc > NIR] : 2.164e-13
Kappa : 0.6659
Mcnemar's Test P-Value : 0.8638
Sensitivity : 0.8615
Specificity : 0.8072
Pos Pred Value : 0.8750
Neg Pred Value : 0.7882
Prevalence : 0.6103
Detection Rate : 0.5258
Detection Prevalence : 0.6009
Balanced Accuracy : 0.8344
'Positive' Class : CH
Y el área bajo la curva ROC sería:
<- predict(oj_mdl_rf, newdata = oj_test, type = "prob")[,"CH"]
pred_rf2 <- roc(oj_test$Purchase, pred_bag2, print.auc=TRUE,
roc.rf ci=TRUE,
plot=TRUE)
que compara con los modelos anteriores de esta forma:
Y la comparación quedaría
<- rbind(oj_scoreboard,
oj_scoreboard data.frame(Modelo = "Random Forest",
Accuracy = oj_cm_rf$overall["Accuracy"],
ROC = roc.rf$auc)
%>% arrange(desc(ROC))
) ::kable(oj_scoreboard, row.names = FALSE) knitr
Modelo | Accuracy | ROC |
---|---|---|
Single Tree (caret) | 0.8544601 | 0.9162651 |
Bagging | 0.8450704 | 0.9099166 |
Random Forest | 0.8403756 | 0.9099166 |
Single Tree | 0.8591549 | 0.8848471 |
Recordemos que la importancia de las variables se puede ver con la función varImp
al igual que cualquier otro modelo basado en àrboles de decisión.
NOTA: El ejemplo para Random Forest con árboles de regresión es igual que lo que vimos en la sección anterior.
NOTA2: En este artículo se hace un “benchmarking” muy interesante para saber qué método y libería de R usar en función de las características de nuestro conjunto de datos.
Speiser JL et al. (2012). A Comparison of Random Forest Variable Selection Methods for Classification Prediction Modeling
Random forest classification is a popular machine learning method for developing prediction models in many research settings. Often in prediction modeling, a goal is to reduce the number of variables needed to obtain a prediction in order to reduce the burden of data collection and improve efficiency. Several variable selection methods exist for the setting of random forest classification; however, there is a paucity of literature to guide users as to which method may be preferable for different types of datasets. Using 311 classification datasets freely available online, we evaluate the prediction error rates, number of variables, computation times and area under the receiver operating curve for many random forest variable selection methods. We compare random forest variable selection methods for different types of datasets (datasets with binary outcomes, datasets with many predictors, and datasets with imbalanced outcomes) and for different types of methods (standard random forest versus conditional random forest methods and test based versus performance based methods). Based on our study, the best variable selection methods for most datasets are Jiang’s method and the method implemented in the VSURF R package. For datasets with many predictors, the methods implemented in the R packages varSelRF and Boruta are preferable due to computational efficiency. A significant contribution of this study is the ability to assess different variable selection techniques in the setting of random forest classification in order to identify preferable methods based on applications in expert and intelligent systems.
EJERCICIO (Entrega en Moodle: P-RF Breast cancer): |
Utiliza los datos de cáncer de mama que hemos trabajado en clase “breast_train_prep” y “breast_test_prep” (que puedes encontrar en el fichero “breast.Rdata” del Moodle) para crear un modelo predictivo usando Random Forest. Compara los resultados con los obtenidos mediante KNN y LDA reportados en el bookdown del curso (y di si mejora los que obtuviste con CART y Bagged Trees). |
14.6 Random Forest p>>n
Problema: Aplicar Random Forest para conjunto de datos con muchas variables (caso p>>n)
- Posible estrategia:
- Creamos K subconjuntos de variable
- Llevamos a cabo una selección de las variables más importantes y nos quedamos con una parte de ellas. Por ejemplo, las M más informativas
- Combinamos las K*M variables y repetimos los pasos 1 y 2
- Acabamos con M variables seleccionadas
- Aplicamos Random Forest
Este enfoque podría causar la pérdida de algunas variables importantes, pero generalmente seleccionará las variables más informativas.
Selección de K y M Breiman (2001) recomienda \(m=p/3\) en clasificación y \(m=\sqrt{p}\) en regresión (mtry). ¿Puede servir esto de ayuda?
Otra opción: Librería ranger