kumilog.net

データ分析やプログラミングの話などを書いています。

Chainerでデータセットを作る

f:id:xkumiyu:20171101232226j:plain

Chainerでデータセットを作る方法についていくつか紹介したいと思います。なお、執筆時点のversionは3.0.0です。

シンプルなデータセット

基本的には、len()で要素数を取得できたり、[1:2]のようにスライスで要素を取得できれば、データセットとして扱うことができます。

ということで、最もシンプルなデータセットはNumpyの配列です。

dataset = numpy.array([1, 2, 3], dtype=numpy.float32)

また、TupleやDictを扱うTupleDataset()DictDataset()が用意されています。

ただ、メモリに収まらないような大規模なデータを扱う場合は使えないので、これらの形式のデータセットの出番はないかもしれません。

画像を扱うデータセット

画像を扱う場合は、ImageDatasetLabeledImageDatasetが便利です。画像処理にPILを使っているので、予めインストールが必要です。

$ pip install Pillow

ImageDataset

ImageDatasetは、画像ファイルからデータを読み込むデータセットです。学習に使うときに(データにアクセスするときに)読み込むので、大規模なデータでも扱うことができます。

PILで読み込んで、Numpyの配列(デフォルトでnumpy.float32型)に変換するだけで、リサイズや標準化などの処理は行わないので、必要であれば後述の方法などで実装する必要があります。

また、データのshapeは、(channels, height, width)になります。ImageDatasetに限らずChainerで画像を扱うときは、この順番にします。Tensorflowとは異なるので注意が必要です。

ImageDatasetの使い方は、引数に画像ファイルのパスのリストか、画像ファイルのパスが書かれたファイルを指定します。

  • 方法1: 画像ファイルのパスのリストを指定
image_files = ['1.jpg', '2.jpg']
dataset = chainer.datasets.ImageDataset(image_files)
  • 方法2: 画像ファイルのパスが書かれたファイルを指定
# images.txt
1.jpg
2.jpg

1行に1画像ファイル記載します。

image_files = 'images.txt'
dataset = chainer.datasets.ImageDataset(image_files)

どちらの方法でも、画像ファイルのrootディレクトリを指定できます。

dataset = chainer.datasets.ImageDataset(image_files, root='images/')

LabeledImageDataset

LabeledImageDatasetは、画像とラベルをセットにしたデータセットで教師あり学習の場合はこちらを使うと良いでしょう。

使い方は、基本的にImageDatasetと同じですが、画像ファイルのパスではなく、画像ファイルのパスとラベルのペアを指定します。ラベルは0からのint*1にします。

  • 方法1: 画像ファイルのパスとラベルのリストを指定
image_files = [('1.jpg', 0), ('2.jpg', 1)]
dataset = chainer.datasets.LabeledImageDataset(image_files)
  • 方法2: 画像ファイルのパスとラベルが書かれたファイルを指定
# images.txt
1.jpg 0
2.jpg 1

1行に1画像ファイルとラベルを記載します。ファイルが先でラベルとはspaceで区切ります。

image_files = 'images.txt'
dataset = chainer.datasets.LabeledImageDataset(image_files)

教師なし学習でImageDatasetを使う場合は、いろんな場所に画像が散らばっていない限り、方法1で読み込む方が便利だと思います。それに対して、教師あり学習でLabeledImageDatasetを使う場合は、予め画像とラベルを書いたファイルを用意しておいて、方法2を使う方が良いと思います。

データセットに処理を加える

ImageDatasetやLabeledImageDatasetでは、画像を読み込むだけでしたので、前処理を行いたい場合は、少し手を加える必要があります。

簡単な処理の場合、TransformDatasetを用いると良いです。TransformDatasetの引数は、データセットと変換する関数です。

dataset = chainer.datasets.ImageDataset(image_files)

def transform(data):
    return data / 255.
dataset = chainer.datasets.TransformDataset(dataset, transform)

LabeledImageDatasetのように、1データがtupleの場合は、tupleを受け取って、tupleを返す変換関数を用意します。

dataset = chainer.datasets.LabeledImageDataset(image_files)

def transform(data):
    img, lable = data
    img = img / 255.
    return img, lable
dataset = chainer.datasets.TransformDataset(dataset, transform)

独自のデータセット

TransformDatasetで対応できない処理や、画像以外のデータを扱うときは、DatasetMixinを継承して独自のデータセットクラスを用意します。

必要な関数は、データセットの要素数を返す__len__()とデータ自体を返すget_example()です。

データセットのデータにアクセスするときにget_example()が呼ばれます。なので、get_example()に前処理を書くことできます。

ただし、データにアクセスする度に処理が行われるということは、重い処理があると学習が遅くなる原因にもなり、注意が必要です。重い処理は事前に実施しておき、処理済みのデータを読み込む方が良いです。

DatasetMixinを使って、画像をリサイズする例を書いてみます。基本的にImageDatasetと同じように書いていますが、読み込み方法は方法1のみです。

import chainer
import numpy
from PIL import Image


class ResizedImageDataset(chainer.dataset.DatasetMixin):
    def __init__(self, paths, size, dtype=numpy.float32):
        self._paths = paths
        self._size = size
        self._dtype = dtype

    def __len__(self):
        # データセットの数を返します
        return len(self._paths)

    def get_example(self, i):
        # データセットのインデックスを受け取って、データを返します
        img= Image.open(self._paths[i])
        img = img.resize(self._size) # PILをつかってリサイズ
        img = numpy.asarray(img, dtype=self._dtype) # float32型のnumpy arrayに変換
        img = img.transpose(2, 0, 1) # PILのImageは(height, width, channel)なのでChainerの形式に変換
        return img

image_files = ['1.jpg', '2.jpg']
dataset = ResizedImageDataset(image_files, size=(32, 32))

データセットの分割

データセットを2分割するときは、split_datasetsplit_dataset_randomを使うと良いです。

split_datasetは単純に2つのデータセットに分割し、split_dataset_randomはランダムに分割します。必須の引数は、分割前のデータセットと分割するサイズです。

以下のように、学習データとテストデータを8:2に分けることができます。

split_at = int(len(dataset) * 0.8)
train, test = chainer.datasets.split_dataset(dataset, split_at)

データセットへのアクセス

最後にデータセットのデータの中身を取り出す方法です。シンプルなデータセットで説明したように、データセットはNumpy配列と同じように扱うことができます。

# 最初のデータ
dataset[0]

# スライスでアクセス
dataset[1:2]

# データセットの要素数
len(dataset)

*1:正確にはnumpy.int32型にします。LabeledImageDatasetのデフォルトの変換がnumpy.int32なので、普通のintでも問題ないです。