web-dev-qa-db-ja.com

PyTorchでBCELossを使用する方法は?

PyTorchで簡単なオートエンコーダーを作成し、 BCELoss を使用したいのですが、ターゲットが0〜1の間であると予想されるため、NaNを取得します。

9
Qubix

更新

BCELoss関数は、数値的に安定するために使用しませんでした。この問題を参照してください https://github.com/pytorch/pytorch/issues/751 。ただし、この問題は Pull#1792 で解決されているため、BCELossは数値的に安定しています。


古い答え

ソースからPyTorchを構築する場合、数値的に安定した関数BCEWithLogitsLosshttps://github.com/pytorch/pytorch/pull/1792 )を使用できます。入力として。

それ以外の場合は、次の関数を使用できます(上記の問題でyzgaoによって提供されています)。

class StableBCELoss(nn.modules.Module):
       def __init__(self):
             super(StableBCELoss, self).__init__()
       def forward(self, input, target):
             neg_abs = - input.abs()
             loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
             return loss.mean()
17
cheezer

ネットワークの終わりにシグモイド層を使用したい場合があります。そのようにして、数字は確率を表します。また、ターゲットが2進数であることを確認してください。完全なコードを投稿すると、さらに役立つ場合があります。

4
Roger Trullo