决策树中的类别特征问题(关于label encode还是one-hot的讨论)

就决策树来说,算法本身是(为数不多的)天然支持categorical feature的机器学习算法,但是如果是high cardinality,那么理论上最优的split要遍历所有二分组合,是指数级的复杂度,Python的implementation只解决数值型feature,把这个难题丢给用户了,xgboost也是一样,作者的解释是为了给用户更多自主权决定如何处理categorical feature。

one-hot不是一个好的方案

one-hot coding是类别特征的一种通用解决方法,然而在树模型里面,这并不是一个比较好的方案,尤其当类别特征维度很高的时候。主要的问题是:

  1. 可能无法在这个类别特征上进行切分。使用one-hot coding的话,意味着在每一个决策节点上只能用 one-vs-rest (例如是不是狗,是不是猫,等等) 的切分方式。当特征纬度高时,每个类别上的数据都会比较少,这时候产生的切分不平衡,切分增益(split gain)也会很小(比较直观的理解是,不平衡的切分和不切分几乎没有区别)。
  2. 会影响决策树的学习。因为就算可以在这个类别特征进行切分,也会把数据切分到很多零散的小空间上,如图1左所示。而决策树学习时利用的是统计信息,在这些数据量小的空间上,统计信息不准确,学习会变差。但如果使用图1右边的切分方法,数据会被切分到两个比较大的空间,进一步的学习也会更好。

决策树中的类别特征问题(关于label encode还是one-hot的讨论)

在决策树上类别特征的处理方法

  1. 类别特征的最优切分。这个方法需要对应工具的支持,支持这个方法的工具有h2o.gbm和LightGBM。用LightGBM可以直接输入类别特征,并产生同图1右边的最优切分。在一个k维的类别特征寻找最优切分,朴素的枚举算法的复杂度是指数的 O(2^k)。LightGBM 用了一个 O(klogk)的算法。算法流程如下图所示:在枚举分割点之前,先把直方图按照每个类别对应的label均值进行排序;然后按照排序的结果依次枚举最优分割点。当然,这个方法很容易过拟合,所以LightGBM里面还增加了很多对于这个方法的约束和正则化。klogk的算法只能用于 regression 或 binary classification,在 multi-class classification 上是np-hard的。不过multi-class classification 也可以用one-vs-rest(多个binary classification)的方法来解决,在GBDT里面一般也是这么解决multi-class问题的。决策树中的类别特征问题(关于label encode还是one-hot的讨论)决策树中的类别特征问题(关于label encode还是one-hot的讨论)
  2. 转成数值特征。在使用 sklearn 或 XGBoost 等不支持类别特征的最优切分工具时,可以用这个方法。常见的转换方法有: a) 把类别特征转成one-hot coding扔到NN里训练个embedding;b) 类似于CTR特征,统计每个类别对应的label(训练目标)的均值。统计的时候有一些小技巧,比如不把自身的label算进去(leave-me-out, leave-one-out)统计, 防止信息泄露。
  3. 其他的编码方法,比如binary coding等等,同样可以用于不支持类别特征的算法。这里有一个比较好的开源项目,封装了常见的各种编码方法: https://github.com/scikit-learn-contrib/categorical-encoding

具体处理方法

  • label encoding
    • 特征存在内在顺序 (ordinal feature)
  • one hot encoding
    • 特征无内在顺序,category数量 < 4
  • target encoding (mean encoding, likelihood encoding, impact encoding)
    • 特征无内在顺序,category数量 > 4
  • beta target encoding
    • 特征无内在顺序,category数量 > 4, K-fold cross validation
  • 不做处理(模型自动编码)
    • CatBoost(leave-one-out统计label均值),lightgbm(最优切分点)

1. Label encoding

对于一个有m个category的特征,经过label encoding以后,每个category会映射到0到m-1之间的一个数。label encoding适用于ordinal feature (特征存在内在顺序)。

2. One-hot encoding (OHE)

对于一个有m个category的特征,经过独热编码(OHE)处理后,会变为m个二元特征,每个特征对应于一个category。这m个二元特征互斥,每次只有一个**。

独热编码解决了原始特征缺少内在顺序的问题,但是缺点是对于high-cardinality categorical feature (category数量很多),编码之后特征空间过大(此处可以考虑PCA降维),而且由于one-hot feature 比较unbalanced,树模型里每次的切分增益较小,树模型通常需要grow very deep才能得到不错的精度。因此OHE一般用于category数量 <4的情况。

3. Target encoding (or likelihood encoding, impact encoding, mean encoding)

Target encoding 采用 target mean value (among each category) 来给categorical feature做编码。为了减少target variable leak,主流的方法是使用2 levels of cross-validation求出target mean,思路如下:

  • 把train data划分为20-folds (举例:infold: fold #2-20, out of fold: fold #1)
    • 将每一个 infold (fold #2-20) 再次划分为10-folds (举例:inner_infold: fold #2-10, Inner_oof: fold #1)
      • 计算 10-folds的 inner out of folds值 (举例:使用inner_infold #2-10 的target的均值,来作为inner_oof #1的预测值)
      • 对10个inner out of folds 值取平均,得到 inner_oof_mean
    • 计算oof_mean (举例:使用 infold #2-20的inner_oof_mean 来预测 out of fold #1的oof_mean
  • 将train data 的 oof_mean 映射到test data完成编码

4. beta target encoding

kaggle竞赛Avito Demand Prediction Challenge 第14名的solution分享: 14th Place Solution: The Almost Golden Defenders

和target encoding 一样,beta target encoding 也采用 target mean value (among each category) 来给categorical feature做编码。不同之处在于,为了进一步减少target variable leak,beta target encoding发生在在5-fold CV内部,而不是在5-fold CV之前:

  • 把train data划分为5-folds (5-fold cross validation)
    • target encoding based on infold data
    • train model
    • get out of fold prediction

同时beta target encoding 加入了smoothing term,用 bayesian mean 来代替mean。Bayesian mean (Bayesian average) 的思路: 某一个category如果数据量较少(<N_min),noise就会比较大,需要补足数据,达到smoothing 的效果。补足数据值 = prior mean。N_min 是一个regularization term,N_min 越大,regularization效果越强。

另外,对于target encoding和beta target encoding,不一定要用target mean (or bayesian mean),也可以用其他的统计值包括 medium, frqequency, mode, variance, skewness, and kurtosis — 或任何与target有correlation的统计值。

5. 不做任何处理(模型自动编码)

参考

[1]https://www.zhihu.com/question/266195966

[2]https://zhuanlan.zhihu.com/p/40231966(代码实现)

[3]https://zhuanlan.zhihu.com/p/26308272

Tagged DataminingkaggleMachinelearning