K平均法でMNIST文字データの分類を試行
Do2dle勉強会という機械学習関連の勉強会に参加しており、月に2〜3回のペースで集まって、本の輪講などを行っています。一人で本を読んでいるのと異なり、勉強会では本の内容に関して色々と議論するので、とても刺激になって楽しいです。
ちなみに、Do2dle(ドゥードゥル)とは、勉強会で目標としている「深層学習の深い理解(Deep Understanding of Deep Learning)」の頭文字DUDLの音読みDoodleを省略した表記です。現在、下記にて新メンバ募集中です。
do2dle.connpass.com
先日(8/6)は、下記2つ輪講を行いました。議論の中で、下記1つ目の「K平均法」をMNIST文字データの分類に適用するとどうなるか、という話になりました(私が言い出した話ではありますが...)。そこで、実際に試してみましたので、結果をまとめておきます。
ちなみに、MNISTはTensorflowインストール後のテストプログラムなどとして実行されることが多く、トレーニングデータでは認識率100%、テストデータでも認識率99.8%ぐらいになります。古典的な手法であるK平均法だと、どのくらい分類できるのか、気になるところです。
K平均法
K平均法とは、データの分類方法の一種で、データの平均値(物理的には重心)を使って、与えられたデータの集合を、K個の部分集合(クラスタ)に分類する方法です。
- 最初に、分類の起点として、各クラスタの重心データをK個生成します。
- 重心データは任意の値でOK。入力データの一部でも良い。
- それぞれの重心との距離を基に、データをクラスタに分類します。
- 各データは、一番近い距離にある重心のクラスタに所属します。
- 各クラスタに所属するデータの平均値を新しい重心にします。
- 新しい重心が決まったところで、「2.」に戻って同じことを繰り返します。
- 上記を、重心の変化が小さくなるまで続けます。
なんだか、中学生でも思いつきそうなアルゴリズムですが、これで重心の変化が少なくなっていき、収束することが証明されています。証明は、下記の本を参照ください。下記の本では、カラー写真の減色処理にK平均法を使った例が紹介されています。
- 作者: 中井悦司
- 出版社/メーカー: 技術評論社
- 発売日: 2015/10/17
- メディア: 単行本(ソフトカバー)
- この商品を含むブログ (1件) を見る
ただ、残念ながら、結果が白黒で印刷されていますので、カラーが良い方は本の著者の方が公開しているスライドを参照願います。ちなみに、このスライドだけでも、かなり勉強になります。もちろん、本にしか書かれていないことも結構ありますが...
既存プログラムの改修
実は、上記の本の「第7章 EMアルゴリズム:最尤推定法による教師なし学習」に、K平均法よりも高度な手法を用いた、MNIST文字データ分類方法が書かれています。また、この本では、それぞれの章で章で紹介しているアルゴリズムを実装したプログラムをGithubに公開しています。
そこで、7章のプログラムに、6章のK平均法のアルゴリズムを取り入れて、新しいプログラムを作成しました。また、7章のプログラムも、情報表示の追加や性能向上などの修正を行いました。修正後のソースコードと実行結果を、下記Githubに置きましたので、良かったら試してみてください。
github.com
- 追加修正部部分
- scripts/07-k_means.py (MNIST文字データへのK平均法適用)
- scripts/07-mix_em.py (情報表示や性能向上)
- scripts/logs/* (実行結果)
なお、元のプログラムでは、結果が図で表示されるものの、統計的な情報が表示されませんでした。また、最初の重心データは乱数で決めていました。これら含め、次のような修正を行いました。
- K平均法を取り込んだプログラムを作成
- 文字データをサンプリングして重心データを計算する処理追加
- 計算に使う文字データ数は可変(0:乱数のみ、1〜:サンプル数)
- 表示した図を自動的にファイルに保存する処理追加
- 処理時間を表示する処理追加(性能が知りたいため)
- 設定情報を表示する処理追加(各実行の設定内容の確認容易化)
- 図の表示カラム数をパラメータ化しファイル先頭に配置
- 07-mix_em.pyの性能向上(関数bern()の改修で約4倍性能向上)
- 各種統計情報の表示追加
MNIST文字データへのK平均法適用結果
下記にK平均法の適用結果を示します。実行結果は全てGithubのscripts/logsにありますので、性能などの詳細が気になる方は、そちらを参照ください。
重心データを乱数で生成した場合
この場合、どの文字がどのクラスタに分類されるかは確定しません。下記出力は、各クラスタ毎に、各文字の出現頻度を表にしたものです。これを見ると、例えばCluster04では4を26個、9を29個含んでおり、4と9がうまく分離できていないことが分かります。他のクラスタも、複数の文字が混在しており、分類は、ほとんどうまく行っていない感じです。参考までに、各クラスタで出現が一番大きい数字に分類されたと考えると、正しく分類できたのは50%強です(手計算)。
***** Histogram of labels ***** 0 1 2 3 4 5 6 7 8 9 -------------------------------------------------------------------------------- [Cluster00] 41 0 0 0 0 1 1 0 0 0 [Cluster01] 2 1 2 0 1 24 0 0 1 1 [Cluster02] 0 76 9 1 2 1 3 4 4 0 [Cluster03] 3 1 11 31 0 3 0 0 31 1 [Cluster04] 1 0 7 4 26 1 5 18 3 29 [Cluster05] 7 1 4 23 0 19 1 0 9 3 [Cluster06] 0 0 2 0 21 1 0 38 1 31 [Cluster07] 4 0 28 0 2 1 44 0 0 0 [Cluster08] 0 0 0 0 7 0 0 2 0 0 [Cluster09] 0 0 1 0 0 0 0 0 0 0
ちなみに、下記に分類後の各クラスタを載せました。ただ、あくまでもクラスタに所属する文字データのサンプルが示されているだけなので、全体としてどうなっているかがわかりません。上記の統計情報を出すようにしたのは、全体の状況が知りたいためです。
各文字毎のデータ20個で重心を計算した場合
この場合、各クラスタ毎に、サンプル文字データから生成した重心が初期設定されているので、分類の正確さを表示できます。下記が結果です。全体の正解率(accuracy)は71.5%と、ニューラルネットなどと比べるとかなり低いです。
正解率が特に低いのは、Label9(42.9%)、Label5(53.3%)、Label3(62.2%)、Label7(62.5%)です。クラスタのHistogramを見ると、(Label09に対応した)Cluster09には7が22個も含まれ、Cluster03には5が19個も含まれ、Cluster07には9が14個も含まれているのが分かります。単に画像データ間の距離だけで分類すると、この程度の精度しかでないのですね。逆に、別々の数字でも、距離的*1には、かなり近いものもあるのですね。
[Label0] #members= 53, #correct-members= 51, accuracy=96.2% [Label1] #members= 96, #correct-members= 74, accuracy=77.1% [Label2] #members= 41, #correct-members= 39, accuracy=95.1% [Label3] #members= 74, #correct-members= 46, accuracy=62.2% [Label4] #members= 56, #correct-members= 45, accuracy=80.4% [Label5] #members= 45, #correct-members= 24, accuracy=53.3% [Label6] #members= 58, #correct-members= 47, accuracy=81.0% [Label7] #members= 48, #correct-members= 30, accuracy=62.5% [Label8] #members= 38, #correct-members= 34, accuracy=89.5% [Label9] #members= 91, #correct-members= 39, accuracy=42.9% [TOTAL ] #members=600, #correct-members=429, accuracy=71.5% ***** Histogram of labels ***** 0 1 2 3 4 5 6 7 8 9 -------------------------------------------------------------------------------- [Cluster00] 51 0 0 0 0 1 1 0 0 0 [Cluster01] 0 74 11 1 1 1 2 3 3 0 [Cluster02] 0 0 39 1 0 0 0 1 0 0 [Cluster03] 0 0 2 46 0 19 0 0 6 1 [Cluster04] 0 0 1 0 45 0 2 2 1 5 [Cluster05] 2 2 2 0 2 24 2 4 1 6 [Cluster06] 4 0 3 0 2 1 47 0 1 0 [Cluster07] 0 0 2 2 0 0 0 30 0 14 [Cluster08] 1 1 0 1 0 1 0 0 34 0 [Cluster09] 0 2 4 8 9 4 0 22 3 39
各文字毎の全データで重心を計算した場合
この場合、全ての文字データを使って重心が初期設定されています。しかし、結果は、上記とあまり変わりませんね。
[Label0] #members= 51, #correct-members= 49, accuracy=96.1% [Label1] #members= 79, #correct-members= 57, accuracy=72.2% [Label2] #members= 42, #correct-members= 39, accuracy=92.9% [Label3] #members= 71, #correct-members= 46, accuracy=64.8% [Label4] #members= 57, #correct-members= 45, accuracy=78.9% [Label5] #members= 84, #correct-members= 30, accuracy=35.7% [Label6] #members= 57, #correct-members= 46, accuracy=80.7% [Label7] #members= 49, #correct-members= 38, accuracy=77.6% [Label8] #members= 39, #correct-members= 34, accuracy=87.2% [Label9] #members= 71, #correct-members= 47, accuracy=66.2% [TOTAL ] #members=600, #correct-members=431, accuracy=71.8% ***** Histogram of labels ***** 0 1 2 3 4 5 6 7 8 9 -------------------------------------------------------------------------------- [Cluster00] 49 0 0 0 0 1 1 0 0 0 [Cluster01] 0 57 10 1 1 0 1 6 2 1 [Cluster02] 1 0 39 1 0 0 0 1 0 0 [Cluster03] 0 0 2 46 0 16 0 0 6 1 [Cluster04] 0 0 1 0 45 0 3 2 1 5 [Cluster05] 3 21 5 7 5 30 3 4 2 4 [Cluster06] 4 0 2 0 2 1 46 1 1 0 [Cluster07] 0 0 2 2 0 0 0 38 0 7 [Cluster08] 1 1 0 1 0 2 0 0 34 0 [Cluster09] 0 0 3 1 6 1 0 10 3 47
最尤推定法での分類結果(ご参考)
「第7章 EMアルゴリズム:最尤推定法による教師なし学習」で書かれたプログラムを適用した場合の結果を参考までに載せます。
重心データを乱数で生成した場合
この場合、どの文字がどのクラスタに分類されるかは確定しません。下記出力は、各クラスタ毎に、各文字の出現頻度を表にしたものです。これを見ると、例えばCluster00では4を20個、7を24個、9を34個含んでおり、4と7と9がうまく分離できていないことが分かります。他のクラスタも、複数の文字が混在しており、分類は、ほとんどうまく行っていない感じです。参考までに、各クラスタで出現が一番大きい数字に分類されたと考えると、正しく分類できたのは50%弱です(手計算)。
***** Histogram of labels ***** 0 1 2 3 4 5 6 7 8 9 -------------------------------------------------------------------------------- [Cluster00] 1 0 10 8 20 4 4 24 4 34 [Cluster01] 1 0 1 0 33 1 36 2 0 2 [Cluster02] 0 0 0 0 1 0 0 31 0 24 [Cluster03] 0 40 3 1 0 1 0 1 2 0 [Cluster04] 0 36 0 0 2 15 1 3 3 0 [Cluster05] 19 0 0 0 2 1 2 0 0 0 [Cluster06] 0 1 2 14 1 4 0 1 5 3 [Cluster07] 2 1 23 29 0 15 2 0 31 1 [Cluster08] 2 1 23 6 0 6 9 0 4 1 [Cluster09] 33 0 2 1 0 4 0 0 0 0
各文字毎のデータ20個で重心を計算した場合
この場合、各クラスタ毎に、サンプル文字データから生成した重心が初期設定されているので、分類の正確さを表示できます。下記が結果です。全体の正解率(accuracy)は60.5%と、K平均法(71.5%)よりもかなり低いです。
正解率が特に低いのは、0(23.4%)です。クラスタのHistogramを見ると、Cluster09にはほぼ全ての数字が10個以上含まれ、ほとんど分類できていません。それにしても、K平均法よりも高度な手法なのに正解率が低いとは、非常に残念な結果でした。
[Label0] #members=214, #correct-members= 50, accuracy=23.4% [Label1] #members= 56, #correct-members= 54, accuracy=96.4% [Label2] #members= 36, #correct-members= 30, accuracy=83.3% [Label3] #members= 28, #correct-members= 27, accuracy=96.4% [Label4] #members= 31, #correct-members= 31, accuracy=100.0% [Label5] #members= 51, #correct-members= 27, accuracy=52.9% [Label6] #members= 44, #correct-members= 41, accuracy=93.2% [Label7] #members= 29, #correct-members= 28, accuracy=96.6% [Label8] #members= 53, #correct-members= 32, accuracy=60.4% [Label9] #members= 58, #correct-members= 43, accuracy=74.1% [TOTAL ] #members=600, #correct-members=363, accuracy=60.5% ***** Histogram of labels ***** 0 1 2 3 4 5 6 7 8 9 -------------------------------------------------------------------------------- [Cluster00] 50 4 31 22 24 18 10 24 14 17 [Cluster01] 0 54 0 1 0 0 0 0 1 0 [Cluster02] 0 1 30 4 0 0 1 0 0 0 [Cluster03] 0 0 0 27 0 0 0 1 0 0 [Cluster04] 0 0 0 0 31 0 0 0 0 0 [Cluster05] 0 17 2 2 0 27 1 0 1 1 [Cluster06] 0 2 0 0 0 1 41 0 0 0 [Cluster07] 0 0 0 0 0 0 0 28 0 1 [Cluster08] 8 1 1 1 2 3 1 1 32 3 [Cluster09] 0 0 0 2 2 2 0 8 1 43
各文字毎の全データで重心を計算した場合
この場合、全ての文字データを使って重心が初期設定されています。結果ですが、正答率が83.8%と、見違えるように良くなっています。K平均法(71.8%)と較べても、10ポイント以上良い値です。高度な手法の効果をやっと示すことができたと思います。それでも、ニューラルネット(99.8%)などと比べると、まだまだですね。正解率ではなく、エラー率で考えると、最尤推定法は16.2%、ニューラルネットは0.2%なので、80倍以上異なります。ニューラルネット(CNN)は特に画像認識に強く、AlphaGoもCNNを画像認識に使っているくらいなので、さすがに差が大きいですね。
[Label0] #members= 70, #correct-members= 48, accuracy=68.6% [Label1] #members= 76, #correct-members= 72, accuracy=94.7% [Label2] #members= 61, #correct-members= 53, accuracy=86.9% [Label3] #members= 58, #correct-members= 50, accuracy=86.2% [Label4] #members= 62, #correct-members= 53, accuracy=85.5% [Label5] #members= 49, #correct-members= 36, accuracy=73.5% [Label6] #members= 53, #correct-members= 48, accuracy=90.6% [Label7] #members= 52, #correct-members= 51, accuracy=98.1% [Label8] #members= 46, #correct-members= 35, accuracy=76.1% [Label9] #members= 73, #correct-members= 57, accuracy=78.1% [TOTAL ] #members=600, #correct-members=503, accuracy=83.8% ***** Histogram of labels ***** 0 1 2 3 4 5 6 7 8 9 -------------------------------------------------------------------------------- [Cluster00] 48 0 6 4 3 2 2 0 3 2 [Cluster01] 0 72 1 1 0 0 0 0 2 0 [Cluster02] 3 0 53 2 0 0 3 0 0 0 [Cluster03] 0 1 0 50 0 6 0 1 0 0 [Cluster04] 0 0 0 0 53 2 0 4 1 2 [Cluster05] 0 3 3 0 0 36 1 0 5 1 [Cluster06] 1 1 1 0 1 1 48 0 0 0 [Cluster07] 0 0 0 0 0 0 0 51 0 1 [Cluster08] 6 2 0 0 0 1 0 0 35 2 [Cluster09] 0 0 0 2 2 3 0 6 3 57