pytorch tensorboard使用教程
[TOC]
安装
- 安装tensorboard
pip install tensorboard
- 启动
tensorboard --logdir=./log
启动成功后会显示如下
....
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.13.0 at http://localhost:6006/ (Press CTRL+C to quit)
浏览器打开http://localhost:6006/
,即可打开tensorboard。
如果想在指定ip和端口启动,可以如下使用:
tensorboard --logdir=./log --host=localhost --port=port
tensorboard的使用逻辑
TensorBoard的工作流程如下:
- 将代码运行过程中,某些关心的数据保存在一个文件件中
- 再使用
tensorboard
读取这个文件夹中的数据,在浏览器中显示
from torch.utils.tensorboard import SummaryWriter
# SummaryWriter将特定的数据存储在文件夹中
add_scalar方法
这个方法通常用来可视化网络训练时的各类标量参数,例如损失、学习率和准确率等。主要是数值类型的曲线图
import numpy as np
# 实例化writer
writer = SummaryWriter("./log/demo")
# 数值型
for n_iter in range(100):
writer.add_scalar(tag="Loss/train",scalar_value=np.random.random(),global_step=n_iter)
writer.add_scalar("Loss/test",np.random.random(),n_iter)
writer.close()
add_graph方法
add_graph方法是用于可视化模型的网络结构图
import torchvision
import torch
writer = SummaryWriter("./log/graph")
img = torch.rand([1,3,64,64],dtype=torch.float32)
model = torchvision.models.AlexNet(num_classes=10)
writer.add_graph(model=model,input_to_model=img)
writer.close()
add_scalars方法
这个方法与add_scalar的差别在于add_scalars在一张图中可以绘制多个曲线,我们只需要以字典的形式传入参数即可
writer = SummaryWriter("./log/scalars")
r = 5
for x in range(1, 101) :
writer.add_scalars('run_14h', {'xsinx' : x * np.sin(x / r),
'xcosx' : x * np.cos(x / r),
'xtanx' : x * np.tan(x / r)}, x)
writer.close()
add_histogram方法
直方图
writer = SummaryWriter("./log/historgram")
for step in range(10) :
x = np.random.randn(1000)
writer.add_histogram('distribution of gaussion', x, step)
writer.close()
add_image
显示图片
import cv2 as cv
writer = SummaryWriter("./log/image")
img = cv.imread('./data/img/watch.jpg', cv.IMREAD_COLOR)#输入图像要是3通道的,所以读取彩色图像
img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
img = torch.tensor(img.transpose(2, 0, 1))#cv读取为numpy图像为(H * W * C),所以要进行轴转换
writer.add_image('watch', img, 0)
writer.close()
add_figure
add_figure() 方法是 PyTorch TensorBoard SummaryWriter 类的一个方法,用于记录和显示一组 matplotlib 绘制的 Figure 对象。它可以让您轻松地在 TensorBoard 上可视化像数据样本、样本预测、模型输出等自定义图形,以帮助您更好地了解和分析模型训练过程中的数据。
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import matplotlib.pyplot as plt
writer = SummaryWriter("./log/figure")
x = np.linspace(0, 10, 1000)
y = np.sin(x)
figure1 = plt.figure()
plt.plot(x, y, 'r-')
writer.add_figure('my_figure', figure1, 0)
writer.close()
参考
转载自:https://juejin.cn/post/7241078623553273915