web-dev-qa-db-ja.com

GBM R関数:クラスごとに個別に変数の重要度を取得します

R(gbmパッケージ)の gbm 関数を使用して、マルチクラス分類の確率的勾配ブースティングモデルを適合させています。 Hastie book(Elements of Statistics Learning) のこの写真のように、クラスごとに各予測子の重要性を個別に取得しようとしています。 (p.382)。

enter image description here

ただし、関数summary.gbmは、予測子の全体の重要度のみを返します(すべてのクラスで平均された重要度)。

誰かが相対的な重要度の値を取得する方法を知っていますか?

18
Antoine

簡単な答えは、379ページで、Hastieが [〜#〜] mart [〜#〜] を使用していると述べていることだと思います。これは、Splusでのみ使用できるようです。

私は、gbmパッケージが個別の相対的な影響を見ることを可能にしないように思われることに同意します。それがマルチクラス問題に関心がある場合は、クラスごとに1対すべてのgbmを構築し、それらの各モデルから重要度の測定値を取得することで、かなり似たものを得ることができます。

したがって、クラスがa、b、c、dであるとします。対残りをモデル化し、そのモデルから重要性を取得します。次に、bと残りのモデルを比較し、そのモデルから重要性を取得します。等。

11
Tchotchke

この機能がお役に立てば幸いです。この例では、ElemStatLearnパッケージのデータを使用しました。関数は、列のクラスが何であるかを理解し、データをこれらのクラスに分割し、各クラスでgbm()関数を実行し、これらのモデルの棒グラフをプロットします。

# install.packages("ElemStatLearn"); install.packages("gbm")
library(ElemStatLearn)
library(gbm)

set.seed(137531)

# formula: the formula to pass to gbm()
# data: the data set to use
# column: the class column to use
classPlots <- function (formula, data, column) {

    class_column <- as.character(data[,column])
    class_values <- names(table(class_column))
    class_indexes <- sapply(class_values, function(x) which(class_column == x))
    split_data <- lapply(class_indexes, function(x) marketing[x,])
    object <- lapply(split_data, function(x) gbm(formula, data = x))
    rel.inf <- lapply(object, function(x) summary.gbm(x, plotit=FALSE))

    nobjs <- length(class_values)
    for( i in 1:nobjs ) {
        tmp <- rel.inf[[i]]
        tmp.names <- row.names(tmp)
        tmp <- tmp$rel.inf
        names(tmp) <- tmp.names

        barplot(tmp, horiz=TRUE, col='red',
                xlab="Relative importance", main=paste0("Class = ", class_values[i]))
    }
    rel.inf
}

par(mfrow=c(1,2))
classPlots(Income ~ Marital + Age, data = marketing, column = 2)

`

output

7
nathanesau

Gbmパッケージが重要度を計算する方法を掘り下げました。これは、結果のtrees要素に含まれているErrorReductionに基づいており、pretty.gbm.trees()でアクセスできます。相対的な影響は、各変数のすべてのツリーでこのErrorReductionの合計をとることによって得られます。マルチクラス問題の場合、実際にはモデルにn.trees*num.classesツリーがあります。したがって、3つのクラスがある場合、3つのツリーごとに各変数のErrorReductionの合計を計算して、1つのクラスの重要度を取得できます。これを実装して結果をプロットするために、次の関数を作成しました。

クラスごとに変数の重要度を取得

RelInf_ByClass <- function(object, n.trees, n.classes, Scale = TRUE){
  library(dplyr)
  library(purrr)
  library(gbm)
  Ext_ErrRed<- function(ptree){
    ErrRed <- ptree %>% filter(SplitVar != -1) %>% group_by(SplitVar) %>% 
      summarise(Sum_ErrRed = sum(ErrorReduction))
  }
  trees_ErrRed <- map(1:n.trees, ~pretty.gbm.tree(object, .)) %>% 
    map(Ext_ErrRed)

  trees_by_class <- split(trees_ErrRed, rep(1:n.classes, n.trees/n.classes)) %>% 
    map(~bind_rows(.) %>% group_by(SplitVar) %>% 
          summarise(rel_inf = sum(Sum_ErrRed)))
  varnames <- data.frame(Num = 0:(length(object$var.names)-1),
                         Name = object$var.names)
  classnames <- data.frame(Num = 1:object$num.classes, 
                           Name = object$classes)
  out <- trees_by_class %>% bind_rows(.id = "Class") %>%  
    mutate(Class = classnames$Name[match(Class,classnames$Num)],
    SplitVar = varnames$Name[match(SplitVar,varnames$Num)]) %>%
    group_by(Class) 
  if(Scale == FALSE){
    return(out)
    } else {
    out <- out %>% mutate(Scaled_inf = rel_inf/max(rel_inf)*100)
    }
}

クラスごとに変数の重要度をプロットする

これを実際に使用する場合、40を超えるフィーチャがあるため、プロットするフィーチャの数を指定するオプションを提供します。また、クラスごとにプロットを個別に並べ替えたい場合はファセットを使用できませんでした。そのため、gridExtraを使用しました。

plot_imp_byclass <- function(df, n) {
  library(ggplot2)
  library(gridExtra)
  plot_imp_class <- function(df){
    df %>% arrange(rel_inf) %>% 
      mutate(SplitVar = factor(SplitVar, levels = .$SplitVar)) %>% 
      ggplot(aes(SplitVar, rel_inf))+
      geom_segment(aes(x = SplitVar, 
                       xend = SplitVar, 
                       y = 0, 
                       yend = rel_inf))+
      geom_point(size=3, col = "cyan") + 
      coord_flip()+
      labs(title = df$Class[[1]], x = "Variable", y = "Importance")+
      theme_classic()+
      theme(plot.title = element_text(hjust = 0.5))
  }

  df %>% top_n(n, rel_inf) %>% split(.$Class) %>% 
    map(plot_imp_class) %>% map(ggplotGrob) %>% 
    {grid.arrange(grobs = .)}
}

それを試してみてください

gbm_iris <- gbm(Species~., data = iris)
imp_byclass <- RelInf_ByClass(gbm_iris, length(gbm_iris$trees), 
                              gbm_iris$num.classes, Scale = F)
plot_imp_byclass(imp_byclass, 4)

すべてのクラスの結果を合計すると、組み込みのrelative.influence関数と同じ結果が得られるようです。

relative.influence(gbm_iris)
# n.trees not given. Using 100 trees.
# Sepal.Length  Sepal.Width Petal.Length  Petal.Width 
# 0.00000     51.88684   2226.88017    868.71085 

imp_byclass %>% group_by(SplitVar) %>% summarise(Overall_rel_inf = sum(rel_inf))
# A tibble: 3 x 2
# SplitVar     Overall_rel_inf
# <fct>                  <dbl>
#   1 Petal.Length          2227. 
# 2 Petal.Width            869. 
# 3 Sepal.Width             51.9
2
see24