XGBoost/lightGBM如何评估ndcg的排名任务?

问题描述:

我目前正在XGBoost/lightGBM之间进行测试,以便对项目进行排名。我正在复制这里提供的基准:https://github.com/guolinke/boosting_tree_benchmarksXGBoost/lightGBM如何评估ndcg的排名任务?

我已经能够成功地重现他们的工作中提到的基准。我想确保我正确实现了我自己的ndcg指标版本,并正确理解排名问题。

我的问题是:

  1. 当使用NDCG创建测试集验证 - 有一个test.group文件说,第一个X行是组0等。为了得到建议我得到预测值和已知的相关分数,并按照每个组的预测值降序对该列表进行排序?

  2. 为了从上面创建的列表中获得最终的ndcg分数 - 我是否获得ndcg分数并对所有分数取平均值?这与XGBoost/lightGBM在评估阶段的评估方法是否相同?

这是我在模型完成培训后评估测试集的方法。

对于最终的树,当我运行lightGBM我获得验证集这些值:

[500] valid_0's [email protected]: 0.513221 valid_0's [email protected]: 0.499337 valid_0's [email protected]: 0.505188 valid_0's [email protected]: 0.523407 

我的最后一步是走对测试集的预测输出和计算预测的NDCG值。

这里是计算NDCG我的Python代码:

import numpy as np 

def dcg_at_k(r, k): 
    r = np.asfarray(r)[:k] 
    if r.size: 
     return np.sum(np.subtract(np.power(2, r), 1)/np.log2(np.arange(2, r.size + 2))) 
    return 0. 


def ndcg_at_k(r, k): 
    idcg = dcg_at_k(sorted(r, reverse=True), k) 
    if not idcg: 
     return 0. 
    return dcg_at_k(r, k)/idcg 

后,我得到了一组特定的测试集的预测(GROUP-0)我有这些预言:

query_id predict 
0 0 (2.0, -0.221681199441) 
1 0 (1.0, 0.109895548348) 
2 0 (1.0, 0.0262799346312) 
3 0 (0.0, -0.595343431322) 
4 0 (0.0, -0.52689043426) 
5 0 (0.0, -0.542221350664) 
6 0 (1.0, -0.448015576024) 
7 0 (1.0, -0.357090949646) 
8 0 (0.0, -0.279677741045) 
9 0 (0.0, 0.2182200869) 

注意

集团0实际上有大约112行。

我再排序元组的列表按降序排列,其提供的相关评分列表:

def get_recommendations(x): 

    sorted_list = sorted(list(x), key=lambda i: i[1], reverse=True) 
    return [k for k, _ in sorted_list] 

relavance = evaluation.groupby('query_id').predict.apply(get_recommendations) 

query_id 
0 [4.0, 2.0, 2.0, 3.0, 2.0, 2.0, 2.0, 2.0, 2.0, ... 
1 [4.0, 2.0, 2.0, 2.0, 1.0, 1.0, 3.0, 2.0, 1.0, ... 
2 [2.0, 3.0, 2.0, 2.0, 1.0, 0.0, 2.0, 2.0, 1.0, ... 
3 [2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, ... 
4 [1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, ... 

最后,对于每个查询ID我计算的相关名单上的NDCG分数,然后取平均值所有NDCG得分为每个查询ID计算:

relavance.apply(lambda x: ndcg_at_k(x, 10)).mean() 

我得到的值是~0.497193

我认为问题是由同一个查询中具有相同标签的数据造成的。 在这种情况下,XGBoost和LightGBM都会为该查询生成ndcg 1。