Itsukaraの日記

最新IT技術を勉強・実践中。最近はDeep Learningに注力。

Variational Autoencoderの中間層データの構造(MNIST)

結構前(9/6)に、Variational Autoencoderの中間層の構造に関する記事を書きました。その後、前回の記事で気になっていた点(中間層データの構造)を確認したので、メモにします。

以前の記事の要約(今回記事関連)

MNISTのデータ(0〜9の各数字の画像データ集)で、中間層を30次元にしてVAEで学習させた場合、中間層のうち11次元のみが活性化し、残りの次元は不活性(値がほぼ0)というのが、前回記事の概要です。この11次元多様体上では、数字の0〜9の画像群が、ある程度分離して分布していると思われますが、11次元多様体の構造を目で見ることは難しく、なんか良い確認手段はないかという話を「あとがき」に書きました。

ちなみに、11次元の球で近似しただけでも、画像データの約9割は、正しく分離できます。つまり、「各数字の画像データ群の中心(計算上は平均値、以下「数字の中心」)」と対象画像データの距離(ユークリッド距離)を計算した時、その画像データが属する数字の中心が一番近くにある比率は約9割になり、各数字毎に纏まって分布している事が分かります。

下記は、各数字の画像データ群に対して、その画像が属する数字の中心が直近の中心である比率を求めたものです(サンプルデータ数は10000個)。

0: nearest1=949(96.8%), nearest2=952(97.1%), nearest=952(97.1%)
1: nearest1=1105(97.4%), nearest2=1106(97.4%), nearest=1106(97.4%)
2: nearest1=888(86.0%), nearest2=891(86.3%), nearest=892(86.4%)
3: nearest1=863(85.4%), nearest2=868(85.9%), nearest=870(86.1%)
4: nearest1=855(87.1%), nearest2=863(87.9%), nearest=865(88.1%)
5: nearest1=782(87.7%), nearest2=786(88.1%), nearest=786(88.1%)
6: nearest1=835(87.2%), nearest2=831(86.7%), nearest=835(87.2%)
7: nearest1=927(90.2%), nearest2=926(90.1%), nearest=927(90.2%)
8: nearest1=832(85.4%), nearest2=821(84.3%), nearest=838(86.0%)
9: nearest1=857(84.9%), nearest2=817(81.0%), nearest=857(84.9%)
OK=8928/10000(89.3%), NG=1072/10000(10.7%)

nearest1は、対象画像データが属する数字の中心が一番近くにある比率です。nearest2は、原点から対象画像データへのベクトルと、原点から各数字の中心へのベクトルのなす角を求め(実際はcos(角度))、角度が一番小さい中心が近くにあると考え、その比率を求めたものです。nearestは、nearest1と"nearest2のいずれかの条件が満たされる比率です。

約9割(89.3%)の画像データは、その画像が属する数字の中心が直近の中心であることが分かります。残り1割強(10.7%)は、上手く分類できていません。

中間層データの形状を多変量正規分布で近似

前回の記事を書いた際、各数字の画像群を含む多様体として、球よりも扁平した形状である楕円体で近似すると、上手く分離できているかもしれないと思いました。直感的には、箱の中に、チョークが沢山入っているイメージです。チョーク同士は3次元空間で互いに分離していますが、2次元、つまり上から見ると重なって見えてしまいます。11次元の場合も、2次元に射影すると重なりが多いように見えますが、11次元では、上手く分離して分布していると考えました。

しかし、11次元の空間での分離状況を確認する良い方法が見つからず、放おっていました。

その後、もう少しVAEの中身をちゃんと理解するためにWebや本を見ていたところ、多変量正規分布で各数字の画像群の分布を近似すれば、楕円体のような扁平した形状を近似できることに気が付きました。例えば2次元ならば、下記のように扁平した形も表現できます。

f:id:Itsukara:20161002172718p:plain

f:id:Itsukara:20161002172711p:plain

各数字のデータ群の形状を上記のような多変量正規分布で近似するには、分散共分散行列Σが必要ですが、これは、各数字の画像データ群から次の式で計算できます。
f:id:Itsukara:20161002180524p:plain

後は、分散共分散行列Σ、中心の座標μ、画像データxから、多変量正規分布の式(下記)に従って、その画像がμを中心とする画像群に属する確率を計算できます。
f:id:Itsukara:20161002180310p:plain

この確率が一番大きな中心を直近の中心とし、各数字の画像データの直近の中心が、該当の数字である比率を計算した結果が下記です。

0: ok=950(96.9%), ng=30(3.1%)
1: ok=1065(93.8%), ng=70(6.2%)
2: ok=1023(99.1%), ng=9(0.9%)
3: ok=948(93.9%), ng=62(6.1%)
4: ok=969(98.7%), ng=13(1.3%)
5: ok=850(95.3%), ng=42(4.7%)
6: ok=903(94.3%), ng=55(5.7%)
7: ok=975(94.8%), ng=53(5.2%)
8: ok=946(97.1%), ng=28(2.9%)
9: ok=890(88.2%), ng=119(11.8%)
OK=9519/10000(95.2%), NG=481/10000(4.8%)

全画像データのうち、95.2%の画像データは、その画像が属する数字の中心が直近と分類されました。残り4.8%は上手く分類できていませんが、球形で近似した場合(10.7%)に比べると、半分以下になっており、分類精度が向上しています。

あとがき

VAEは急激に進化しているので、4.8%という数字は余り価値がありませんが、気になっていた点が確認できたので、スッキリしました。つまり、VAEの中間データは、各数字毎にかなり纏まった形で分布している事が確認できました。

分類から外れるのは、僅か4.8%であり、教師無しで、このような高精度の分類が出来るのは、元々の画像データの高次元空間(28x28=784次元)内で、同じ数字の画像データが纏まって分布しているためと思われます。

ちなみに、中間データではなく、画像データの空間内において、上記と同じようなことをやろうとしたのですが、画像データで0である部分が多いため、分散共分散行列のランクが391次元になり、逆行列を求めることができず(逆行列を求めるにはランクが784である必要あり)、上手く行きませんでした。784次元空間内の391次元部分空間のみを対象とすれば、うまく分類できる可能性がありますが、今回は、ここまでとしました。

補足

上手く分類できなかったデータ、つまり、上記4.8%に入っているデータを表示して見たのが下記です。左端は、各数字の中心です。これをみると、人間が見ても分類が難しいデータが多少ありますが、多くは、目視で正しく分類可能です。やはり、CNNに比べるとVAEでは、まだまだですね。
f:id:Itsukara:20161002182644p:plain

補足2

この文章を書きながら、プログラムを再確認したところ、実は、下記式でΣの行列式平方根で割る部分が抜けていました。ただ、これを入れると上手く分類できません。Σの行列式はかなり小さな値となり(10e-30以下)、誤差が多くなるためかもしれませんが、正しい定義にすると上手く行かない理由、および、定義が正しくなくても上手く行く理由は不明です。おそらく、正しくなくても上手く行く理由は、各数字の画像データ群の形状は近似できているためと思いますが(正しい場合との違いは、形状の大きさだけなので)
f:id:Itsukara:20161002180310p:plain

本件関連ソースコード(ipynb)とhtml

  • ipynb: vae-cov.ipynb (表示するにはjupyterが必要。実行確認可)
  • html : vae-cov.html (ブラウザで表示可能。実行確認不可)

上記ファイルを、下記サイトからダウンロード願います。
github.com