TensorFlowでCNNを試しているのですが、教師データのJPEG画像をTFRecords形式への変換と読み書きに四苦八苦したので、そのメモ。
もくじ
まえおき
学習時にいろいろな画像サイズで試したくて、1枚の画像につき、
- 32x32
- 64x64
- 96x96
- 128x128
をそれぞれ用意していたのですが、どうにも面倒くさい。
さらに、ラベルは訳あって画像分のXMLファイルから学習時に毎回読み込んでいたので、読み込みにかなり時間がかかっていました。
訳というのは、教師データを作るために、まず最初に、SSD(Single Shot MultiBox Detector)で対象の領域を推測し、その推測された領域(ground truth bounding box)のデータとラベルをPASCAL VOCのAnnotation形式のXMLファイルに落として、labelimgを使って正確に修正して、最後に切り抜いています。
そのためにラベルが別になっているのですが、画像とラベルを一緒に格納できればかなり便利です。(普通はそうします)
最初はCIFAR-10のような形式も考えたのですが、TFRecordsが便利そうなのでこれを試してみました。
環境 ・ tensorflow_gpu-1.1.0rc1 ・ Python 3.5.2 ・ Windows 10
TFRecords形式に書き出し
tf.python_io.TFRecordWriter() というのを使います。
from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import sys from os.path import join, relpath from glob import glob from PIL import Image import numpy as np import tensorflow as tf FLAGS = tf.app.flags.FLAGS tf.app.flags.DEFINE_string('directory', 'data', """Directory where to write *.tfrecords.""") def _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) def _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) def convert_to(dataset, name): filename = os.path.join(FLAGS.directory, name + '.tfrecords') writer = tf.python_io.TFRecordWriter(filename) for data in dataset: image_path = data[0] label = int(data[1]) image_object = Image.open(image_path) image = np.array(image_object) height = image.shape[0] width = image.shape[1] depth = 3 image_raw = image.tostring() example = tf.train.Example(features=tf.train.Features(feature={ 'height': _int64_feature(height), 'width': _int64_feature(width), 'depth': _int64_feature(depth), 'label': _int64_feature(label), 'image_raw': _bytes_feature(image_raw)})) writer.write(example.SerializeToString()) writer.close() def main(unused_argv): if not tf.gfile.Exists(FLAGS.directory): tf.gfile.MakeDirs(FLAGS.directory) label_data = [ ['dog', 0], ['cat', 1] ] img_data = [] for n, v in label_data: path = os.path.join('images', n) for file in [relpath(x, path) for x in glob(join(path, '*.jpg'))]: img_data.append([os.path.join(path, file), v]) convert_to(img_data, 'train') if __name__ == '__main__': tf.app.run()
画像の読み込みは、PILを使っています。
こうすると、train.tfrecordsで書き出されます。
中身を復元して確認する
書き出されたtfrecordsが正しく書き出されているか、実際に画像に復元して確かめてみます。
読み込みには、tf.TFRecordReader() というのを使います。
from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import sys from os.path import join, relpath from glob import glob from PIL import Image import numpy as np import tensorflow as tf FLAGS = tf.app.flags.FLAGS tf.app.flags.DEFINE_string('directory', 'data', """Directory where to read *.tfrecords.""") def read_and_decode(filename_queue): reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) features = tf.parse_single_example( serialized_example, features={ 'height': tf.FixedLenFeature([], tf.int64), 'width': tf.FixedLenFeature([], tf.int64), 'label': tf.FixedLenFeature([], tf.int64), 'depth': tf.FixedLenFeature([], tf.int64), 'image_raw': tf.FixedLenFeature([], tf.string), }) image_raw = tf.decode_raw(features['image_raw'], tf.uint8) height = tf.cast(features['height'], tf.int32) width = tf.cast(features['width'], tf.int32) depth = tf.cast(features['depth'], tf.int32) label = tf.cast(features['label'], tf.int32) image = tf.reshape(image_raw, tf.stack([height, width, depth])) return image, label def inputs(): if not FLAGS.directory: raise ValueError('Please supply a directory') tfrecords_filename = 'train' filename = os.path.join(FLAGS.directory, tfrecords_filename + '.tfrecords') filename_queue = tf.train.string_input_producer([filename]) image, label = read_and_decode(filename_queue) return image, label def main(unused_argv): if not os.path.exists('output'): os.mkdir('output') images, labels = inputs() init = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) try: for i in range(6): e, l = sess.run([images, labels]) img = Image.fromarray(e, 'RGB') img.save(os.path.join('output', "{0}-{1}.jpg".format(str(i), l))) finally: coord.request_stop() coord.join(threads) if __name__ == '__main__': tf.app.run()
ポイントは、
tf.train.string_input_producer([filename])
でキューを作成し、
reader = tf.TFRecordReader() reader.read(filename_queue)
で読み込み、
coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) coord.request_stop() coord.join(threads)
で処理をします。
画像を含めたサンプルを上げました。 github.com