技術とか戦略とか

IT技術者が技術や戦略について書くブログです。

Python:scikit-learnのニューラルネットワークを試してみた(テストデータ付き)

何番煎じかわかりませんが、scikit-learnというライブラリを用いて、ニューラルネットワークのプログラミングを試してみました。
「トレーニングデータで学習→テストデータで学習結果を確認→学習したニューラルネットワークを保存→学習したニューラルネットワークをロード→ロードしたニューラルネットワークを使って処理する」という一連の流れを試しています。
 
今回は、20~24歳の体力測定結果から、男性か女性かを切り分けるというのを試してみました。
 
プログラムを作るにあたり、主に以下のページを参考にさせていただきました。
 
Windowsでscikit-learn(sklearn)をインストールしてirisの予測をサクッとするまで - SuprSonicJetBoy's blog

http://blog.suprsonicjetboy.com/entry/2017/08/30/171308

 
scikit-learnのディープラーニング実装簡単すぎワロタ  新規事業のつくり方

http://aiweeklynews.com/archives/50172518.html

 
scikit-learnで学習した分類器をjoblib.dumpで保存するときはcompressをTrueにするとファイルが一つにまとまって便利 - 洋食の日記

https://yoshoku.hatenablog.com/entry/2017/03/16/003000

 
ちなみに、今回試したニューラルネットワークの他に、線形回帰や決定木等もscikit-learnを用いて行うことができます。
ライブラリを使うだけで統計解析ができるなんて良い時代になりましたね。
もちろん、使いこなすには、統計的な知識やライブラリの知識が必要になりますが…。
 
【プログラム】
・AITest.py
# scikit-learnから必要な関数をインポート
# 事前にPythonとAnacondaをインストールしておく
# ライブラリのパスは各自の環境に合わせて変更する
# importで落ちる場合はコマンドプロンプトで下記を実行
# 「pip uninstall numpy」→「pip install numpy」
# 「pip uninstall scipy」→「pip install scipy」
# 「pip uninstall scikit-learn」→「pip install scikit-learn」
import sys
sys.path.append("C:/Anaconda3/Anaconda3/Lib/site-packages")
from sklearn.neural_network import MLPClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix
from sklearn.externals import joblib
import numpy as np

# ファイルから入力値と正解値を取得し、リストへ格納
# ファイルはCSV形式
# 1項目目:入力値…握力(kg)
# 2項目目:入力値…上体起こし(回)
# 3項目目:入力値…長座体前屈(cm)
# 4項目目:入力値…立ち幅とび(cm)
# 5項目目:正解値…性別(男性は1、女性は2)
x_data = # 入力値のリスト
y_data =
# 正解値のリスト
file = open("C:/tmp/physical_data.txt","r")
line = file.readline()
while line:
    param = line.rstrip().split(",")
    input = param[:4] # 入力値(1~4項目目)
    answer = param[4] # 正解値(5項目目)
    x_data.append(input)
    y_data.append(answer)
    line = file.readline()
file.close()

# 学習データとテストデータに分割
# (トレーニングデータ8割、テストデータ2割)
x_train, x_test, y_train, y_test = train_test_split(
    np.array(x_data, dtype=np.int32),
    np.array(y_data, dtype=np.int32),
    test_size=0.2
)

# ニューラルネットワーク生成
clf = MLPClassifier(
    max_iter=1000
)

# トレーニングデータで学習
clf.fit(x_train,y_train)

# テストデータで学習結果をテスト
result = clf.predict(x_test)
print("■テスト結果の見方")
print("     出力値")
print("       1  2")
print(" 正解値 1")
print("     2")
print("■テスト結果")
print(confusion_matrix(y_test,result))

# 学習結果を保存
# (上手く行った学習結果を使い回すため)
joblib.dump(clf,'C:/tmp/physical_learn.dmp',compress=True)

# 実務ではここから下は別のpyファイルになる
# 学習結果をロード
clf2 = joblib.load('C:/tmp/physical_learn.dmp')

# ロードした学習結果を用いて性別を切り分けてみる
x_test2 = np.array([[54,30,42,228],[29,21,43,164]])
result2 = clf2.predict(x_test2)
print("■上手く行った学習結果をロードして試した結果")
print("出力値:",result2[0],"、正解値:1")
print("出力値:",result2[1],"、正解値:2")


【テスト結果の一例】

■テスト結果の見方
     出力値
       1  2
 正解値 1
     2
■テスト結果
[[18  0]
 [ 1 21]]
■上手く行った学習結果をロードして試した結果
出力値: 1 、正解値:1
出力値: 2 、正解値:2

Process finished with exit code 0


【参考:入力データ】

 下記のページに記載されている20~24歳の平均値から、ランダムに±20%した値をテストデータとしています。
件数は、男性100件、女性100件の計200件です。
 
あなたはできる?これが年代別「体力測定」の平均値|「マイナビウーマン」

