web-dev-qa-db-ja.com

列に基づくsklearn層別サンプリング

pandasデータフレームに読み込むAmazonレビューデータを含むかなり大きなCSVファイルがあります。データ80-20(train-test)を分割したいのですが、そうする間に分割データが1つの列(カテゴリ)の値を比例的に表していることを確認してください。つまり、すべての異なるカテゴリのレビューが、トレインデータとテストデータの両方に比例して存在します。

データは次のようになります。

**ReviewerID**       **ReviewText**        **Categories**       **ProductId**

1212                   good product         Mobile               14444425
1233                   will buy again       drugs                324532
5432                   not recomended       dvd                  789654123 

そのために次のコードを使用しています。

import pandas as pd
Meta = pd.read_csv('C:\\Users\\xyz\\Desktop\\WM Project\\Joined.csv')
import numpy as np
from sklearn.cross_validation import train_test_split

train, test = train_test_split(Meta.categories, test_size = 0.2, stratify=y)

次のエラーが発生します

NameError: name 'y' is not defined

私はpythonに比較的新しいので、私は間違っているのか、このコードが列カテゴリに基づいて階層化されるかどうかを理解することはできません。stratifyオプションを削除するとうまく機能するようです電車とテストの分割のカテゴリ列も同様です。

任意の助けをいただければ幸いです。

14
    >>> import pandas as pd
    >>> Meta = pd.read_csv('C:\\Users\\*****\\Downloads\\so\\Book1.csv')
    >>> import numpy as np
    >>> from sklearn.model_selection import train_test_split
    >>> y = Meta.pop('Categories')
    >>> Meta
        ReviewerID      ReviewText  ProductId
        0        1212    good product   14444425
        1        1233  will buy again     324532
        2        5432  not recomended  789654123
    >>> y
        0    Mobile
        1     drugs
        2       dvd
        Name: Categories, dtype: object
    >>> X = Meta
    >>> X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.33, random_state=42, stratify=y)
    >>> X_test
        ReviewerID    ReviewText  ProductId
        0        1212  good product   14444425
18
nEO

sklearn.model_selection.train_test_split

stratify:配列のようまたはなし(デフォルトはなし)

Noneでない場合、データは層状に分割され、これをクラスラベルとして使用します。

APIドキュメントに沿って、X_train, X_test, y_train, y_test = train_test_split(Meta_X, Meta_Y, test_size = 0.2, stratify=Meta_Y)のように試す必要があると思います。

Meta_XMeta_Yはあなたが適切に割り当てる必要があります(Meta_YMeta.categoriesコードに基づきます)。

10
su79eu7k