【Caffe】CIFAR-10の学習と識別
Caffe付属のサンプルから、CIFAR-10の学習と識別。
The CIFAR-10 dataset
CIFAR-10*1は、一般物体認識のベンチマークとしてよく利用される画像データセット。 80 million tiny imagesのサブセットで、Alex Krizhevsky氏*2らにより作成されている。
特徴
- 32x32 pixelのRGBカラー画像60000枚で構成される
- 全10クラス、各クラス6000枚
- クラスラベルは0~9
- 順に airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck
- 学習画像50000枚、テスト画像10000枚に分割される
- クラスラベルは相互排他的
- データファイルはピクセルの配列で、Python向け、Matlab向け、バイナリ形式の3種フォーマットで提供されている
データセットの取得
公式サイトからダウンロードし解凍するか、サンプルスクリプトcaffe/data/cifar10/get_cifar10.sh
で取得できる。
一例として、バイナリ形式のデータセットは次のファイルで構成されている(他フォーマットでも同じ構成)。
- batches_meta.txt:クラスラベルのリスト
- data_batch_1(~5).bin:学習画像(バッチ1)
- data_batch_2.bin:学習画像(バッチ2)
- data_batch_3.bin:学習画像(バッチ3)
- data_batch_4.bin:学習画像(バッチ4)
- data_batch_5.bin:学習画像(バッチ5)
- test_batch.bin:テスト画像
caffe/examples/cifar10/create_cifar10.sh
でLMDBへの変換し、平均画像を取得する。
- cifar10_train_lmdb:学習画像
- cifar10_test_lmdb:テスト画像
- mean.binaryproto:学習データの平均値
学習
caffe/examples/cifar10
以下の次のファイルを使用し、データセットの学習を行う(Momentum SGD、4000回)。
- cifar10_quick_train_test.prototxt
- cifar10_quick_solver.prototxt
学習するネットワークは次のような5層*3構成となっている。
$ ./examples/cifar10/train_quick.sh ... I0418 00:47:51.126336 6592 solver.cpp:44] Initializing solver from parameters: test_iter: 100 test_interval: 500 base_lr: 0.001 display: 100 max_iter: 4000 lr_policy: "fixed" momentum: 0.9 weight_decay: 0.004 snapshot: 4000 ... I0418 00:55:23.891779 1800 solver.cpp:397] Test net output #0: accuracy = 0.7497 I0418 00:55:23.891779 1800 solver.cpp:397] Test net output #1: loss = 0.738206 (* 1 = 0.738206 loss)
テストデータに対する精度は75 %程度となった。
識別
学習した重みを利用し、テストデータを識別するPythonスクリプトを以下に示す。
次の流れで処理を実施する。
- データセットからの画像抽出
- 識別器の生成
- 識別と結果の出力
CIFAR-10のデータからの画像の抽出、Python APIの使い方については、次の記事を参考にした。
""" classify_cifar10.py """ import sys import os import random import pickle import numpy as np import matplotlib.pyplot as plt import caffe def unpickle(f): """ pickleオブジェクトのロード """ with open(f,'rb') as fo: d = pickle.load(fo, encoding='bytes') return d d = unpickle("./cifar-10-batches-py/test_batch") # データとラベルの取得 data = d[b'data'] labels = d[b'labels'] # ラベル (0~9) に対応する文字列 label_str = [ "airplane", "automoblie", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck" ] # バイトデータを32x32[pix] RGB画像単位へreshapeし、transpose()により次元を変更 # (num, channel, height, width) -> (num, height, width, channel) img = data.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype('uint8') # 識別器を構成 mu = np.load("./mean.npy") model = "../examples/cifar10/cifar10_quick.prototxt" weights = "../examples/cifar10/cifar10_quick_iter_4000.caffemodel" net = caffe.Classifier(model, weights, mean=mu, image_dims=(32, 32)) # 正解数 corrects = 0 # 入力数 num_input = 100 print("------------------------------") for i in range(num_input): # ランダムなインデックスを取得 idx = random.randrange(10000) # 画像の表示(0.1[sec]):時間がかかるのでコメントアウト #plt.imshow(img[idx]) #plt.pause(0.1) #plt.clf() # 識別 pred = net.predict([img[idx]], False)[0].argmax() label = labels[idx] # 判定 if pred == label: corrects += 1 msg = "OK" else: msg = "NG" print("#{:2d} [{}] pred.: {}({}) label: {}({})".format(i, msg, pred, label_str[pred], label, label_str[label])) print("------------------------------") # 精度の表示 print("accuracy: {}[%]".format(corrects / num_input * 100))
出力結果は次のようになった。
$ python classify_cifar10.py ... ------------------------------ # 0 [NG] pred.: 7(horse) label: 6(frog) # 1 [OK] pred.: 3(cat) label: 3(cat) # 2 [OK] pred.: 4(deer) label: 4(deer) ... #97 [OK] pred.: 1(automoblie) label: 1(automoblie) #98 [OK] pred.: 0(airplane) label: 0(airplane) #99 [OK] pred.: 7(horse) label: 7(horse) ------------------------------ accuracy: 74.0[%]
おおむね検証時と同じ精度を達成できている。