SSJB's blog

いろいろです。

Windowsでscikit-learn(sklearn)をインストールしてirisの予測をサクッとするまで

Windows機械学習周りの環境を再構築したので、まっさらな状態から、scikit-learn (sklearn) を導入し、Random Forestでirisを予測するところまでを行います。 github.com

環境

Python のインストール

Pythonインストーラーからサクッとインストールできます。
基本的にはコマンドプロンプトからPythonを実行しますので、インストール時に「Add Python to PATH」にチェックを入れると、インストール完了後、すぐにコマンドプロンプトからPythonが使えるようになります。

最新は、v3.6.2 ですが、後にTensorFlowをインストールするつもりで、TensorFlowはv3.5.xが推奨なので、v3.5.2を入れました。 www.python.org

C:\>python --version
Python 3.5.2

scikit-learn以外のパッケージのインストール

scikit-learnは、NumPyとSciPyが必要なのですが、まっさきにpipでインストールすると、SciPyのインストールでコケてしまいます。
そのため、非公式とはなりますが、パッケージのWindows用のバイナリが配布されていますので、それを使用します。

Python Extension Packages for Windows - Christoph Gohlke

既にpipでインストールしてしまった場合は、削除から。

C:\>pip uninstall numpy
C:\>pip uninstall scipy

上記サイトから必要なバイナリをダウンロードします。
ダウンロードするものは、環境によって決めます。
たとえば、NumPyであれば、Python v3.5.2で64bitの環境なのでnumpy‑1.13.1+mkl‑cp35‑cp35m‑win_amd64.whlをダウンロードします。

一先ず必要なものは、

  • NumPy (numpy+mkl)
  • SciPy の2つで、必ず NumPy を最初にインストールします。

pip installで先程ダウンロードしたファイルを指定します。

C:\>pip install numpy-1.13.1+mkl-cp35-cp35m-win_amd64.whl
C:\>pip install scipy-0.19.1-cp35-cp35m-win_amd64.whl

scikit-learn のインストール

あとは普通にpipで入れます。
ついでに、pandasもいれておきます。

C:\>pip install scikit-learn
C:\>pip install pandas

Random Forest でiris予測

sklearnには、いくつかデータセットが用意されています。
今回は、iris のデータセットを使用して、予測を行っていきます。

sklearn.datasets.load_iris — scikit-learn 0.19.0 documentation
のように簡単に読み込めますが、他のデータにも応用を効かせたいので、直接読み込むのではなく、訓練とテストにぶんりさせて、Pandasで読み込んで実行してみます。

CSVファイルはこちら。
scikit-learn/iris.csv at master · scikit-learn/scikit-learn · GitHub

ヘッダーが

150,4,setosa,versicolor,virginica

となっていますので、特徴名の

sepal_length,sepal_width,petal_length,petal_width,species

とします。
species は、[0,1,2]のいずれかで、名称は[setosa, versicolor, virginica] が対応します。

次に手動で、訓練ファイルとテストファイルに分離します。
ヘッダーを含めて、121行目までを訓練データ、その他をテストデータとします。
テストデータにもヘッダーを入れます。

コードはこんな感じです。

import os
import numpy as np
import pandas as pd
import time
from sklearn import ensemble, externals

