Variational Autoencoderの20次元中間層内距離の確認結果
昨日(8/31)、全脳アーキテクチャ勉強会のオフ会に参加したところ、NTT研究所の山田さんがVariational Autoencoderの話をされ、MNISTのデータに対して20ノードの中間層まで圧縮できる話を伺いました。
これだけ圧縮すると、いわゆる「おばあちゃんニューロン」のように、20個のノードのそれぞれが0〜9の数字の1つを表すようになっているのではないかと思ったのですが、そうはならないだろうとのこと。
気になったので、一応、確認してみました。
結果、そのとおりでした(おばあちゃんニューロンにはなっていない)。
また、中間層が2個の時は下図のように各数字の図形データの中間層の値が、近くに固まる傾向がありましたが、20個(20次元)でも同様の傾向があることが「数値的」に確認できました(20次元は視覚化ができないので数値的に確認しました)。
各数字の図形データに対する中間層の値の分布
学習が終わった後で、各数字の図形データに対して中間層の値の分布を示したものが下記です。各1つの折れ線が1つの図形データに対する中間層の値の分布を表しており、数百個の図形データでの中間層の値の分布を示したものです。1つの数字に対して、中間層は様々な分布になることが分かります。
- 数字0に対する中間層の値の分布
- 数字1に対する中間層の値の分布
- 数字2に対する中間層の値の分布
なお、全ての数字に対する分布図を見ると、中間層では値が殆ど0のノードが多く、値が0意外なのは10個程度でした。この10という数字が偶然なのか、データの種類が10個であることに関係しているのか、不明です。
各数字の図形データの中間層の重心どうしの距離
各数字の図形データの中間層データ(群)に対して、重心(計算的上は平均)を求め、重心間の距離を求めた結果が下記です。重心どうしの距離は、最大でも12であり、結構近いことが分かります。ちなみに、i行j列の成分は、数字iの中間層データの重心と、数字jの中間層データの重心の距離です。
[[ 0 11 7 6 7 6 4 7 6 4] [11 0 5 7 7 7 6 6 6 6] [ 7 5 0 5 7 8 6 6 6 8] [ 6 7 5 0 12 3 7 6 5 7] [ 7 7 7 12 0 12 5 7 6 2] [ 6 7 8 3 12 0 6 7 6 6] [ 4 6 6 7 5 6 0 10 7 5] [ 7 6 6 6 7 7 10 0 8 3] [ 6 6 6 5 6 6 7 8 0 4] [ 4 6 8 7 2 6 5 3 4 0]]
ちなみに、テンソル間のデータの距離は、要素ごとの距離の2乗和として計算しています。
各数字の図形データの重心からの距離
各数字の図形データの重心からの距離を求めたのが下記です。重心間の距離の最大値である12よりも遠くにあるデータが結構あることが分かりました。これでは、正しい重心の近くにない可能性があると思い、次項で確認しました。
- 数字0に対する中間層データの重心からの距離
- 数字1に対する中間層データの重心からの距離
- 数字2に対する中間層データの重心からの距離
各数字の図形データの重心が一番近い重心である比率
各数字の図形データの各データに対して、そのデータが属する重心が他の重心よりも近い比率を求めました。その結果、約9割以上は、正しい重心の近くにあることが分かりました。つまり、各数字の中間層データは、それなりに近くに固まっている事が分かります。
Ratio of 'center of 0 is_nearest' = 99.0% Ratio of 'center of 1 is_nearest' = 97.0% Ratio of 'center of 2 is_nearest' = 91.3% Ratio of 'center of 3 is_nearest' = 93.3% Ratio of 'center of 4 is_nearest' = 90.9% Ratio of 'center of 5 is_nearest' = 92.4% Ratio of 'center of 6 is_nearest' = 91.9% Ratio of 'center of 7 is_nearest' = 93.3% Ratio of 'center of 8 is_nearest' = 89.2% Ratio of 'center of 9 is_nearest' = 88.1%
別の記事で、MNISTデータのK平均法での分類を行っていますが、正しく分類できる比率は71.5%でした。これに対して、上記では約90%が正しく分類されています。つまり、VAEによって次元が784次元から20次元に圧縮されるだけでなく、距離による分類という観点で基のデータより上手く分類されていることが分かります。
ソースコード
ここで公開されていたコード(jupyter IPython)を基に、上記の確認用のコードを追加しました。追加後のコードは下記にあります。なお、html化したファイルも置きましたので、結果だけを見るならば、htmlファイルをご覧ください。ご参考まで。
github.com
感想
疑問点を確認することが出来ましたが、「中間層の活性ノード数が、入力したデータの種類数とほぼ等しい」という理由が、新たな疑問点となりました。