web-dev-qa-db-ja.com

Pytorchでネットを視覚化するにはどうすればよいですか?

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.models as models
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.autograd import Variable
from torchvision.models.vgg import model_urls
from torchviz import make_dot

batch_size = 3
learning_rate =0.0002
Epoch = 50

resnet = models.resnet50(pretrained=True)
print resnet
make_dot(resnet)

Pytorchモデルからresnetを視覚化したいと思います。どうすればできますか? torchvizを使用しようとしましたが、エラーが発生しました。

'ResNet' object has no attribute 'grad_fn'
23
raaj

make_dotには変数が必要です(つまり、grad_fn)、モデル自体ではありません。
試してください:

x = torch.zeros(1, 3, 224, 224, dtype=torch.float, requires_grad=False)
out = resnet(x)
make_dot(out)  # plot graph of variable, not of a nn.Module
0
Shai

PyTorchViz( https://github.com/szagoruyko/pytorchviz )、「PyTorch実行グラフとトレースの視覚化を作成する小さなパッケージ」をご覧ください。

Example PyTorchViz visualization

0
David J.