パッケージscikit Learnのtrain_test_split
を使用しようとしていますが、パラメータstratify
に問題があります。コードは次のとおりです。
from sklearn import cross_validation, datasets
X = iris.data[:,:2]
y = iris.target
cross_validation.train_test_split(X,y,stratify=y)
ただし、次の問題が引き続き発生します。
raise TypeError("Invalid parameters passed: %s" % str(options))
TypeError: Invalid parameters passed: {'stratify': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])}
誰かが何が起こっているか考えていますか?以下は関数のドキュメントです。
[...]
stratify:array-likeまたはNone(デフォルトはNone)
Noneでない場合、データは層状に分割され、これをラベル配列として使用します。
バージョン0.17の新機能:stratifysplitting
[...]
Scikit-Learnは、「stratify」という引数を認識していないことを伝えているだけで、誤って使用しているわけではありません。これは、引用したドキュメントに示されているように、パラメーターがバージョン0.17で追加されたためです。
したがって、Scikit-Learnを更新するだけです。
このstratify
パラメーターは、生成されるサンプルの値の割合がパラメーターstratify
に提供される値の割合と同じになるように分割します。
たとえば、変数y
が値0
および1
を持つバイナリカテゴリ変数であり、ゼロの25%と1の75%がある場合、stratify=y
はランダム分割には、25%の0
と75%の1
があります。
Google経由でここに来る私の将来の自己のために:
train_test_split
はmodel_selection
にあるため、次のとおりです。
from sklearn.model_selection import train_test_split
# given:
# features: xs
# ground truth: ys
x_train, x_test, y_train, y_test = train_test_split(xs, ys,
test_size=0.33,
random_state=0,
stratify=ys)
それを使用する方法です。 random_state
の設定は、再現性のために望ましいです。
このコンテキストでは、階層化とは、train_test_splitメソッドが、入力データセットと同じ割合のクラスラベルを持つトレーニングおよびテストサブセットを返すことを意味します。
このコードを実行してみてください。「うまくいく」だけです。
from sklearn import cross_validation, datasets
iris = datasets.load_iris()
X = iris.data[:,:2]
y = iris.target
x_train, x_test, y_train, y_test = cross_validation.train_test_split(X,y,train_size=.8, stratify=y)
y_test
array([0, 0, 0, 0, 2, 2, 1, 0, 1, 2, 2, 0, 0, 1, 0, 1, 1, 2, 1, 2, 0, 2, 2,
1, 2, 1, 1, 0, 2, 1])