気が向いたら書くやつ

気が向いたら何か書きます

【Caffe】CIFAR-10の学習と識別

Caffe付属のサンプルから、CIFAR-10の学習と識別。

The CIFAR-10 dataset

CIFAR-10*1は、一般物体認識のベンチマークとしてよく利用される画像データセット80 million tiny imagesのサブセットで、Alex Krizhevsky氏*2らにより作成されている。

www.cs.toronto.edu

特徴

  • 32x32 pixelのRGBカラー画像60000枚で構成される
  • 全10クラス、各クラス6000枚
  • クラスラベルは0~9
    • 順に airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck
  • 学習画像50000枚、テスト画像10000枚に分割される
  • クラスラベルは相互排他的
  • データファイルはピクセルの配列で、Python向け、Matlab向け、バイナリ形式の3種フォーマットで提供されている

データセットの取得

公式サイトからダウンロードし解凍するか、サンプルスクリプトcaffe/data/cifar10/get_cifar10.shで取得できる。

一例として、バイナリ形式のデータセットは次のファイルで構成されている(他フォーマットでも同じ構成)。

  • batches_meta.txt:クラスラベルのリスト
  • data_batch_1(~5).bin:学習画像(バッチ1)
  • data_batch_2.bin:学習画像(バッチ2)
  • data_batch_3.bin:学習画像(バッチ3)
  • data_batch_4.bin:学習画像(バッチ4)
  • data_batch_5.bin:学習画像(バッチ5)
  • test_batch.bin:テスト画像

caffe/examples/cifar10/create_cifar10.shでLMDBへの変換し、平均画像を取得する。

  • cifar10_train_lmdb:学習画像
  • cifar10_test_lmdb:テスト画像
  • mean.binaryproto:学習データの平均値

学習

caffe/examples/cifar10以下の次のファイルを使用し、データセットの学習を行う(Momentum SGD、4000回)。

  • cifar10_quick_train_test.prototxt
  • cifar10_quick_solver.prototxt

学習するネットワークは次のような5層*3構成となっている。

f:id:soratobi96:20190420225043p:plain

$ ./examples/cifar10/train_quick.sh
...
I0418 00:47:51.126336  6592 solver.cpp:44] Initializing solver from parameters:
test_iter: 100
test_interval: 500
base_lr: 0.001
display: 100
max_iter: 4000
lr_policy: "fixed"
momentum: 0.9
weight_decay: 0.004
snapshot: 4000
...
I0418 00:55:23.891779  1800 solver.cpp:397]     Test net output #0: accuracy = 0.7497
I0418 00:55:23.891779  1800 solver.cpp:397]     Test net output #1: loss = 0.738206 (* 1 = 0.738206 loss)

テストデータに対する精度は75 %程度となった。

識別

学習した重みを利用し、テストデータを識別するPythonスクリプトを以下に示す。

次の流れで処理を実施する。

  • データセットからの画像抽出
  • 識別器の生成
  • 識別と結果の出力

CIFAR-10のデータからの画像の抽出、Python APIの使い方については、次の記事を参考にした。

qiita.com

qiita.com

"""
classify_cifar10.py
"""
import sys
import os
import random
import pickle
import numpy as np
import matplotlib.pyplot as plt
import caffe

def unpickle(f):
    """
    pickleオブジェクトのロード
    """
    with open(f,'rb') as fo:
        d = pickle.load(fo, encoding='bytes')
    return d

d = unpickle("./cifar-10-batches-py/test_batch")

# データとラベルの取得
data   = d[b'data']
labels = d[b'labels']

# ラベル (0~9) に対応する文字列
label_str = [ "airplane", "automoblie", "bird", "cat", "deer",
              "dog", "frog", "horse", "ship", "truck" ]

# バイトデータを32x32[pix] RGB画像単位へreshapeし、transpose()により次元を変更
# (num, channel, height, width) -> (num, height, width, channel)
img = data.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype('uint8')

# 識別器を構成
mu = np.load("./mean.npy")

model   = "../examples/cifar10/cifar10_quick.prototxt"
weights = "../examples/cifar10/cifar10_quick_iter_4000.caffemodel"
net     = caffe.Classifier(model, weights, mean=mu, image_dims=(32, 32))

# 正解数
corrects = 0
# 入力数
num_input = 100

print("------------------------------")

for i in range(num_input):
    # ランダムなインデックスを取得
    idx = random.randrange(10000)

    # 画像の表示(0.1[sec]):時間がかかるのでコメントアウト
    #plt.imshow(img[idx])
    #plt.pause(0.1)
    #plt.clf()

    # 識別
    pred  = net.predict([img[idx]], False)[0].argmax()
    label = labels[idx]

    # 判定
    if pred == label:
        corrects += 1
        msg = "OK"
    else:
        msg = "NG"

    print("#{:2d} [{}] pred.: {}({}) label: {}({})".format(i, msg,
        pred, label_str[pred], label, label_str[label]))

print("------------------------------")
# 精度の表示
print("accuracy: {}[%]".format(corrects / num_input * 100))

出力結果は次のようになった。

$ python classify_cifar10.py
...
------------------------------
# 0 [NG] pred.: 7(horse) label: 6(frog)
# 1 [OK] pred.: 3(cat) label: 3(cat)
# 2 [OK] pred.: 4(deer) label: 4(deer)
...
#97 [OK] pred.: 1(automoblie) label: 1(automoblie)
#98 [OK] pred.: 0(airplane) label: 0(airplane)
#99 [OK] pred.: 7(horse) label: 7(horse)
------------------------------
accuracy: 74.0[%]

おおむね検証時と同じ精度を達成できている。

*1:Canadian Institute For Advanced Research。Wikipedia

*2:AlexNetの人

*3:一般的な数え方として、重みを持つレイヤのみ数えている