SSJB's blog

いろいろです。

TensorFlowで画像をTFRecordsで読み書きしてみる

TensorFlowでCNNを試しているのですが、教師データのJPEG画像をTFRecords形式への変換と読み書きに四苦八苦したので、そのメモ。

もくじ

まえおき

学習時にいろいろな画像サイズで試したくて、1枚の画像につき、

  • 32x32
  • 64x64
  • 96x96
  • 128x128

をそれぞれ用意していたのですが、どうにも面倒くさい。
さらに、ラベルは訳あって画像分のXMLファイルから学習時に毎回読み込んでいたので、読み込みにかなり時間がかかっていました。

訳というのは、教師データを作るために、まず最初に、SSD(Single Shot MultiBox Detector)で対象の領域を推測し、その推測された領域(ground truth bounding box)のデータとラベルをPASCAL VOCのAnnotation形式のXMLファイルに落として、labelimgを使って正確に修正して、最後に切り抜いています。
そのためにラベルが別になっているのですが、画像とラベルを一緒に格納できればかなり便利です。(普通はそうします)

最初はCIFAR-10のような形式も考えたのですが、TFRecordsが便利そうなのでこれを試してみました。

環境
・ tensorflow_gpu-1.1.0rc1
・ Python 3.5.2
・ Windows 10

TFRecords形式に書き出し

tf.python_io.TFRecordWriter() というのを使います。

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
from os.path import join, relpath
from glob import glob
from PIL import Image
import numpy as np
import tensorflow as tf


FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('directory', 'data', """Directory where to write *.tfrecords.""")


def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def convert_to(dataset, name):
    filename = os.path.join(FLAGS.directory, name + '.tfrecords')
    writer = tf.python_io.TFRecordWriter(filename)

    for data in dataset:
        image_path = data[0]
        label = int(data[1])

        image_object = Image.open(image_path)
        image = np.array(image_object)

        height = image.shape[0]
        width = image.shape[1]
        depth = 3
        image_raw = image.tostring()

        example = tf.train.Example(features=tf.train.Features(feature={
            'height': _int64_feature(height),
            'width': _int64_feature(width),
            'depth': _int64_feature(depth),
            'label': _int64_feature(label),
            'image_raw': _bytes_feature(image_raw)}))
        writer.write(example.SerializeToString())

    writer.close()


def main(unused_argv):
    if not tf.gfile.Exists(FLAGS.directory):
        tf.gfile.MakeDirs(FLAGS.directory)

    label_data = [
        ['dog', 0],
        ['cat', 1]
    ]

    img_data = []
    for n, v in label_data:
        path = os.path.join('images', n)
        
        for file in [relpath(x, path) for x in glob(join(path, '*.jpg'))]:
            img_data.append([os.path.join(path, file), v])

    convert_to(img_data, 'train')


if __name__ == '__main__':
    tf.app.run()

画像の読み込みは、PILを使っています。
こうすると、train.tfrecordsで書き出されます。

中身を復元して確認する

書き出されたtfrecordsが正しく書き出されているか、実際に画像に復元して確かめてみます。

読み込みには、tf.TFRecordReader() というのを使います。

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import sys
from os.path import join, relpath
from glob import glob
from PIL import Image
import numpy as np
import tensorflow as tf


FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('directory', 'data', """Directory where to read *.tfrecords.""")


def read_and_decode(filename_queue):
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)

    features = tf.parse_single_example(
        serialized_example,
        features={
            'height': tf.FixedLenFeature([], tf.int64),
            'width': tf.FixedLenFeature([], tf.int64),
            'label': tf.FixedLenFeature([], tf.int64),
            'depth': tf.FixedLenFeature([], tf.int64),
            'image_raw': tf.FixedLenFeature([], tf.string),
        })

    image_raw = tf.decode_raw(features['image_raw'], tf.uint8)
    height = tf.cast(features['height'], tf.int32)
    width = tf.cast(features['width'], tf.int32)
    depth = tf.cast(features['depth'], tf.int32)
    label = tf.cast(features['label'], tf.int32)

    image = tf.reshape(image_raw, tf.stack([height, width, depth]))

    return image, label


def inputs():
    if not FLAGS.directory:
        raise ValueError('Please supply a directory')

    tfrecords_filename = 'train'
    filename = os.path.join(FLAGS.directory, tfrecords_filename + '.tfrecords')
    filename_queue = tf.train.string_input_producer([filename])

    image, label = read_and_decode(filename_queue)

    return image, label


def main(unused_argv):
    if not os.path.exists('output'):
        os.mkdir('output')

    images, labels = inputs()

    init = tf.global_variables_initializer()

    with tf.Session() as sess:

        sess.run(init)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        try:
            for i in range(6):
                e, l = sess.run([images, labels])
                img = Image.fromarray(e, 'RGB')
                img.save(os.path.join('output', "{0}-{1}.jpg".format(str(i), l)))
        finally:
            coord.request_stop()
            coord.join(threads)


if __name__ == '__main__':
    tf.app.run()

ポイントは、

tf.train.string_input_producer([filename])

でキューを作成し、

reader = tf.TFRecordReader()
reader.read(filename_queue)

で読み込み、

coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)

coord.request_stop()
coord.join(threads)

で処理をします。

画像を含めたサンプルを上げました。 github.com

参考