import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.models as models
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.autograd import Variable
from torchvision.models.vgg import model_urls
from torchviz import make_dot
batch_size = 3
learning_rate =0.0002
Epoch = 50
resnet = models.resnet50(pretrained=True)
print resnet
make_dot(resnet)
Pytorchモデルからresnet
を視覚化したいと思います。どうすればできますか? torchviz
を使用しようとしましたが、エラーが発生しました。
'ResNet' object has no attribute 'grad_fn'
make_dot
には変数が必要です(つまり、grad_fn
)、モデル自体ではありません。
試してください:
x = torch.zeros(1, 3, 224, 224, dtype=torch.float, requires_grad=False)
out = resnet(x)
make_dot(out) # plot graph of variable, not of a nn.Module
PyTorchViz( https://github.com/szagoruyko/pytorchviz )、「PyTorch実行グラフとトレースの視覚化を作成する小さなパッケージ」をご覧ください。