Itsukaraの日記

最新IT技術を勉強・実践中。最近はDeep Learningに注力。

Variational Autoencoderの中間層データの性質

9/1の記事9/3の記事で、Variational Autoencoderの中間層の話を書きましたが、その後、山田さんから「中間層が5次元での画像の再現状況は?」との質問を頂きました。これを改めて確認しました(後述)。

ちなみに、当方は、VAEに詳しいわけではなく、単にVAEの特性が面白いと思ったので、他の方が作成されたVAEのプログラムをいじって遊んでいるレベルです。

高次元多様体の低次元構造の分析に関する論文

また、MNISTの文字画像全体が属する高次元空間(画素数(28x28=784)次元)内での、MNISTの文字画像全体が成す多様体の低次元構造を分析した論文を、NTT研究所の山田さんから紹介頂きました。

ニューラルネット写像関数と考えて、その微分(ヤコビアン行列)を求め、ヤコビアン行列の固有値を調べると、0でない固有値の数が多様体の次元が分かるとのこと。理由は未説明だったため今ひとつ分かりませんでしたが、中間層で活性化されたノード(値がほぼゼロでないノード)の数がMNIST多様体の次元であることを考えると、当然な気もします。ヤコビアン固有値が0位外の固有ベクトルのみが、中間ノードを変化させるわけですから...

逆に、高次元多様体の次元を調べるには、わざわざヤコビアン行列を求めなくても、VAE適用後の中間層の活性化ノード数を調べれば十分な気がしました。

ちなみに、この論文によると、トレーニングを進めると非ゼロ固有値の数(正確には値が1位上の固有値の数)が10個程度(論文のグラフから目算で推定)になったようです。この結果は、VAE適用後の活性化ノード数と一致します。

中間層の次元が5の場合

理由が不明ですが、色々な環境で10回くらい試しても、下記にように上手く収束しませんでした(cost = nanになってしまう)。

Epoch: 0001 cost= 187.885618730
Epoch: 0006 cost= 124.540811934
Epoch: 0011 cost= 119.978319744
Epoch: 0016 cost= 117.629997753
Epoch: 0021 cost= 116.178550346
Epoch: 0026 cost= nan
Epoch: 0031 cost= nan
Epoch: 0036 cost= nan
Epoch: 0041 cost= nan
Epoch: 0046 cost= nan
Epoch: 0051 cost= nan
Epoch: 0056 cost= nan
Epoch: 0061 cost= nan
Epoch: 0066 cost= nan
Epoch: 0071 cost= nan

上記の状態では、下記のように、入力画像は上手く再現されませんでした(全て真白)。
f:id:Itsukara:20160905231609p:plain

中間層が2次元の場合

気を取り直して、中間層が4次元の場合に、再現画像を表示してみました。2次元では計算が収束し、再現画像は下記のようになりました。20次元の場合と比べると、再現性が今ひとつです。とくに「2」の再現性が悪いですね。
f:id:Itsukara:20160905223940p:plain

後から考えたら、MNIST多様体(下記)で数字の画像が再現出来ていたわけなので、その一部を取り出したようなものですね。
f:id:Itsukara:20160906023941p:plain

中間層の値の分布

中間層が2次元の場合は、VAE関連の文献によく載っているように、下図で中間層の値の分布を示すのが一番分かりやすいですね。下図では、各数字毎に中間層の2ノードうち1つのノード値をx軸の値にし、もう1つのノードの値をy軸の値にしてプロットしたものです。また、数字毎に別々の色が付いています。0, 1, 2, 7は他の数字から離れた場所に分布していますが、他は、分布が重なっているものが多いです。
f:id:Itsukara:20160905224040p:plain

中間層が4次元の場合

中間層が4次元の場合も、計算は収束し、再現画像は下記のようになりました。2次元と比べると、かなり再現性が良くなっています。
f:id:Itsukara:20160905230658p:plain

中間層の値の分布

中間層が4次元の場合、2次元のような分布図を作成することはできないため、前々回記事記載の方法で、各数字に対する中間層の値の分布を作成してみました(今回は1つの図に纏めました)。4次元ぐらいならば、数字毎の中間層グラフの形が異なるのが、なんとなく見て取れると思います。特に、0と1は、グラフの形が大きく異なります。ただ、2と3は、グラフの形がかなり似ており、区別が付きませんね。
f:id:Itsukara:20160905230719p:plain

