码迷,mamicode.com
首页 > 其他好文 > 详细

tensorflow2.0——各批次loss、acc及可视化

时间:2020-11-30 16:04:44      阅读:7      评论:0      收藏:0      [点我收藏+]

标签:pyplot   poc   ros   mic   cat   his   layer   test   puts   

一、loss、acc提取

  有时候我们需要查看每个batch训练时候的损失loss与准确率acc,这样可以帮助我们挑选合适的epoch以及查看模型是否收敛。

  Model.fit()在调用时会返回一个History类,这个类的一个属性Historty.history是一个字典,里面就包含了每一个batch的测试集与验证集的loss、acc。

# 模型训练
history = model.fit(train_images, train_labels, batch_size=50, epochs=5, validation_split=0.1, verbose=1)

history.history.keys()  # 查看字典的键

loss = history.history[loss]  # 测试集损失
acc = history.history[acc]  # 测试集准确率
val_loss = history.history[val_loss]  # 验证集损失
val_acc = history.history[val_acc]  # 验证集准确率

  

二、使用matplotlib可视化

  这里可视化用到的包是matplotlib,暂不提供在tensorboard上的可视化,详细使用如下。

import tensorflow as tf
import matplotlib.pyplot as plt

# 读取数据集
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.fashion_mnist.load_data()

# 数据集归一化
train_images = train_images / 255
train_labels = train_labels / 255  # 进行数据的归一化,加快计算的进程

# 创建模型结构
net_input = tf.keras.Input(shape=(28, 28))
fl = tf.keras.layers.Flatten()(net_input)  # 调用input
l1 = tf.keras.layers.Dense(32, activation="relu")(fl)
l2 = tf.keras.layers.Dropout(0.5)(l1)
net_output = tf.keras.layers.Dense(10, activation="softmax")(l2)

# 创建模型类
model = tf.keras.Model(inputs=net_input, outputs=net_output)

# 查看模型的结构
model.summary()

# 模型编译
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
              loss="sparse_categorical_crossentropy",
              metrics=[acc])

# 模型训练
history = model.fit(train_images, train_labels, batch_size=50, epochs=5, validation_split=0.1, verbose=1)

history.history.keys()  # 查看字典的键

loss = history.history[loss]  # 测试集损失
acc = history.history[acc]  # 测试集准确率
val_loss = history.history[val_loss]  # 验证集损失
val_acc = history.history[val_acc]  # 验证集准确率

# 可视化,定义2*2的画布
plt.figure()
plt.subplot(221)
plt.plot(loss)
plt.title(loss)
plt.subplot(222)
plt.plot(acc)
plt.title(acc)
plt.subplot(223)
plt.plot(val_loss)
plt.title(val_loss)
plt.subplot(224)
plt.plot(val_acc)
plt.title(val_acc)
plt.show()

输出结果:

技术图片

 

tensorflow2.0——各批次loss、acc及可视化

标签:pyplot   poc   ros   mic   cat   his   layer   test   puts   

原文地址:https://www.cnblogs.com/dwithy/p/14036881.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!