码迷,mamicode.com
首页 > 其他好文 > 详细

tensflow2 基础

时间:2020-06-09 16:48:21      阅读:50      评论:0      收藏:0      [点我收藏+]

标签:shuff   sum   bat   show   scale   port   legend   save   data   

https://blog.csdn.net/lzs781/article/details/104742043/

 

官网

https://tensorflow.google.cn/tutorials/images/classification

一、生成模型 , 为了增加训练的精确率,可以使 epochs 值变大

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D
import os
import matplotlib.pyplot as plt


#


# 1. 训练路径
PATH = rC:\Users\wuhao\Desktop\cats_and_dogs_filtered\cats_and_dogs_filtered
train_dir = os.path.join(PATH, train)
train_cats_dir = os.path.join(train_dir, cats)
train_dogs_dir = os.path.join(train_dir, dogs)


batch_size = 128
epochs = 5
IMG_HEIGHT = 150
IMG_WIDTH = 150

# 2.转化为生成器
train_image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
train_data_gen = train_image_generator.flow_from_directory(batch_size=batch_size,
                                                           directory=train_dir,
                                                           shuffle=True,
                                                           target_size=(IMG_HEIGHT, IMG_WIDTH),
                                                           class_mode=binary)


sample_training_images, _ = next(train_data_gen)


# 3.展示图片(可有可无)
def plot_images(images_arr):
    fig, axes = plt.subplots(1, 5, figsize=(20, 20))
    axes = axes.flatten()
    for img, ax in zip(images_arr, axes):
        ax.imshow(img)
        ax.axis(off)
    plt.tight_layout()
    plt.show()


# 显示 5张 图片
plot_images(sample_training_images[:5])

# 4. 创建模型

model = Sequential([
    Conv2D(16, 3, padding=same, activation=relu, input_shape=(IMG_HEIGHT, IMG_WIDTH ,3)),
    MaxPooling2D(),
    Conv2D(32, 3, padding=same, activation=relu),
    MaxPooling2D(),
    Conv2D(64, 3, padding=same, activation=relu),
    MaxPooling2D(),
    Flatten(),
    Dense(512, activation=relu),
    Dense(1)
])


# 5. 编译模型
model.compile(
    optimizer=adam,
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[accuracy]
)

model.summary()


# 6.训练模型
num_cats_tr = len(os.listdir(train_cats_dir))
num_dogs_tr = len(os.listdir(train_dogs_dir))
total_train = num_cats_tr + num_dogs_tr
history = model.fit_generator(
    train_data_gen,
    steps_per_epoch=total_train // batch_size,
    epochs=epochs,
)
# 7.训练结果可视化
acc = history.history[accuracy]
loss = history.history[loss]
epochs_range = range(epochs)
plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label=Training Accuracy)
plt.legend(loc=lower right)
plt.title(Training and Validation Accuracy)
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label=Training Loss)
plt.legend(loc=upper right)
plt.title(Training and Validation Loss)
plt.show()

# 9. 保存训练模型
model.save(path_to_my_model.h5)

2、加载模型

import tensorflow as tf
import os

#

batch_size = 128
epochs = 5
IMG_HEIGHT = 150
IMG_WIDTH = 150
PATH = rC:\Users\wuhao\Desktop\cats_and_dogs_filtered\cats_and_dogs_filtered
validation_dir = os.path.join(PATH, validation)
# 1.加载模型
new_model = tf.keras.models.load_model(path_to_my_model.h5)

new_model.summary()

# 2.获取验证的生成器
validation_image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
val_data_gen = validation_image_generator.flow_from_directory(batch_size=batch_size,
                                                              directory=validation_dir,
                                                              target_size=(IMG_HEIGHT, IMG_WIDTH),
                                                              class_mode=binary)
# 3.获取模型的精确率和 损失率
res = new_model.evaluate(val_data_gen)
print(res)

 

tensflow2 基础

标签:shuff   sum   bat   show   scale   port   legend   save   data   

原文地址:https://www.cnblogs.com/wt7018/p/13073035.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!