画深度学习网络可视化模型
1. 安装pytorch
2. 安装 pytorchviz
使用的是pip install git+https://github.com/szagoruyko/pytorchviz
安装完之后发现还必须要安装graphviz
3. 安装 graphviz
安装完之后报了一大串错误,查过博客发现系统还需要安装,使用sudo conda install graphviz 安装
4. 通过测试代码
import torch
from torchvision import models
from torchviz import make_dot
model = models.vgg19()
x = torch.randn(1, 3, 224, 224)
vis_graph = make_dot(model(x),params=dict(model.named_parameters()))
vis_graph.view()
结果如下: