銀の弾丸

プログラミングに関して、いろいろ書き残していければと思っております。

線形回帰で「アヤメ(iris)」の分類

f:id:takamints:20171015205935j:plain

octaveを使って線形回帰で「アヤメの分類」をやってみました。

なんで急にこんなことやり始めたかっていうと、先日、会社の方に初めて機械学習関連のお仕事が舞い込んできたようでして、自分は担当外なのですが「今の時代、それぐらいできなアカンやろーwww」とか、いつもの調子でついハッタリをかましてしまって(笑)・・・、
不安になって自宅で機械学習のコソ練せざるを得ないという。いくつになってもお勉強です(泣)

「線形回帰てなんやねん」については以下の記事で、ごく初歩の概念を説明してます。
takamints.hatenablog.jp

お仕事の方は、Chainer/XGBoost/Pythonという条件がついているようですが、我が師と(勝手に)仰いでいる Andrew Ng 先生は「最初は octave(matlab) でやりなはれや」とおっしゃられていたので、手元の octave でやってみましたという次第。

ということで、以降のスクリプトを実行するには octave または matlabが必要です。 データセットのダウンロードと変換に curl と Nodeを使っていますが、手動でやるなら不要ですよ。

octaveWindowsへの導入は、以下の記事を参考に。

takamints.hatenablog.jp

Chainer v2による実践深層学習
新納 浩幸
オーム社
売り上げランキング: 4,654
MATLABとOctaveによる科学技術計算

丸善出版
売り上げランキング: 130,528

まずはアヤメ(iris)をダウンロードして変換します

機械学習のためのオープンな訓練データとして、アヤメ(iris)というのがあるらしいと、この度初耳。

データセットは150件と小さくて、4つのパラメータで3種類のアヤメを同定するというものです。

上記のファイルをダウンロードするため以下のシェルスクリプトを書きました。

get-iris-csv.sh

#!/bin/sh
echo Downloading iris.data and iris.names.
curl --silent https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data > iris.data
curl --silent https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.names > iris.names
echo
echo Converting iris.data CSV.
echo
node conv-iris-csv.js

iris.dataをダウンロードしてCSVを変換しています(変換については次項↓参照)。

bashで実行するとこうなります。

$ ./get-iris-csv.sh
Downloading iris.data and iris.names.

Converting iris.data CSV.

The class values in column 5:
0: Iris-setosa
1: Iris-versicolor
2: Iris-virginica

CSVの5列目を数値に変換

iris.data の5列目はアヤメの種類の名前であり、文字列です。 しかし、octave の csvread 関数では数値しか読み込めません。数学的な行列ですからね。

そこで以下のNodeスクリプト。 iris.data を読み込んで、全行の5列目の文字列を 0,1,2 の数値に変換。 iris.csv に出力してます(おっと、エラーを一切見ていませんが良い子は絶対マネしないでw)。

conv-iris-csv.js

#!/usr/bin/env node
var inputFile = "iris.data";
var outputFile = "iris.csv";
var fs = require("fs");
fs.readFile(inputFile, "utf-8", function(err, data) {
    var indexOfName = {};
    fs.writeFile(outputFile, data.split(/\r*\n/).map(function(row, indexRow) {
        if(row == "") {
            return "";
        }
        return row.split(",").map(function(column, indexCol) {
            if(indexCol < 4) {
                return column;
            }
            if(!(column in indexOfName)) {
                indexOfName[column] = Object.keys(indexOfName).length;
            }
            return indexOfName[column];
        }).join(",");
    }).join("\n"),
    function(err) {
        var nameIndex = new Array(Object.keys(indexOfName).length);
        Object.keys(indexOfName).forEach(function(key) {
            nameIndex[indexOfName[key]] = "" + indexOfName[key] + ": " + key;
        });
        console.log("The class values in column 5:");
        console.log(nameIndex.join("\n"));
    });
});

データについて

iris.namesの「4. Relevant Information」には、 「ひとつは他の2つからリニアに分離できるが、他の2つは無理」と書いてありました(以下1>の行)

