web-dev-qa-db-ja.com

lightgbmのf1_scoreメトリック

カスタム指標:f1_scoreweighted averageでlgbモデルをトレーニングしたい。

here でlightgbmの高度な例を調べたところ、カスタムのバイナリエラー関数の実装がわかりました。以下に示すように、同様の機能としてf1_scoreを返すように実装しました。

def f1_metric(preds, train_data):

    labels = train_data.get_label()

    return 'f1', f1_score(labels, preds, average='weighted'), True

以下に示すように、fevalパラメータをf1_metricとして渡して、モデルをトレーニングしようとしました。

evals_results = {}

bst = lgb.train(params, 
                     dtrain, 
                     valid_sets= [dvalid], 
                     valid_names=['valid'], 
                     evals_result=evals_results, 
                     num_boost_round=num_boost_round,
                     early_stopping_rounds=early_stopping_rounds,
                     verbose_eval=25, 
                     feval=f1_metric)

その後、ValueError: Found input variables with inconsistent numbers of samples:を取得しています

トレーニングセットが検証セットではなく関数に渡されています。

検証セットが渡され、f1_scoreが返されるように構成するにはどうすればよいですか?

4
Sreeram TP

ドキュメントは少し混乱しています。 fevalに渡す関数のシグニチャーを説明するとき、それらはそのパラメーターpredsおよびtrain_dataを呼び出しますが、これは少し誤解を招きます。

しかし、次のように機能するようです:

from sklearn.metrics import f1_score

def lgb_f1_score(y_hat, data):
    y_true = data.get_label()
    y_hat = np.round(y_hat) # scikits f1 doesn't like probabilities
    return 'f1', f1_score(y_true, y_hat), True

evals_result = {}

clf = lgb.train(param, train_data, valid_sets=[val_data, train_data], valid_names=['val', 'train'], feval=lgb_f1_score, evals_result=evals_result)

lgb.plot_metric(evals_result, metric='f1')

複数のカスタム指標を使用するには、上記と同様に、1つの全体的なカスタム指標関数を定義します。この関数では、すべての指標を計算し、タプルのリストを返します。

編集:修正されたコードはもちろん、F1が大きいほどTrueに設定する必要があります。

12
Toby

トビーの答えについて:

def lgb_f1_score(y_hat, data):
    y_true = data.get_label()
    y_hat = np.round(y_hat) # scikits f1 doesn't like probabilities
    return 'f1', f1_score(y_true, y_hat), True

Y_hatの部分を次のように変更することをお勧めします。

y_hat = np.where(y_hat < 0.5, 0, 1)  

理由:私はy_hat = np.round(y_hat)を使用しましたが、トレーニング中にlightgbmモデルがy予測をバイナリではなくマルチクラスと見なすことがある(非常にありそうもないが、それでも変更される)ことを確認しました。

私の推測:時々、y予測は負の値または2に丸めるのに十分なほど小さいか、わかりませんが、よくわかりませんが、npを使用してコードを変更したところ、バグがなくなりました。

Np.whereソリューションが優れているかどうかは確かではありませんが、このバグを理解するのに朝かかります。

0
GISH