文系データサイエンティストの備忘録

英語できないのに外資系で働くデータ分析屋。

決定木から分岐条件を取得する

Rのrpartなどで決定木分析をした際、決定木の分岐条件を表にしてまとめたい時があるかと思います。 決定木を普通に実行するとこんな感じ

> # mtcarsを読込
> tmp <- data.frame(mtcars)
> 
> # シリンダー数,V/S,オートマ・マニュアル,ギア数,キャブ数をファクター化
> df <- tmp %>% 
+   mutate_each_(funs(as.factor), list("cyl","vs","am","gear","carb"))
> 
> # 燃費を目的変数にして回帰木生成
> ct <- rpart(mpg ~ ., data = df, method = "anova")
> (ct.party <- as.party(ct))

Model formula:
mpg ~ cyl + disp + hp + drat + wt + qsec + vs + am + gear + carb

Fitted party:
[1] root
|   [2] cyl in 6, 8
|   |   [3] hp >= 192.5: 13.414 (n = 7, err = 28.8)
|   |   [4] hp < 192.5: 18.264 (n = 14, err = 59.9)
|   [5] cyl in 4: 26.664 (n = 11, err = 203.4)

Number of inner nodes:    2
Number of terminal nodes: 3

ただこの分岐条件を使ってさらに分析を進めたい場合、この決定木から分岐条件を抜き出したい。 そのような場合、ベクトルとして分岐条件をテキストで抽出することが可能。

> # 条件をベクトルで抜き出し
> (cond <- partykit:::.list.rules.party(ct.party, i = nodeids(ct.party)))
                                       1 
                                      "" 
                                       2 
              "cyl %in% c(\"6\", \"8\")" 
                                       3 
"cyl %in% c(\"6\", \"8\") & hp >= 192.5" 
                                       4 
 "cyl %in% c(\"6\", \"8\") & hp < 192.5" 
                                       5 
                     "cyl %in% c(\"4\")" 

このままだとまだ扱いにくいので、分岐条件の変数と数値がそれぞれ列を持っている表を目指して整形する。

> # 決定木分岐条件一覧表に整形
> cond.df <- data.frame(INDEX=1:length(cond), CONDITIONS=cond)
> cond.df.split <- cond.df %>%
+   tidyr::separate(col=CONDITIONS, 
+                   into=paste("CONDITIONS",as.character(c(1:(1+max(str_count(.$CONDITIONS,(" & ")))))),sep=""), 
+                   sep=(" & "), remove=TRUE, extra='merge', fill='right')%>% 
+   tidyr::gather(key.con, CONDITIONS, -INDEX) %>% 
+   dplyr::filter(CONDITIONS != "NA") %>% 
+   dplyr::arrange(INDEX) %>% 
+   tidyr::separate(col=CONDITIONS, 
+                   into=c("CONDITIONS", "SIGN", "VALUE"),
+                   sep=(" "), remove=TRUE, extra='merge', fill='right') %>% 
+   dplyr::mutate(rVALUE= ifelse(test= SIGN=="<" | SIGN=="<=",
+                                yes= VALUE,
+                                no= Inf)) %>% 
+   dplyr::mutate(lVALUE= ifelse(test= SIGN==">" | SIGN==">=" | SIGN=="%in%",
+                                yes= VALUE,
+                                no= -Inf)) %>% 
+   dplyr::mutate(inVALUE= ifelse(test= SIGN=="%in%",
+                                yes= VALUE,
+                                no= NA)) %>% 
+   dplyr::mutate(rVALUE=as.numeric(rVALUE), lVALUE=as.numeric(lVALUE)) %>% 
+   dplyr::group_by(INDEX, CONDITIONS) %>% 
+   dplyr::summarise(rVALUE_min=min(rVALUE), lVALUE_max=max(lVALUE)) %>% 
+   dplyr::ungroup(.) %>% 
+   dplyr::mutate(VALUE= str_c(lVALUE_max, rVALUE_min, sep="~")) %>% 
+   dplyr::mutate(VALUE= str_replace_all(VALUE, c("-Inf"="","Inf"="","c[(]"="","[)]~"="","\""=""))) %>% 
+   dplyr::select(INDEX, CONDITIONS, VALUE)
> (cond.df.split)
# A tibble: 7 x 3
  INDEX CONDITIONS  VALUE
  <int>      <chr>  <chr>
1     1                 ~
2     2        cyl   6, 8
3     3        cyl   6, 8
4     3         hp 192.5~
5     4        cyl   6, 8
6     4         hp ~192.5
7     5        cyl      4

こんな感じ。 (9/15バグ修正)