線形回帰で「アヤメ(iris)」の分類
octaveを使って線形回帰で「アヤメの分類」をやってみました。
なんで急にこんなことやり始めたかっていうと、先日、会社の方に初めて機械学習関連のお仕事が舞い込んできたようでして、自分は担当外なのですが「今の時代、それぐらいできなアカンやろーwww」とか、いつもの調子でついハッタリをかましてしまって(笑)・・・、
不安になって自宅で機械学習のコソ練せざるを得ないという。いくつになってもお勉強です(泣)
「線形回帰てなんやねん」については以下の記事で、ごく初歩の概念を説明してます。
takamints.hatenablog.jp
お仕事の方は、Chainer/XGBoost/Pythonという条件がついているようですが、我が師と(勝手に)仰いでいる Andrew Ng 先生は「最初は octave(matlab) でやりなはれや」とおっしゃられていたので、手元の octave でやってみましたという次第。
ということで、以降のスクリプトを実行するには octave または matlabが必要です。 データセットのダウンロードと変換に curl と Nodeを使っていますが、手動でやるなら不要ですよ。
octaveのWindowsへの導入は、以下の記事を参考に。
まずはアヤメ(iris)をダウンロードして変換します
機械学習のためのオープンな訓練データとして、アヤメ(iris)というのがあるらしいと、この度初耳。
データセットは150件と小さくて、4つのパラメータで3種類のアヤメを同定するというものです。
- iris.data - CSVのデータセット
- iris.names - データの説明などが記述されてる
上記のファイルをダウンロードするため以下のシェルスクリプトを書きました。
#!/bin/sh echo Downloading iris.data and iris.names. curl --silent https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data > iris.data curl --silent https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.names > iris.names echo echo Converting iris.data CSV. echo node conv-iris-csv.js
iris.dataをダウンロードしてCSVを変換しています(変換については次項↓参照)。
bashで実行するとこうなります。
$ ./get-iris-csv.sh Downloading iris.data and iris.names. Converting iris.data CSV. The class values in column 5: 0: Iris-setosa 1: Iris-versicolor 2: Iris-virginica
CSVの5列目を数値に変換
iris.data の5列目はアヤメの種類の名前であり、文字列です。 しかし、octave の csvread 関数では数値しか読み込めません。数学的な行列ですからね。
そこで以下のNodeスクリプト。 iris.data を読み込んで、全行の5列目の文字列を 0,1,2 の数値に変換。 iris.csv に出力してます(おっと、エラーを一切見ていませんが良い子は絶対マネしないでw)。
#!/usr/bin/env node var inputFile = "iris.data"; var outputFile = "iris.csv"; var fs = require("fs"); fs.readFile(inputFile, "utf-8", function(err, data) { var indexOfName = {}; fs.writeFile(outputFile, data.split(/\r*\n/).map(function(row, indexRow) { if(row == "") { return ""; } return row.split(",").map(function(column, indexCol) { if(indexCol < 4) { return column; } if(!(column in indexOfName)) { indexOfName[column] = Object.keys(indexOfName).length; } return indexOfName[column]; }).join(","); }).join("\n"), function(err) { var nameIndex = new Array(Object.keys(indexOfName).length); Object.keys(indexOfName).forEach(function(key) { nameIndex[indexOfName[key]] = "" + indexOfName[key] + ": " + key; }); console.log("The class values in column 5:"); console.log(nameIndex.join("\n")); }); });
データについて
iris.namesの「4. Relevant Information」には、
「ひとつは他の2つからリニアに分離できるが、他の2つは無理」と書いてありました(以下1>
の行)
また、「データに間違いがある」とも書かれていました(以下2>
の行)が、修正しても結果に変化はありませんでした。
4. Relevant Information: --- This is perhaps the best known database to be found in the pattern recognition literature. Fisher's paper is a classic in the field and is referenced frequently to this day. (See Duda & Hart, for example.) The data set contains 3 classes of 50 instances each, 1> where each class refers to a type of iris plant. One class is 1> linearly separable from the other 2; the latter are NOT linearly 1> separable from each other. --- Predicted attribute: class of iris plant. --- This is an exceedingly simple domain. --- This data differs from the data presented in Fishers article (identified by Steve Chadwick, spchadwick@espeedaz.net ) 2> The 35th sample should be: 4.9,3.1,1.5,0.2,"Iris-setosa" 2> where the error is in the fourth feature. 2> The 38th sample: 4.9,3.6,1.4,0.1,"Iris-setosa" 2> where the errors are in the second and third features. ```
線形回帰のoctave/matlabスクリプト
以下のoctave/matlabのスクリプトが今回の御本尊。
- CSVを読み込んで、
- 学習してから
- 検証してます。
本来ならばデータセットの3分の1程度をトレーニングには使わず検証用に残しておくべきなのですが、 データ件数が150件と非常に小さいため、全データでトレーニングして、全データで検証してます。手前味噌な感じですが仕方がない。
# # Classification of iris using linear regression # # Clear all mat clear D = csvread('iris.csv'); # Load iris.csv # D = D(randperm(size(D,1)),:); # Sort random # # Select training dataset # Dt = D(1:size(D, 1), :); # Dt = D(1:2*size(D, 1)/3, :); Xt = Dt(:, 1:(size(Dt,2)-1)); # Input Xt = [ones(size(Xt,1),1) Xt]; Yt = Dt(:, size(Dt,2)); # Result # # Select validation dataset # Dv = D(1:size(D, 1), :); # Dv = D(2*size(D, 1)/3+1:size(D,1), :); Xv = Dv(:, 1:(size(Dv,2)-1)); Xv = [ones(size(Xv,1),1) Xv]; Yv = Dv(:, size(Dv,2)); training_data_count = size(Xt,1); learning_rate = 0.00001 # learning rate iteration = 5000000 report_interval = round(iteration / 10); # # Training # theta = ones(1, size(Xt,2)); # factor numOfParam = size(theta, 2); update = [1, numOfParam]; for n = 1:iteration diff = Xt * theta' - Yt; for j = 1:numOfParam update(1,j) = diff' * Xt(:,j); endfor theta = theta - learning_rate * update / training_data_count; if mod(n,report_interval) == 0 cost = (diff' * diff) / ( 2 * size(Xt,1) ) endif endfor # # Validation # y = Xv * theta'; diff = y - Yv; #validationResult = [Yv round(y)] error_count = 0; for i = 1:size(diff,1) diffI = diff(i,1); if(diffI * diffI >= 0.5 * 0.5) error_index = i error_count = error_count + 1; endif endfor theta validationErrorRate = error_count / size(Xv, 1) validationTotalCost = diff' * diff / ( 2 * size(Xv,1) )
結果検証
以下が上記スクリプトの実行結果。 150件中3件の判定間違いがありますが、これが限界みたいです。
learning_rate = 1.0000e-005 iteration = 5000000 cost = 0.026944 cost = 0.025440 cost = 0.024725 cost = 0.024366 cost = 0.024169 cost = 0.024048 cost = 0.023963 cost = 0.023897 cost = 0.023841 cost = 0.023792 error_index = 71 error_index = 84 error_index = 134 theta = 0.586249 -0.177805 -0.067764 0.244015 0.621859 validationErrorRate = 0.020000 validationTotalCost = 0.023792
(2017-10-21追記①)
データの可視化をしようとしていて、単なる偶然ですが学習回数を少なくできるのを見つけました。 上の実行例では学習を500万回繰り返していますが、1万回で同じ結果を得られます。 learning_rateは1000倍で、学習は数秒で終わります。さらに誤差が若干少ない。
iris learning_rate = 0.010000 iteration = 10000 cost = 0.025441 cost = 0.024366 cost = 0.024048 cost = 0.023897 cost = 0.023792 cost = 0.023708 cost = 0.023636 cost = 0.023575 cost = 0.023522 cost = 0.023476 error_index = 71 error_index = 84 error_index = 134 theta = 0.464549 -0.152014 -0.066062 0.234802 0.621481 validationErrorRate = 0.020000 validationTotalCost = 0.023476
(2017-10-21追記①)おわり
別途「Iris-setosa」を、他の2つから完璧に(そしてかなり簡単に)分離できました。 しかし「Iris-versicolor」と「Iris-virginica」はパラメータ平面において数件が領域を共有してるため、線形モデルで分離するのは困難なようです。
これ以上判定精度を挙げるには別の新たなパラメータが必要になるはずです。 そんなこんなで、この分類は成功していると言えそうです。 当初はなんかおかしいと思っていましたが(笑)
結果が離散的な分類問題なので、ロジスティック回帰で解く必要があるのかも?と思いましたが、線形モデルである限りは結果は変わらないはずです。
また、ディープラーニングではキレイに分離できるはずですが、しかし、おそらくそれはオーバーフィッティングであって、一般的な解とはいえない(つまり別のデータセットを持ってきたら、やはり間違う)はず。そして、学習には膨大な中間層が必要で時間もかかると思います。
(2017-10-21追記②)
入力データと結果の可視化
入力データと判定エラーのチャートを描いてみました。 (赤:Setona、緑:Versicolour、青:Virginica)
○が入力データで、●は判定ミスです(間違ってこの色が示す種類に判定された)。
ご覧のように、緑:Versicolour、青:Virginicaとが入り混じっている箇所があります。 この部分がリニアには分離できないと書かれているところだと思います。
ちなみにチャートを描くスクリプトは、以下になります。
# Draw iris charts []; function show_chart(D, error_data, featureX, featureY, chartTitle) hold all Dy0 = D( D( :, 5 ) == 0, : ); scatter(Dy0(:,featureX), Dy0(:,featureY), 'r') Dy1 = D( D( :, 5 ) == 1, : ); scatter(Dy1(:,featureX), Dy1(:,featureY), 'g') Dy2 = D( D( :, 5 ) == 2, : ); scatter(Dy2(:,featureX), Dy2(:,featureY), 'b') error_data_r = error_data(error_data(:,5)==0, :); scatter(error_data_r(:, featureX), error_data_r(:, featureY), [], 'r', 'filled') error_data_g = error_data(error_data(:,5)==1, :); scatter(error_data_g(:, featureX), error_data_g(:, featureY), [], 'g', 'filled') error_data_b = error_data(error_data(:,5)==2, :); scatter(error_data_b(:, featureX), error_data_b(:, featureY), [], 'b', 'filled') title(chartTitle) endfunction subplot(1,2,1) show_chart(D, error_data, 1, 2, 'X:Sepal length, Y:Sepal width') subplot(1,2,2) show_chart(D, error_data, 3, 4, 'X:Petal length, Y:Petal width')
(2017-10-21追記②)おわり
公開リポジトリ
上記のスクリプトは、以下のGitリポジトリに置いてます。 (最新版では学習回数は1万回に設定されています)。 試してみたい人は是非どうぞ。