4次元空間を2次元で示すことは無理ですが、2次元への射影を表示することは可能なので、射影の2次元分布図を作成してみました。3つ試しましたが、どれも、2次元の場合と比べると、重なりが大きいです。これは、射影した影響と思います。4次元空間では、それなりに分離している可能性があります。

1次元目と2次元目の平面への射影

f:id:Itsukara:20160905230748p:plain

1次元目と3次元目の平面への射影

f:id:Itsukara:20160905230818p:plain

2次元目と3次元目の平面への射影

f:id:Itsukara:20160905230833p:plain

中間層が30次元の場合

再現性の比較のため、中間層が30次元の場合も再度計算してみました。下記のように、非常に良い再現性が得られています。数字によっては、入力画像よりも高精細に見え(例えば4)、どちらが入力画像であるか区別するのが難しくなっています。
f:id:Itsukara:20160905225419p:plain

中間層の値の分布

中間層の値の分布は下記のようになり、30次元のうち、11次元のみが活性しています。このように活性化次元が少ないのは、VAEの特性によるものではないかとのご意見を山田さんから頂いています。
f:id:Itsukara:20160905225500p:plain

30次元でも、4次元の時と同様に、2つの次元からなる平面に射影してみました。2つの次元としては、上記グラフを目で見て、数字毎の違いが感じられる次元を選択しました。それでも、数字毎の分布の重なりが非常多く、上手く分離できていないことが分かります。

17次元目と18次元目の平面への射影

f:id:Itsukara:20160905225648p:plain

18次元目と20次元目の平面への射影

f:id:Itsukara:20160905225941p:plain

0次元目と4次元目の平面への射影

比較のために、目で見てグラフの違いが少ない次元への射影も作図してみました。上記2つよりも更に重なりが大きく、全体が混在しているように見えることが分かります。
f:id:Itsukara:20160905230032p:plain

30次元空間での各数字の分離状況

前回と同様に、中間層の30次元内での各画像間の距離を考えることで、各数字の画像の分布の分離状況を計算してみました。

Ratio of neaest members from center 0: is_nearest1=98.4%, is_nearest2=15.9%, is_nearest=98.6%
Ratio of neaest members from center 1: is_nearest1=98.8%, is_nearest2=38.9%, is_nearest=98.9%
Ratio of neaest members from center 2: is_nearest1=93.8%, is_nearest2=38.8%, is_nearest=94.2%
Ratio of neaest members from center 3: is_nearest1=93.8%, is_nearest2=50.8%, is_nearest=95.2%
Ratio of neaest members from center 4: is_nearest1=94.4%, is_nearest2=48.4%, is_nearest=95.0%
Ratio of neaest members from center 5: is_nearest1=92.8%, is_nearest2=48.9%, is_nearest=93.2%
Ratio of neaest members from center 6: is_nearest1=93.5%, is_nearest2=46.4%, is_nearest=95.0%
Ratio of neaest members from center 7: is_nearest1=95.4%, is_nearest2=16.4%, is_nearest=96.2%
Ratio of neaest members from center 8: is_nearest1=88.6%, is_nearest2=44.5%, is_nearest=90.5%
Ratio of neaest members from center 9: is_nearest1=90.2%, is_nearest2=49.0%, is_nearest=91.6%

上記で"is_nearest1"は、各数字毎にその数字の画像全体の重心(平均)を計算し、その重心が他の数字の重心よりも近いデータの比率を求めたものです。この数字が大きいということは、各数字の画像データが30次元中間層内の球内に分布していると仮定して、他の数字の球と分離できるということです。0では、98.4%が他の数字の球と分離できていることが分かります。分離が一番悪い8でも、88.6%は、他の数字の球と分離できています。

"is_nearest2"は、30次元空間で、同じ数字のデータは原点から見て同じ方向にあるのではないかと推察して計算したものです。つまり、各画像データに対して、原点から画像データへのベクトルと、重心から重心へのベクトルのなす角を求め(実際は角度のCosine)、角度が一番小さい重心が近くにあると考えて分類したものです。残念ながら、これだけでは、分類精度は15.9%〜50.8%と、非常に悪い結果になってしまいました。ただ、距離が近いか、角度が近いか、いずれかが成立する重心と近いと考えると、少しだけ精度が上がります("is_nearest")。

あとがき

もしかしたら、各数字の画像全体を含む多様体として、比較的単純な形である楕円体で近似すると、上手く分離できているかもしれないと思ったのですが、各数字の画像全体に対して楕円体を計算する手段が思いつかなかったので、それ以上は進んでおりません。