交差検定を使用したロジスティック回帰モデルから確率を予測したいと思います。交差検証スコアを取得できることはわかっていますが、スコアの代わりにpredict_probaから値を返すことは可能ですか?
# imports
from sklearn.linear_model import LogisticRegression
from sklearn.cross_validation import (StratifiedKFold, cross_val_score,
train_test_split)
from sklearn import datasets
# setup data
iris = datasets.load_iris()
X = iris.data
y = iris.target
# setup model
cv = StratifiedKFold(y, 10)
logreg = LogisticRegression()
# cross-validation scores
scores = cross_val_score(logreg, X, y, cv=cv)
# predict probabilities
Xtrain, Xtest, ytrain, ytest = train_test_split(X, y)
logreg.fit(Xtrain, ytrain)
proba = logreg.predict_proba(Xtest)
これは現在、scikit-learnバージョン0.18の一部として実装されています。 'method'文字列パラメーターをcross_val_predictメソッドに渡すことができます。ドキュメントは here です。
例:
proba = cross_val_predict(logreg, X, y, cv=cv, method='predict_proba')
また、これは新しいsklearn.model_selectionパッケージの一部であるため、このインポートが必要になることに注意してください。
from sklearn.model_selection import cross_val_predict
これの簡単な回避策は、ラッパークラスを作成することです。
class proba_logreg(LogisticRegression):
def predict(self, X):
return LogisticRegression.predict_proba(self, X)
次に、そのインスタンスを分類子オブジェクトとしてcross_val_predict
に渡します。
# cross validation probabilities
probas = cross_val_predict(proba_logreg(), X, y, cv=cv)
関数cross_val_predict
これは予測値を提供しますが、「predict_proba」にはそのような関数はまだありません。多分それをオプションにすることができるでしょう。
これは簡単に実装できます。
def my_cross_val_predict(
m, X, y, cv=KFold(),
predict=lambda m, x: m.predict_proba(x),
combine=np.vstack
):
preds = []
for train, test in cv.split(X):
m.fit(X[train, :], y[train])
pred = predict(m, X[test, :])
preds.append(pred)
return combine(preds)
これは、predict_probaを返します。 predictとpredict_probaの両方が必要な場合は、predict
およびcombine
引数を変更するだけです。
def stack(arrs):
if arrs[0].ndim == 1:
return np.hstack(arrs)
else:
return np.vstack(arrs)
def my_cross_val_predict(
m, X, y, cv=KFold(),
predict=lambda m, x:[ m.predict(x)
, m.predict_proba(x)
],
combine=lambda preds: list(map(stack, Zip(*preds)))
):
preds = []
for train, test in cv.split(X):
m.fit(X[train, :], y[train])
pred = predict(m, X[test, :])
preds.append(pred)
return combine(preds)