ポケモントレーナーのためにポケモン図鑑を作る①
ポケモン図鑑をこれから作っていきたいと思います。まずはポケモンの判別プログラム(CNN)から作っていきます。
次回
とりあえず初代の御三家から
いきなり数百種類のポケモンを判別するのは大変そうだし、何よりデータセットを作るのが大変。なので、御三家(ヒトカゲ、フシギダネ、ゼニガメ)+ピカチュウに絞ります。
ポケモントレーナーが、最初に渡されるポケモンを知らないかもしれませんしね。
使用した環境
今回は以下の環境で学習しました。
お気づきだろうか。
友達に言ったら「へー、Core2quad使ってるんだ(笑)」と冷笑されるほどの糞スペック。
「機械学習をやってみた系」のPC構成では見たことがないぐらいの前世代PC。でも、まだ戦える!!
データの準備
調べてみると、Deeplearningは数千枚オーダーの画像がそれぞれ必要なようです。ですが、数千枚はさすがに面倒なので、Google画像検索でヒットした画像各200,300枚程度をかさ増し(約6倍)して正方形の画像に加工し使っています。
学習
学習ライブラリはTensorflowベースのKerasを用いました。もちろんCPUモードでの学習になります。逆にこの構成のPCって最新のグラフィックボードを増設できるのだろうか。
学習モデルはCNNで、
Keras : 画像分類 : AlexNet – PyTorchのAlexNetを元に使用させていただきました。
AlexNetって本当に優秀ですね。自作のCNNモデルだと全然学習が進まないのに、AlexNetだとどんどん学習が進む。
コードが汚くてすいません。もっと短く書けそう。。
import keras from keras.datasets import mnist from keras.models import Sequential from keras.layers import Dense, Dropout, Flatten from keras.layers import Conv2D, MaxPooling2D from keras.layers import BatchNormalization from keras import backend as K import numpy as np import glob import cv2 import random import math batch_size = 100 num_classes = 4 epochs = 20 img_rows, img_cols = 224, 224 x_train = [] y_train = [] x_test = [] y_test = [] def makeData(path,ans): img_src = cv2.imread(path) dst = img_src[:,:,0] dst = cv2.resize(dst,(img_rows, img_cols)) dst2 = img_src[:,:,1] dst2 = cv2.resize(dst2,(img_rows, img_cols)) dst3 = img_src[:,:,2] dst3 = cv2.resize(dst3,(img_rows, img_cols)) x_train.append([ dst.tolist(), dst2.tolist(), dst3.tolist()] ) y_train.append(ans) label = -1 for i in glob.glob('train/*'): if 'not_use' in i: continue if '.' not in i: label += 1 print(i + "のラベル" + str(label)) for j in glob.glob(i + '/*'): if '.jpg' in j: makeData( j, label) for i in range(200): num = int(random.random() * len(x_train)) x_test.append(x_train.pop(num)) y_test.append(y_train.pop(num)) x_train = np.array(x_train) y_train = np.array(y_train) x_test = np.array(x_test) y_test = np.array(y_test) if K.image_data_format() == 'channels_first': x_train = x_train.reshape(x_train.shape[0],5,img_rows,img_cols) x_test = x_test.reshape(x_test.shape[0], 5, img_rows, img_cols) input_shape = (5, img_rows, img_cols) else: x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 3) x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 3) input_shape = (img_rows, img_cols, 3) x_train = x_train.astype('float32') x_test = x_test.astype('float32') x_train /= 255 x_test /= 255 y_train = keras.utils.to_categorical(y_train, num_classes) y_test = keras.utils.to_categorical(y_test, num_classes) model = Sequential() model.add(Conv2D(48, 11, strides=3,input_shape=input_shape, activation='relu', padding='same')) model.add(MaxPooling2D(3, strides=2)) model.add(BatchNormalization()) model.add(Conv2D(128, 5, strides=3, activation='relu', padding='same')) model.add(MaxPooling2D(3, strides=2)) model.add(BatchNormalization()) model.add(Conv2D(192, 3, strides=1, activation='relu', padding='same')) model.add(Conv2D(192, 3, strides=1, activation='relu', padding='same')) model.add(Conv2D(128, 3, strides=1, activation='relu', padding='same')) model.add(MaxPooling2D(3, strides=2)) model.add(BatchNormalization()) model.add(Flatten()) model.add(Dense(2048, activation='relu')) model.add(Dropout(0.5)) model.add(Dense(2048, activation='relu')) model.add(Dropout(0.5)) model.add(Dense(num_classes, activation='softmax')) model.compile(loss=keras.losses.categorical_crossentropy,optimizer=keras.optimizers.Adadelta(),metrics=['accuracy']) model.fit(x_train, y_train,batch_size=batch_size,epochs=epochs,verbose=1,validation_data=(x_test, y_test)) model_json_str = model.to_json() open('cnn.json', 'w').write(model_json_str) model.save_weights('cnn.h5'); score = model.evaluate(x_test, y_test, verbose=0) print('Test loss:', score[0]) print('Test accuracy:', score[1])
結果
実行結果は以下のようになりました。
Using TensorFlow backend. train\hitokageのラベル0 train\hushigidaneのラベル1 train\pikatyuのラベル2 train\zenigameのラベル3 (224, 224, 3) (6556, 224, 224, 3) x_train shape: (6556, 224, 224, 3) 6556 train samples 200 test samples Train on 6556 samples, validate on 200 samples Epoch 1/25 6556/6556 [==============================] - 412s - loss: 1.0334 - acc: 0.5961 - val_loss: 1.4738 - val_acc: 0.3400 Epoch 2/25 6556/6556 [==============================] - 408s - loss: 0.6979 - acc: 0.7088 - val_loss: 1.1469 - val_acc: 0.4150 Epoch 3/25 6556/6556 [==============================] - 409s - loss: 0.6265 - acc: 0.7431 - val_loss: 0.9915 - val_acc: 0.5400 Epoch 4/25 6556/6556 [==============================] - 407s - loss: 0.5694 - acc: 0.7677 - val_loss: 0.7536 - val_acc: 0.6650 Epoch 5/25 6556/6556 [==============================] - 409s - loss: 0.5232 - acc: 0.7784 - val_loss: 0.7851 - val_acc: 0.6850 Epoch 6/25 6556/6556 [==============================] - 405s - loss: 0.4937 - acc: 0.7979 - val_loss: 1.2362 - val_acc: 0.5800 Epoch 7/25 6556/6556 [==============================] - 403s - loss: 0.4510 - acc: 0.8208 - val_loss: 0.6217 - val_acc: 0.7450 Epoch 8/25 6556/6556 [==============================] - 403s - loss: 0.4028 - acc: 0.8388 - val_loss: 0.5831 - val_acc: 0.7650 Epoch 9/25 6556/6556 [==============================] - 408s - loss: 0.3928 - acc: 0.8411 - val_loss: 0.4908 - val_acc: 0.8300 Epoch 10/25 6556/6556 [==============================] - 409s - loss: 0.3528 - acc: 0.8565 - val_loss: 1.6967 - val_acc: 0.5750 Epoch 11/25 6556/6556 [==============================] - 407s - loss: 0.3120 - acc: 0.8763 - val_loss: 0.5483 - val_acc: 0.8050 Epoch 12/25 6556/6556 [==============================] - 409s - loss: 0.2787 - acc: 0.8868 - val_loss: 1.0185 - val_acc: 0.7200 Epoch 13/25 6556/6556 [==============================] - 408s - loss: 0.2538 - acc: 0.9012 - val_loss: 2.0394 - val_acc: 0.6150 Epoch 14/25 6556/6556 [==============================] - 406s - loss: 0.2131 - acc: 0.9137 - val_loss: 0.6719 - val_acc: 0.7850 Epoch 15/25 6556/6556 [==============================] - 402s - loss: 0.2005 - acc: 0.9218 - val_loss: 1.2824 - val_acc: 0.6750 Epoch 16/25 6556/6556 [==============================] - 403s - loss: 0.1903 - acc: 0.9253 - val_loss: 1.2046 - val_acc: 0.7050 Epoch 17/25 6556/6556 [==============================] - 408s - loss: 0.1525 - acc: 0.9433 - val_loss: 0.7733 - val_acc: 0.7750 Epoch 18/25 6556/6556 [==============================] - 408s - loss: 0.1384 - acc: 0.9475 - val_loss: 1.4794 - val_acc: 0.6550 Epoch 19/25 6556/6556 [==============================] - 409s - loss: 0.1386 - acc: 0.9466 - val_loss: 0.8382 - val_acc: 0.8100 Epoch 20/25 6556/6556 [==============================] - 408s - loss: 0.1092 - acc: 0.9565 - val_loss: 2.3111 - val_acc: 0.5850 Epoch 21/25 6556/6556 [==============================] - 404s - loss: 0.0981 - acc: 0.9660 - val_loss: 0.9810 - val_acc: 0.7800 Epoch 22/25 6556/6556 [==============================] - 402s - loss: 0.1022 - acc: 0.9620 - val_loss: 1.4912 - val_acc: 0.7400 Epoch 23/25 6556/6556 [==============================] - 403s - loss: 0.0797 - acc: 0.9677 - val_loss: 1.0995 - val_acc: 0.8000 Epoch 24/25 6556/6556 [==============================] - 407s - loss: 0.0895 - acc: 0.9657 - val_loss: 0.7260 - val_acc: 0.8050 Epoch 25/25 6556/6556 [==============================] - 409s - loss: 0.0896 - acc: 0.9687 - val_loss: 1.0438 - val_acc: 0.7750 Test loss: 1.04382741451 Test accuracy: 0.775
少ないデータをかさ増ししたからなのか、精度は77.5%とあまりよくないですが、大体分別出来ているみたいです。よかった。
おわりに
コードを貼り付けただけで特に解説等なくてすいません。以上である程度の精度でポケモンを判別できるようになりました。
汚いコードですいません。Deeplearning初心者なので、アドバイスをいただければ幸いです!