https://woman.mynavi.jp/article/140709-32/

 
・C:/tmp/physical_data.txt
47,27,52,231,1
30,20,51,163,2
44,25,44,204,1
29,20,41,167,2
45,29,44,255,1
29,20,45,168,2
53,34,44,212,1
27,21,43,190,2
51,27,54,209,1
25,18,45,179,2
44,27,44,231,1
29,20,45,165,2
53,30,47,239,1
28,24,45,144,2
50,29,50,233,1
30,18,43,167,2
46,24,48,238,1
27,18,44,153,2
44,27,46,225,1
31,22,42,169,2
50,29,40,240,1
27,21,51,170,2
47,30,40,203,1
30,21,48,154,2
45,29,43,224,1
26,19,40,166,2
49,27,52,212,1
31,22,44,162,2
42,30,41,257,1
25,20,48,148,2
49,25,41,212,1
26,20,47,138,2
48,27,41,258,1
32,20,49,177,2
52,28,50,239,1
32,20,47,175,2
46,26,51,223,1
26,21,43,182,2
47,28,46,230,1
29,19,48,156,2
48,31,49,204,1
27,17,43,157,2
40,29,50,246,1
26,24,47,189,2
42,26,46,230,1
28,23,43,152,2
49,30,51,201,1
30,21,41,157,2
49,29,50,221,1
28,21,45,151,2
45,31,51,256,1
29,21,38,159,2
50,28,43,264,1
28,18,52,166,2
48,31,40,240,1
32,23,51,159,2
41,25,40,199,1
25,21,48,154,2
53,31,48,218,1
32,21,47,167,2
50,26,50,229,1
29,21,43,168,2
42,27,43,193,1
32,20,39,181,2
53,33,45,220,1
30,22,52,148,2
53,27,54,226,1
26,21,39,173,2
48,26,46,230,1
27,19,45,158,2
52,28,50,196,1
31,21,41,137,2
43,23,39,234,1
30,21,46,172,2
43,28,51,247,1
26,20,51,187,2
46,29,45,225,1
25,20,47,168,2
44,28,42,227,1
25,20,45,182,2
45,25,44,236,1
24,18,45,143,2
54,30,43,231,1
27,20,52,186,2
51,30,40,202,1
27,23,52,174,2
41,30,45,208,1
28,21,48,181,2
48,32,47,201,1
27,20,41,164,2
48,32,45,228,1
28,19,42,170,2
44,32,41,243,1
30,19,50,178,2
47,27,49,217,1
26,23,52,151,2
45,26,46,215,1
28,21,47,174,2
47,33,44,203,1
27,23,42,158,2
44,33,41,200,1
29,21,42,165,2
46,30,49,242,1
29,20,49,166,2
50,30,47,251,1
28,23,45,164,2
46,30,49,215,1
28,17,48,169,2
43,28,46,234,1
29,23,48,157,2
48,31,41,199,1
31,22,47,185,2
46,25,41,237,1
26,21,45,150,2
42,31,48,227,1
27,19,46,155,2
48,27,47,197,1
29,20,43,168,2
43,26,53,245,1
24,22,53,164,2
51,26,50,195,1
27,19,37,165,2
46,25,43,243,1
31,21,54,160,2
55,32,46,200,1
29,18,41,170,2
49,27,45,224,1
27,24,38,158,2
48,28,42,210,1
27,19,48,176,2
48,28,45,189,1
29,21,44,161,2
45,29,52,235,1
32,22,39,147,2
50,25,47,198,1
30,18,48,167,2
38,29,43,246,1
28,21,46,187,2
54,31,39,189,1
29,23,50,178,2
40,28,42,264,1
27,24,41,139,2
50,31,44,198,1
27,21,48,197,2
43,30,52,246,1
25,20,44,141,2
52,31,48,244,1
27,20,49,182,2
48,31,46,258,1
31,22,51,144,2
47,29,51,239,1
28,24,45,177,2
46,29,50,214,1
30,19,49,155,2
43,29,50,227,1
24,22,51,164,2
45,30,51,248,1
28,17,48,154,2
47,30,48,227,1
25,19,46,188,2
46,25,43,213,1
31,22,52,172,2
43,25,47,192,1
26,25,51,155,2
41,28,54,265,1
33,24,47,164,2
47,28,49,229,1
23,25,50,158,2
45,28,54,221,1
30,20,51,173,2
46,24,48,261,1
24,20,42,150,2
46,25,43,219,1
27,20,43,155,2
43,30,45,255,1
27,18,44,143,2
44,25,48,189,1
27,19,46,165,2
43,32,41,246,1
29,21,43,180,2
44,24,43,231,1
28,22,39,156,2
46,24,46,199,1
28,22,46,171,2
46,32,47,220,1
32,20,50,186,2
46,25,42,246,1
25,19,50,186,2
44,28,43,210,1
31,22,44,161,2
51,29,43,244,1
30,23,45,155,2
50,29,50,219,1
31,22,42,170,2
46,29,48,217,1
23,21,43,160,2
53,25,45,219,1
24,20,47,153,2
40,28,43,205,1
32,20,44,142,2


【参考:入力データ(余談)】
後日知ったのですが、機械学習を紹介しているページに良く出てくるアヤメの分類のテストデータって有名なやつだったのですね。
自作せずにこれ使えば良かったですね…。
 
機械学習とは - アイリスデータ

http://home.a00.itscom.net/hatada/ml/data/iris01.html