Itsukaraの日記

最新IT技術を勉強・実践中。最近はDeep Learningに注力。

初等線形代数・統計でMNISTの94%の画像を正しく分離!

前回記事でVAE(Variational Autoencorder)の中間層のデータ構造について書きましたが、MNISTの画像データ群自体の構造を、同様の手法で分析しました。結果、MNISTの50,000個の画像のうち94.0%は正しく分離できました*1。また、50,000個の画像の分類に基づいてテストデータ10,000個を分類したところ、91.6%は正しく分類できました*2

今回使った手法は、非常に単純な形状である楕円体で画像データ群の形状を近似しており、それだけでVAE(94.2%)*3と殆ど同じ精度(94.0%)で分類ができています。つまり、高次元でのMNIST画像データ群多様体の形状は、元々かなり単純であることが分かりました。

以下、結果の概要を記載します。計算手法及び詳細は下記githubのファイル「mnist-cov.html (html版)」か「mnist-cov.ipynb(jupyter notebook版)」を参照願います。
github.com

目的

MNISTは、0〜9の数字の画像データ(縦28横28で784画素/画像)とそのラベル(どの数字かを示す情報)が入っており、各数字毎の画像データ群が、784次元の空間内で纏まった領域にあることを確認することが今回の目的でした。

分散共分散行列の計算と画像例の生成

まず、各数字の画像データ群に対して、画像データ群の分布状況を表す分散共分散行列Σを求めました。

次に、この分散共分散行列の位置づけについて説明します。

各画像データ群は、各次元が平均0、分散1の784次元正規分布*4(形状は球形)の標本値x=(x1, .., xn)が、変換行列Aとベクトルuで「z = A * X + u」という変換によって楕円体に変換され、画像データz=(z1, ..., zn)が生成された、と仮定します。

このような変換を受けた後の画像データの集合に対して平均と分散を計算すると、平均はu、分散は「A * A.T」になります。この「A * A.T」は、分散共分散行列Σそのものなので、後は下記式の「A * A.T」にΣの値を代入すれば良いように見えます。しかし実はΣがスパースな行列であるため、式中にある「A * A.T」の逆行列を求めることができません。
f:id:Itsukara:20161005055213p:plain

しかし、「A * A.T = Σ」となるAを求めることはできます。これには、Σを特異値分解して「Σ = U * s * V (UとVはユニタリ行列、sは対角行列(特異値の一覧)」と表現し、「A = U * s^(1/2) * V」とすることで計算できます。

この理由ですが、Σは対象行列のため、実はV = U.Tとなります。そのため、A * A.T = (U * s^(1/2) * U.T) * (U * s^(1/2) * U.T).T = U * s^(1/2) * U.T * U * s^(1/2) * U.T = U * s * U.T = Σとなります。実際に数値で計算して確かめたところ、少なくともMNISTの分析では10-e7程度の精度で等号が成立しています。

上記で求めたAを使うと、784次元の正規分布でxを発生させ、「z = A * x + u」という式でzを計算できます。このzは、対象とした数字の画像データ群内の1点であり、対象とした数字と似た形状になると予測されます。

下記が、その例です。左から1番目が画像データ群の中心で、2番めは中心を加える前の生成データ(A * x)、右端は中心を加えたもの(A * x + u)です。中心のデータを加えなくても、なんとなく、数字の特徴を捉えている感じが解ると思います。

  • 数字0の分散共分散行列から生成した画像

f:id:Itsukara:20161005062940p:plain

  • 数字8の分散共分散行列から生成した画像

f:id:Itsukara:20161005063002p:plain

トレインデータの分離状況

MNISTの50,000個のトレインデータに対して、画像が正しい画像データ群に所属しているか確認しました。具体的には、0〜9の各数字の画像データ群に所属するデータに対して、所属する確率が一番大きいデータ群が正しいデータ群である比率を計算しました。下記が結果です。94.0%と、かなり高精度で正しいデータ群への所属を確認できました。

0: ok=4842(97.7%), ng=113(2.3%)
1: ok=5152(91.2%), ng=497(8.8%)
2: ok=4814(96.7%), ng=163(3.3%)
3: ok=4747(92.6%), ng=382(7.4%)
4: ok=4639(96.5%), ng=169(3.5%)
5: ok=4202(92.8%), ng=327(7.2%)
6: ok=4689(95.0%), ng=246(5.0%)
7: ok=4801(93.2%), ng=351(6.8%)
8: ok=4573(93.4%), ng=322(6.6%)
9: ok=4522(91.0%), ng=449(9.0%)
OK=46981/50000(94.0%), NG=3019/50000(6.0%)

テストデータの分類結果

トレインデータの分類に基づいてテストデータ10,000個を分類した結果を示します。91.6%と、かなり良い精度では正しく分類できました。

0: ok=952(97.1%), ng=28(2.9%)
1: ok=1027(90.5%), ng=108(9.5%)
2: ok=977(94.7%), ng=55(5.3%)
3: ok=907(89.8%), ng=103(10.2%)
4: ok=938(95.5%), ng=44(4.5%)
5: ok=799(89.6%), ng=93(10.4%)
6: ok=879(91.8%), ng=79(8.2%)
7: ok=922(89.7%), ng=106(10.3%)
8: ok=873(89.6%), ng=101(10.4%)
9: ok=887(87.9%), ng=122(12.1%)
OK=9161/10000(91.6%), NG=839/10000(8.4%)

あとがき

今回の手法は、ニューラルネットやVAEは全く使わず、初等的な線形代数と統計の手法のみでの分類です。古典的かつ初等的な数学しか使っていないので、過去に同様の手法で分類された可能性があるかもしれませんが、未確認です。

今回の手法は、ニューラルネットでの学習と比べると、格段に高速です。VAEを使った場合は、GPUを使っても10分ぐらいかかりましたが、今回の計算はGPU無しでも1分ぐらいでした。ただ、VAEと異なり、画像生成に関しては、生成された画像の品質が非常に悪いです。

今回の実験は、Numpyやscipyの行列関連関数の活用とmatplotlibでの視覚化により、かなり簡単、高速、インタラクティブに実施することができました。これらのが簡単に使えるPythonは、やはり非常に便利ですね。今回行ったことを10年前にやったとすると、おそらく、必要となるソフトウェアの開発期間と手間の増大、パソコンの性能の低さによる実験時間の長大化、必要な情報を調べるために必要な時間の増大などにより、1か月程度は掛った気がします。というか、そもそも、やる気がしなかったと思います。

*1:分類したわけではなく、予め分類して作成したカテゴリーに入っている確率が一番高いことが確認できたということ

*2:テストデータの場合は、予め分類して作成したカテゴリーを元に分類できたということ

*3:以前に画像データ10,000個で測定した際は95.2%でしたが、データ数を50,000個にして再測定したところ94.2%となりました

*4:各次元が平均0、分散1の正規分布ということ