def main():
    START_TIME = time.time()

    SPECIES = ['setosa', 'versicolor', 'virginica']
    FEATURES = [
        'sepal_length',
        'sepal_width',
        'petal_length',
        'petal_width'
    ]
    LABEL = 'species'

    TRAIN_FILE = "./dataset/iris_training.csv"
    TEST_FILE = "./dataset/iris_test.csv"

    OUTPUT_DIR = "./models"
    OUTPUT_FILE = "{0}/iris-model.ckpt".format(OUTPUT_DIR)

    # Dataset
    training_dataset = pd.read_csv(TRAIN_FILE)
    test_dataset = pd.read_csv(TEST_FILE)

    # Shuffle test dataset
    #test_dataset= test_dataset.sample(frac=1).reset_index(drop=True)


    # Training data
    train_x = np.array(training_dataset[FEATURES].astype(np.float32))
    train_y = np.array(training_dataset[LABEL].astype(np.float32))

    # Test data
    test_x = np.array(test_dataset[FEATURES].astype(np.float32))
    test_y = np.array(test_dataset[LABEL].astype(np.float32))


    print("\n---------- INFORMATION ----------")
    print("FEATURES       : {0}".format(FEATURES))
    print("LABEL          : {0}".format(LABEL))
    print("TRAINING FILE  : {0}".format(TRAIN_FILE))
    print("TRAINING DATA  : {0}".format(len(training_dataset)))
    print("TEST FILE      : {0}".format(TEST_FILE))
    print("TEST DATA      : {0}".format(len(test_dataset)))
    print("OUTPUT         : {0}".format(OUTPUT_FILE))


    print("\n------------ ALGORITHM ------------")
    print("Randam Forest")

    # Random forest
    model = ensemble.RandomForestClassifier()

    # Reconstruct the model 
    # model = externals.joblib.load(OUTPUT_FILE)

    model.fit(train_x, train_y)
    importances = model.feature_importances_

    print("\n------------ IMPORTANCES ------------")
    for i in range(len(FEATURES)):
        print("{0}: {1} ".format(FEATURES[i].ljust(15), importances[i]))


    result = model.predict_proba(test_x)
    print("\n------------ PREDICTION 1 ------------")
    print(result[:10])


    result = model.predict(test_x)
    print("\n------------ PREDICTION 2 ------------")
    print(result[:10])


    print("\n------------ PREDICTION 3 ------------")
    for i in range(10):
        if int(result[i]) == int(test_y[i]):
            print("True  => {1}   {0}  @@ {2}".format(int(result[i]), int(test_y[i]), SPECIES[int(test_y[i])]))
        else:
            print("False => {1}   {0}  @@ {2}".format(int(result[i]), int(test_y[i]), SPECIES[int(test_y[i])]))


    print("\n------------ SCORE ------------")
    print(model.score(test_x, test_y))


    # Persist the model 
    joblib = externals.joblib
    if not os.path.exists(OUTPUT_DIR):
        os.mkdir(OUTPUT_DIR)
    joblib.dump(model, "{0}".format(OUTPUT_FILE))

    print("\n--------------------------------")
    print("Time: {0} sec".format(round(time.time() - START_TIME, 3)))
    print("--------------------------------")


if __name__ == '__main__':
    main()

このようにデータを分離させて、特徴で指定するようにできると、正規化等必要ですが、タイタニックなどにも使いまわせます。

実行すると、こんな感じになります。

---------- INFORMATION ----------
FEATURES       : ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']
LABEL          : species
TRAINING FILE  : ./dataset/iris_training.csv
TRAINING DATA  : 120
TEST FILE      : ./dataset/iris_test.csv
TEST DATA      : 30
OUTPUT         : ./models/iris-model.ckpt

------------ ALGORITHM ------------
Randam Forest

------------ IMPORTANCES ------------
sepal_length   : 0.03482778065132767
sepal_width    : 0.017052383871736687
petal_length   : 0.3757864060484355
petal_width    : 0.5723334294285002

------------ PREDICTION 1 ------------
[[ 0.   0.2  0.8]
 [ 0.   0.9  0.1]
 [ 1.   0.   0. ]
 [ 1.   0.   0. ]
 [ 0.   1.   0. ]
 [ 0.   0.   1. ]
 [ 0.   1.   0. ]
 [ 0.   1.   0. ]
 [ 0.   1.   0. ]
 [ 0.   0.6  0.4]]

------------ PREDICTION 2 ------------
[ 2.  1.  0.  0.  1.  2.  1.  1.  1.  1.]

------------ PREDICTION 3 ------------
True  => 2   2  @@ virginica
True  => 1   1  @@ versicolor
True  => 0   0  @@ setosa
True  => 0   0  @@ setosa
True  => 1   1  @@ versicolor
True  => 2   2  @@ virginica
True  => 1   1  @@ versicolor
True  => 1   1  @@ versicolor
True  => 1   1  @@ versicolor
False => 2   1  @@ virginica

------------ SCORE ------------
0.833333333333

--------------------------------
Time: 0.056 sec
--------------------------------

github.com