This article is available in: English
はじめに
深層学習による画像認識をする際、十分な学習画像データが無い場合に有用なテクニックが、データ拡張(データの水増し、data augmentation)です。この手法は、データを人工的に加工、合成することで学習画像データを新しく作成するものです。画像の加工方法は、一般に、画像の切り出し、反転、明度や彩度の変更などの変換があったりします。
本記事ではkerasのImageDataGeneratorを使用して、データ拡張をする方法を述べます。
本記事に記載されているコードは以下で試すことができます。
Google Colab
ImageDataGeneratorでデータ拡張をやってみる
ライブラリのimportと画像の読み込み
まずは必要なライブラリのimportと変換する画像を用意します。今回はこちらの画像で試してみます。
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.preprocessing.image import ImageDataGenerator, load_img, array_to_img
img_path = '対象の画像のpath'
target_img = load_img(img_path)
target_img = np.array(target_img) # numpyのndarray形式に変換
x = target_img.reshape((1,) + target_img.shape) # imgを4次元配列に変換
ここでは、ImageDataGeneratorクラスはndarray形式の4次元配列を要求するので、最後の行でその変換をおこなっています。
ImageDataGeneratorを作成
# [-90°, 90°]の範囲でランダムに回転
data_generator = ImageDataGenerator(rotation_range=90)
# ImageDataGeneratorから変換画像のジェネレータを取得
generator = data_generator.flow(x, batch_size=1)
# (NumBatches, Height, Width, Channels) の4次元配列を取得
batches = next(generator)
# 画像に変換
get_img = array_to_img(batches[0])
plt.imshow(gen_img)
plt.show()
ImageDataGeneratorから変換画像を取得して表示します。とりあえず今回は90度回転の変換を施すことにします。他にもさまざまな変換がありますが、それは後述します。
4行目で変換画像のジェネレータを取得し、変換後の画像情報を含む4次元配列を6行目で生成しています。
実行すると、このように回転処理がおこなわれた画像が表示されます。
変換した画像を保存する
flow()関数に下記の引数を追加すると、変換後の画像を保存できます。
import os
save_dir = 'output' # 保存先のフォルダ
os.makedirs(save_dir, exist_ok=True) # 保存先のフォルダが無い場合は作成
data_generator = ImageDataGenerator(rotation_range=90)
# flow()に引数 save_to_dir, save_prefix, save_formatを追加して変換後画像を保存
generator = data_generator.flow(x, batch_size=1, save_to_dir=save_dir, save_prefix='generated', save_format='png')
batches = next(generator)
さまざまな変換
ImageDataGeneratorに渡すパラメータを変えることでさまざまな画像変換ができます。ここでは次の関数定義を使用し、変換後の画像を4枚表示させてみます。
def show_gen_img(data_generator):
"""変換後画像を4枚表示"""
generator = data_generator.flow(x, batch_size=1)
plt.figure(figsize=(17,8))
for i in range(4):
batches = next(generator)
gen_img = array_to_img(batches[0])
plt.subplot(1, 4, i+1)
plt.imshow(gen_img)
plt.show()
回転
# [-120°, 120°]の範囲でランダムに回転
data_generator = ImageDataGenerator(rotation_range=120)
show_gen_img(data_generator)
上下反転
# ランダムに上下反転
data_generator = ImageDataGenerator(vertical_flip=True)
show_gen_img(data_generator)
左右反転
# ランダムに左右反転
data_generator = ImageDataGenerator(horizontal_flip=True)
show_gen_img(data_generator)
上下平行移動
# [-0.2*Height, 0.2*Height]の範囲でランダムに上下平行移動
data_generator = ImageDataGenerator(height_shift_range=0.2)
show_gen_img(data_generator)
左右平行移動
# [-0.2*Width, 0.2*Width]の範囲でランダムに左右平行移動
data_generator = ImageDataGenerator(width_shift_range=0.2)
show_gen_img(data_generator)
せん断
# [-20°, 20°]の範囲でランダムにせん断
data_generator = ImageDataGenerator(shear_range=20)
show_gen_img(data_generator)
拡大縮小
# [0.3, 1.5]の範囲でランダムに拡大縮小
data_generator = ImageDataGenerator(zoom_range=[0.3, 1.5])
show_gen_img(data_generator)
明度変更
# [0.3, 2.0]の範囲でランダムに明度を変更
data_generator = ImageDataGenerator(brightness_range=[0.3, 2.0])
show_gen_img(data_generator)
リンク
リンク