web-dev-qa-db-ja.com

Pandas transform()とapply()

applytransformが同じデータフレームで呼び出されたときに異なるdtypeを返す理由がわかりません。 「apply」はデータを折りたたむ前に2つの関数を説明しましたが、transformapplyとまったく同じことを行いますが、元のインデックスと崩壊しません。」以下を検討してください。

_df = pd.DataFrame({'id': [1,1,1,2,2,2,2,3,3,4],
                   'cat': [1,1,0,0,1,0,0,0,0,1]})
_

id列にゼロ以外のエントリがあるcatsを特定しましょう。

_>>> df.groupby('id')['cat'].apply(lambda x: (x == 1).any())
id
1     True
2     True
3    False
4     True
Name: cat, dtype: bool
_

すごい。ただし、インジケーター列を作成する場合は、次のようにできます。

_>>> df.groupby('id')['cat'].transform(lambda x: (x == 1).any())
0    1
1    1
2    1
3    1
4    1
5    1
6    1
7    0
8    0
9    1
Name: cat, dtype: int64
_

any()関数から返されたブール値ではなく、dtypeが_int64_になっている理由がわかりません。

元のデータフレームをいくつかのブール値を含むように変更すると(ゼロが残っていることに注意)、変換アプローチはobject列にブール値を返します。すべての値がブール値であるため、これは私にとって特別な謎ですが、objectとしてリストされているのは、整数とブール値の元の混合型列のdtypeと一致するようです。

_df = pd.DataFrame({'id': [1,1,1,2,2,2,2,3,3,4],
                   'cat': [True,True,0,0,True,0,0,0,0,True]})

>>> df.groupby('id')['cat'].transform(lambda x: (x == 1).any())
0     True
1     True
2     True
3     True
4     True
5     True
6     True
7    False
8    False
9     True
Name: cat, dtype: object
_

ただし、すべてのブール値を使用すると、変換関数はブール値の列を返します。

_df = pd.DataFrame({'id': [1,1,1,2,2,2,2,3,3,4],
                   'cat': [True,True,False,False,True,False,False,False,False,True]})

>>> df.groupby('id')['cat'].transform(lambda x: (x == 1).any())
0     True
1     True
2     True
3     True
4     True
5     True
6     True
7    False
8    False
9     True
Name: cat, dtype: bool
_

私の鋭いパターン認識スキルを使用して、結果の列のdtypeは元の列のそれを反映しているようです。これがなぜ発生するのか、またはtransform関数の内部で何が起こっているのかについてのヒントを教えていただければ幸いです。乾杯。

16
3novak

SeriesGroupBy.transform()は結果のdtypeを元の列と同じものにキャストしようとするようですが、DataFrameGroupBy.transform()はそれを行っていないようです:

In [139]: df.groupby('id')['cat'].transform(lambda x: (x == 1).any())
Out[139]:
0    1
1    1
2    1
3    1
4    1
5    1
6    1
7    0
8    0
9    1
Name: cat, dtype: int64

#                         v       v
In [140]: df.groupby('id')[['cat']].transform(lambda x: (x == 1).any())
Out[140]:
     cat
0   True
1   True
2   True
3   True
4   True
5   True
6   True
7  False
8  False
9   True

In [141]: df.dtypes
Out[141]:
cat    int64
id     int64
dtype: object
9
MaxU

私がそれをより明白だと思うので、合計で別の説明的な例を追加するだけです:

df = (
    pd.DataFrame(pd.np.random.Rand(10, 3), columns=['a', 'b', 'c'])
        .assign(a=lambda df: df.a > 0.5)
)

Out[70]: 
       a         b         c
0  False  0.126448  0.487302
1  False  0.615451  0.735246
2  False  0.314604  0.585689
3  False  0.442784  0.626908
4  False  0.706729  0.508398
5  False  0.847688  0.300392
6  False  0.596089  0.414652
7  False  0.039695  0.965996
8   True  0.489024  0.161974
9  False  0.928978  0.332414

df.groupby('a').apply(sum)  # drop rows

         a         b         c
a                             
False  0.0  4.618465  4.956997
True   1.0  0.489024  0.161974


df.groupby('a').transform(sum)  # keep dims

          b         c
0  4.618465  4.956997
1  4.618465  4.956997
2  4.618465  4.956997
3  4.618465  4.956997
4  4.618465  4.956997
5  4.618465  4.956997
6  4.618465  4.956997
7  4.618465  4.956997
8  0.489024  0.161974
9  4.618465  4.956997

ただし、pd.DataFrameオブジェクトではなくpd.GroupByオブジェクトに適用すると、違いを確認できませんでした。

0
ClementWalter