kerasで学習画像のデータ拡張をする ImageDataGenerator

プログラミング

This article is available in: English

はじめに

 深層学習による画像認識をする際、十分な学習画像データが無い場合に有用なテクニックが、データ拡張(データの水増し、data augmentation)です。この手法は、データを人工的に加工、合成することで学習画像データを新しく作成するものです。画像の加工方法は、一般に、画像の切り出し、反転、明度や彩度の変更などの変換があったりします。

本記事ではkerasのImageDataGeneratorを使用して、データ拡張をする方法を述べます。

本記事に記載されているコードは以下で試すことができます。

Google Colaboratory

ImageDataGeneratorでデータ拡張をやってみる

ライブラリのimportと画像の読み込み

 まずは必要なライブラリのimportと変換する画像を用意します。今回はこちらの画像で試してみます。

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.preprocessing.image import ImageDataGenerator, load_img, array_to_img

img_path = '対象の画像のpath'
target_img = load_img(img_path)
target_img = np.array(target_img)  # numpyのndarray形式に変換
x = target_img.reshape((1,) + target_img.shape)  # imgを4次元配列に変換

ここでは、ImageDataGeneratorクラスはndarray形式の4次元配列を要求するので、最後の行でその変換をおこなっています。

ImageDataGeneratorを作成
# [-90°, 90°]の範囲でランダムに回転
data_generator = ImageDataGenerator(rotation_range=90)  
# ImageDataGeneratorから変換画像のジェネレータを取得
generator = data_generator.flow(x, batch_size=1)  
# (NumBatches, Height, Width, Channels) の4次元配列を取得
batches = next(generator)  
# 画像に変換
get_img = array_to_img(batches[0])

plt.imshow(gen_img)
plt.show()

ImageDataGeneratorから変換画像を取得して表示します。とりあえず今回は90度回転の変換を施すことにします。他にもさまざまな変換がありますが、それは後述します。

4行目で変換画像のジェネレータを取得し、変換後の画像情報を含む4次元配列を6行目で生成しています。

実行すると、このように回転処理がおこなわれた画像が表示されます。

▲画像の回転変換ができた
変換した画像を保存する

flow()関数に下記の引数を追加すると、変換後の画像を保存できます。

import os

save_dir = 'output'  # 保存先のフォルダ
os.makedirs(save_dir, exist_ok=True)  # 保存先のフォルダが無い場合は作成

data_generator = ImageDataGenerator(rotation_range=90)  
# flow()に引数 save_to_dir, save_prefix, save_formatを追加して変換後画像を保存
generator = data_generator.flow(x, batch_size=1, save_to_dir=save_dir, save_prefix='generated', save_format='png')
batches = next(generator)

さまざまな変換

 ImageDataGeneratorに渡すパラメータを変えることでさまざまな画像変換ができます。ここでは次の関数定義を使用し、変換後の画像を4枚表示させてみます。

def show_gen_img(data_generator):
    """変換後画像を4枚表示"""
    generator = data_generator.flow(x, batch_size=1)
    plt.figure(figsize=(17,8))
    for i in range(4):
        batches = next(generator)
        gen_img = array_to_img(batches[0])
        plt.subplot(1, 4, i+1)
        plt.imshow(gen_img)
    plt.show()
回転
# [-120°, 120°]の範囲でランダムに回転
data_generator = ImageDataGenerator(rotation_range=120)
show_gen_img(data_generator)
上下反転
# ランダムに上下反転
data_generator = ImageDataGenerator(vertical_flip=True)
show_gen_img(data_generator)
左右反転
# ランダムに左右反転
data_generator = ImageDataGenerator(horizontal_flip=True)
show_gen_img(data_generator)
上下平行移動
# [-0.2*Height, 0.2*Height]の範囲でランダムに上下平行移動
data_generator = ImageDataGenerator(height_shift_range=0.2)
show_gen_img(data_generator)
左右平行移動
# [-0.2*Width, 0.2*Width]の範囲でランダムに左右平行移動
data_generator = ImageDataGenerator(width_shift_range=0.2)
show_gen_img(data_generator)
せん断
# [-20°, 20°]の範囲でランダムにせん断 
data_generator = ImageDataGenerator(shear_range=20)  
show_gen_img(data_generator)
拡大縮小
# [0.3, 1.5]の範囲でランダムに拡大縮小
data_generator = ImageDataGenerator(zoom_range=[0.3, 1.5])  
show_gen_img(data_generator)
明度変更
# [0.3, 2.0]の範囲でランダムに明度を変更
data_generator = ImageDataGenerator(brightness_range=[0.3, 2.0])  
show_gen_img(data_generator)
タイトルとURLをコピーしました