新闻  |   论坛  |   博客  |   在线研讨会
热文 | 卷积神经网络入门案例,轻松实现花朵分类(1)
AI科技大本营 | 2021-05-15 12:17:44    阅读:580   发布文章

前言

本文介绍卷积神经网络的入门案例,通过搭建和训练一个模型,来对几种常见的花朵进行识别分类;使用到TF的花朵数据集,它包含5类,即:“雏菊”,“蒲公英”,“玫瑰”,“向日葵”,“郁金香”;共 3670 张彩色图片;通过搭建和训练卷积神经网络模型,对图像进行分类,能识别出图像是“蒲公英”,或“玫瑰”,还是其它。

1.png

本篇文章主要的意义是带大家熟悉卷积神经网络的开发流程,包括数据集处理、搭建模型、训练模型、使用模型等;更重要的是解在训练模型时遇到“过拟合”,如何解决这个问题,从而得到“泛化”更好的模型。

思路流程

  • 导入数据集

  • 探索集数据,并进行数据预处理

  • 构建模型(搭建神经网络结构、编译模型)

  • 训练模型(把数据输入模型、评估准确性、作出预测、验证预测)  

  • 使用训练好的模型

  • 优化模型、重新构建模型、训练模型、使用模型

目录

  • 导入数据集

  • 探索集数据,并进行数据预处理

  • 构建模型

  • 训练模型

  • 使用模型

  • 优化模型、重新构建模型、训练模型、使用模型(过拟合、数据增强、正则化、重新编译和训练模型、预测新数据)

导入数据集

使用到TF的花朵数据集,它包含5类,即:“雏菊”,“蒲公英”,“玫瑰”,“向日葵”,“郁金香”;共 3670 张彩色图片;数据集包含5个子目录,每个子目录种存放一个类别的花朵图片。

# 下载数据集
import pathlib
dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
data_dir = pathlib.Path(data_dir)
# 查看数据集图片的总数量
image_count = len(list(data_dir.glob('*/*.jpg')))
print(image_count)

探索集数据,并进行数据预处理

查看一张郁金香的图片: 

# 查看郁金香tulips目录下的第1张图片;
tulips = list(data_dir.glob('tulips/*'))
PIL.Image.open(str(tulips[0]))

2.png

加载数据集的图片,使用keras.preprocessing从磁盘上加载这些图像。

# 定义加载图片的一些参数,包括:批量大小、图像高度、图像宽度
batch_size = 32
img_height = 180
img_width = 180
# 将80%的图像用于训练
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)
# 将20%的图像用于验证
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="validation",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)
# 打印数据集中花朵的类别名称,字母顺序对应于目录名称
class_names = train_ds.class_names
print(class_names)

查看一下训练数据集中的9张图像

# 查看一下训练数据集中的9张图像
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 10))
for images, labels in train_ds.take(1):
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.title(class_names[labels[i]])
    plt.axis("off")

图像形状

传递这些数据集来训练模型model.fit,可以手动遍历数据集并检索成批图像:

for image_batch, labels_batch in train_ds:
  print(image_batch.shape)
  print(labels_batch.shape)
  break

能看到输出:(32, 180, 180, 3)   (32,)

image_batch是图片形状的张量(32, 180, 180, 3)。32是指批量大小;180,180分别表示图像的高度、宽度,3是颜色通道RGB。32张图片组成一个批次。

label_batch是形状的张量(32,),对应32张图片的标签。

数据集预处理

下面进行数据集预处理,将像素的值标准化至0到1的区间内:

# 将像素的值标准化至0到1的区间内。
normalization_layer = layers.experimental.preprocessing.Rescaling(1./255)

为什么是除以255呢?由于图片的像素范围是0~255,我们把它变成0~1的范围,于是每张图像(训练集、测试集)都除以255。

标准化数据

# 调用map将其应用于数据集:
normalized_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
image_batch, labels_batch = next(iter(normalized_ds))
first_image = image_batch[0]
# Notice the pixels values are now in `[0,1]`.
print(np.min(first_image), np.max(first_image))


*博客内容为网友个人发布,仅代表博主个人观点,如有侵权请联系工作人员删除。

参与讨论
登录后参与讨论
推荐文章
最近访客