また、「データに間違いがある」とも書かれていました(以下2>の行)が、修正しても結果に変化はありませんでした。

4. Relevant Information:
   --- This is perhaps the best known database to be found in the pattern
       recognition literature.  Fisher's paper is a classic in the field
       and is referenced frequently to this day.  (See Duda & Hart, for
       example.)  The data set contains 3 classes of 50 instances each,
1>       where each class refers to a type of iris plant.  One class is
1>       linearly separable from the other 2; the latter are NOT linearly
1>       separable from each other.
   --- Predicted attribute: class of iris plant.
   --- This is an exceedingly simple domain.
   --- This data differs from the data presented in Fishers article
    (identified by Steve Chadwick,  spchadwick@espeedaz.net )
2>   The 35th sample should be: 4.9,3.1,1.5,0.2,"Iris-setosa"
2>   where the error is in the fourth feature.
2>   The 38th sample: 4.9,3.6,1.4,0.1,"Iris-setosa"
2>   where the errors are in the second and third features.  ```

線形回帰のoctave/matlabスクリプト

以下のoctave/matlabスクリプトが今回の御本尊。

  1. CSVを読み込んで、
  2. 学習してから
  3. 検証してます。

本来ならばデータセットの3分の1程度をトレーニングには使わず検証用に残しておくべきなのですが、 データ件数が150件と非常に小さいため、全データでトレーニングして、全データで検証してます。手前味噌な感じですが仕方がない。

iris.m

#
# Classification of iris using linear regression
#

# Clear all mat
clear

D = csvread('iris.csv');        # Load iris.csv
# D = D(randperm(size(D,1)),:); # Sort random

#
# Select training dataset
#
Dt = D(1:size(D, 1), :);    # Dt = D(1:2*size(D, 1)/3, :);
Xt = Dt(:, 1:(size(Dt,2)-1));    # Input
Xt = [ones(size(Xt,1),1) Xt];
Yt = Dt(:, size(Dt,2));      # Result

#
# Select validation dataset
#
Dv = D(1:size(D, 1), :);    # Dv = D(2*size(D, 1)/3+1:size(D,1), :);
Xv = Dv(:, 1:(size(Dv,2)-1));
Xv = [ones(size(Xv,1),1) Xv];
Yv = Dv(:, size(Dv,2));

training_data_count = size(Xt,1);
learning_rate = 0.00001     # learning rate
iteration = 5000000
report_interval = round(iteration / 10);

#
# Training
#
theta = ones(1, size(Xt,2)); # factor
numOfParam = size(theta, 2);
update = [1, numOfParam];
for n = 1:iteration
    diff = Xt * theta' - Yt;
    for j = 1:numOfParam
        update(1,j) = diff' * Xt(:,j);
    endfor
    theta = theta - learning_rate * update / training_data_count;
    if mod(n,report_interval) == 0
        cost = (diff' * diff) / ( 2 * size(Xt,1) )
    endif
endfor

#
# Validation
#
y = Xv * theta';
diff = y - Yv;
#validationResult = [Yv round(y)]

error_count = 0;
for i = 1:size(diff,1)
    diffI = diff(i,1);
    if(diffI * diffI >= 0.5 * 0.5)
        error_index = i
        error_count = error_count + 1;
    endif
endfor
theta
validationErrorRate = error_count / size(Xv, 1)
validationTotalCost = diff' * diff / ( 2 * size(Xv,1) )

結果検証

以下が上記スクリプトの実行結果。 150件中3件の判定間違いがありますが、これが限界みたいです。

learning_rate = 1.0000e-005
iteration =  5000000
cost =  0.026944
cost =  0.025440
cost =  0.024725
cost =  0.024366
cost =  0.024169
cost =  0.024048
cost =  0.023963
cost =  0.023897
cost =  0.023841
cost =  0.023792
error_index =  71
error_index =  84
error_index =  134
theta =

   0.586249  -0.177805  -0.067764   0.244015   0.621859

validationErrorRate =  0.020000
validationTotalCost =  0.023792

(2017-10-21追記①)

データの可視化をしようとしていて、単なる偶然ですが学習回数を少なくできるのを見つけました。 上の実行例では学習を500万回繰り返していますが、1万回で同じ結果を得られます。 learning_rateは1000倍で、学習は数秒で終わります。さらに誤差が若干少ない。

iris
learning_rate =  0.010000
iteration =  10000
cost =  0.025441
cost =  0.024366
cost =  0.024048
cost =  0.023897
cost =  0.023792
cost =  0.023708
cost =  0.023636
cost =  0.023575
cost =  0.023522
cost =  0.023476
error_index =  71
error_index =  84
error_index =  134
theta =

   0.464549  -0.152014  -0.066062   0.234802   0.621481

validationErrorRate =  0.020000
validationTotalCost =  0.023476

(2017-10-21追記①)おわり

別途「Iris-setosa」を、他の2つから完璧に(そしてかなり簡単に)分離できました。 しかし「Iris-versicolor」と「Iris-virginica」はパラメータ平面において数件が領域を共有してるため、線形モデルで分離するのは困難なようです。

これ以上判定精度を挙げるには別の新たなパラメータが必要になるはずです。 そんなこんなで、この分類は成功していると言えそうです。 当初はなんかおかしいと思っていましたが(笑)

結果が離散的な分類問題なので、ロジスティック回帰で解く必要があるのかも?と思いましたが、線形モデルである限りは結果は変わらないはずです。

また、ディープラーニングではキレイに分離できるはずですが、しかし、おそらくそれはオーバーフィッティングであって、一般的な解とはいえない(つまり別のデータセットを持ってきたら、やはり間違う)はず。そして、学習には膨大な中間層が必要で時間もかかると思います。

(2017-10-21追記②)

入力データと結果の可視化

入力データと判定エラーのチャートを描いてみました。 (赤:Setona、緑:Versicolour、青:Virginica)

f:id:takamints:20171021124312p:plain

○が入力データで、●は判定ミスです(間違ってこの色が示す種類に判定された)。

ご覧のように、緑:Versicolour、青:Virginicaとが入り混じっている箇所があります。 この部分がリニアには分離できないと書かれているところだと思います。

ちなみにチャートを描くスクリプトは、以下になります。

# Draw iris charts
[];

function show_chart(D, error_data, featureX, featureY, chartTitle)

    hold all

    Dy0 = D( D( :, 5 ) == 0, : );
    scatter(Dy0(:,featureX), Dy0(:,featureY), 'r')

    Dy1 = D( D( :, 5 ) == 1, : );
    scatter(Dy1(:,featureX), Dy1(:,featureY), 'g')

    Dy2 = D( D( :, 5 ) == 2, : );
    scatter(Dy2(:,featureX), Dy2(:,featureY), 'b')

    error_data_r = error_data(error_data(:,5)==0, :);
    scatter(error_data_r(:, featureX), error_data_r(:, featureY), [], 'r', 'filled')

    error_data_g = error_data(error_data(:,5)==1, :);
    scatter(error_data_g(:, featureX), error_data_g(:, featureY), [], 'g', 'filled')

    error_data_b = error_data(error_data(:,5)==2, :);
    scatter(error_data_b(:, featureX), error_data_b(:, featureY), [], 'b', 'filled')

    title(chartTitle)

endfunction

subplot(1,2,1)
show_chart(D, error_data, 1, 2, 'X:Sepal length, Y:Sepal width')

subplot(1,2,2)
show_chart(D, error_data, 3, 4, 'X:Petal length, Y:Petal width')

(2017-10-21追記②)おわり

公開リポジトリ

上記のスクリプトは、以下のGitリポジトリに置いてます。 (最新版では学習回数は1万回に設定されています)。 試してみたい人は是非どうぞ。

github.com

データ解析のためのロジスティック回帰モデル
Jr David W. Hosmer Stanley Lemeshow Rodney X. Sturdivant
共立出版
売り上げランキング: 178,009