【Rで機械学習】ニューラルネットワークの結果を精査してみる。

wanko-sato.hatenablog.com

前回、Rで簡単なニューラルネットワークを自作してみました。参考にした書籍はこちら。


ニューラルネットワーク自作入門

ニューラルネットワーク自作入門

本書はPythonでの実装例が示されていますが、前回の記事ではそれをRに書き直してみよう、というものでした。
さて、本書では、MINSTを学習データとして使用しており、興味深い実験として、「ある出力結果を与えたらどのような入力データでありうるのか?」というback queryを行っています。それ自体、面白い結果だと思っているのですが、いかんせん入力データが大きいため、実際のところどうなの?というのがいまいちわかりづらい結果でした。

前回の記事では、三次元の正規分布を使った座標データを分類する、という非常に単純な課題をニューラルネットワークで学習させました。入力データは3つの座標、出力データは3つの確率です。精査していくのには手ごろな大きさなので、改めて、前回学習させたニューラルネットワークで何が起こっているのか、精査していきたいと思います。

そもそもの動機

機械学習において、「適切に分類できれば中身はどうでも良い」という考え方がひとつあります。精度が求められるタスク、たとえば自動運転や疾患の診断等では、どんな手段であっても精度が高い方が良いでしょう。
ただ、データ分析を日々の業務としている自分としては、「なぜそのように分類されたのか?」の方に興味があります。それはつまるところ、分類のための基準を機械学習モデルはどのように選び取っているのか、ということであり、学習アルゴリズムそのものへの興味でもあります。分類の基準がわかれば、それをもとに「人間が」考えるためのひとつの手がかりになります。それは、データを見る視点を新たに手に入れる、ということでもあります。
そういう意味で、ニューラルネットワークの中で何が起こっているのかをきちんと考えるのはとても有益なことだと考えています。

ニューラルネットワークの中では何が起こっているのか

おさらい

前回作成したニューラルネットワークは次のようなものでした。

  • 入力層、隠れ層、出力層の3層からなる
  • 各層のノードは3つずつで、すべてのノードが接続されている
  • 隠れ層および出力層から出力される値はロジスティック関数を通す

つまり、最終的に出力される値は1~3のどのグループに分類されるか、の確率を示したものになります。
このとき、入力-隠れの重み行列をW_{hi}、隠れ-出力の重み行列をW_{oh}とし、入力ベクトルをX=(x_{1},x_{2},x_{3})、隠れ層からの出力ベクトルをH=(h_{1},h_{2},h_{3})、出力層からの出力ベクトルをO=(o_{1},o_{2},o_{3})とします。
W_{hi}の重み行列は

W_{hi}=\begin{pmatrix}
whi_{11}&whi_{12}&whi_{13}\\
whi_{21}&whi_{22}&whi_{23}\\
whi_{31}&whi_{32}&whi_{33}\\
\end{pmatrix}

となっており、添え字の一桁目が隠れ層のノード番号、二桁目が入力層のノード番号になっています。つまり、whi_{12}は入力層の2番目のノードから隠れ層の1番目のノードへの重み、という意味になります。

どんな計算を行っているのか

ごくごく簡単に言えば、入力データのベクトルに重み行列をかけた積をロジスティック関数に通す、ということを繰り返しているわけです。
入力層から隠れ層へのデータは

W_{hi}X=\begin{pmatrix}
whi_{11}&whi_{12}&whi_{13}\\
whi_{21}&whi_{22}&whi_{23}\\
whi_{31}&whi_{32}&whi_{33}\\
\end{pmatrix}
\begin{pmatrix}
x_{1}\\
x_{2}\\
x_{3}\\
\end{pmatrix}

と計算されますが、要するにこれって

{\displaystyle
\begin{eqnarray}
  \left\{
    \begin{array}{1}
      whi_{11}x_{1} + whi_{12}x_{2} + whi_{13}x_{3}\\
      whi_{21}x_{1} + whi_{22}x_{2} + whi_{23}x_{3}\\
      whi_{31}x_{1} + whi_{32}x_{2} + whi_{33}x_{3}\\
    \end{array}
  \right.
\end{eqnarray}}

という3つの式を計算しているのと同じことなわけです。で、シグモイド関数

sigm(x)=\displaystyle \frac{1}{1+e^{-x}}

とロジスティック関数として定義すれば、

{\displaystyle
\begin{eqnarray}
  \left\{
    \begin{array}{1}
      h_{1} = sigm(whi_{11}x_{1} + whi_{12}x_{2} + whi_{13}x_{3})\\
      h_{2} = sigm(whi_{21}x_{1} + whi_{22}x_{2} + whi_{23}x_{3})\\
      h_{3} = sigm(whi_{31}x_{1} + whi_{32}x_{2} + whi_{33}x_{3})\\
    \end{array}
  \right.
\end{eqnarray}}

という計算を行っていることになるわけです。隠れ-出力でも同じ計算ですから、行列の積で表現すると

W_{oh}H=\begin{pmatrix}
woh_{11}&woh_{12}&woh_{13}\\
woh_{21}&woh_{22}&woh_{23}\\
woh_{31}&woh_{32}&woh_{33}\\
\end{pmatrix}
\begin{pmatrix}
h_{1}\\
h_{2}\\
h_{3}\\
\end{pmatrix}

こうなり、方程式の形にするのであれば

{\displaystyle
\begin{eqnarray}
  \left\{
    \begin{array}{1}
      o_{1} = sigm(woh_{11}h_{1} + woh_{12}h_{2} + woh_{13}h_{3})\\
      o_{2} = sigm(woh_{21}h_{1} + woh_{22}h_{2} + woh_{23}h_{3})\\
      o_{3} = sigm(woh_{31}h_{1} + woh_{32}h_{2} + woh_{33}h_{3})\\
    \end{array}
  \right.
\end{eqnarray}}

となります。で、出力に対しては教師データベクトルT=(t_{1},t_{2},t_{3})があり、そこから誤差ベクトルE=(e{1},e{2},e{3})が計算できる、という流れになっているわけです。ニューラルネットワークはこの誤差ベクトルを最小にするような重み行列W_{hi}W_{oh}を求めることである、と言い換えることもできます。

back query

であれば、すでに学習済みの重み行列を用いて、任意の出力がどのような入力から得られるか、逆方向の計算も可能であると考えられます。重み行列とデータベクトルの計算は

\begin{eqnarray}
  H &=& W_{hi}X\\
  W_{hi}^{-1}H &=& W_{hi}^{-1}W_{hi}X\\
  W_{hi}^{-1}H &=& IX\\
  W_{hi}^{-1}H &=& X\\
\end{eqnarray}

とこのように重み行列の逆行列を左側から乗じてあげることで、隠れ層のデータから入力層のデータを計算することができます。出力層から隠れ層のデータも同じです。
さらに、シグモイド関数として用いたロジスティック関数の逆関数であるロジット関数は

logit(x) = ln(\displaystyle \frac{x}{1-x})

と定義されますから、出力層のデータから隠れ層、隠れ層のデータから入力層を計算するには、


H = W_{oh}^{-1}logit(O)\\
X = W_{hi}^{-1}logit(H)

の二つの式で計算できることになるわけです。
ただ、ここで注意しなければならない点があります。ロジット関数の形をみると、

  • x1であってはならない(ゼロ割になるため)
  • \displaystyle \frac{x}{1-x}<0になってはならない

ということがわかります。コーディングにあたってはこの点に注意していく必要があります。

※ちなみに、書籍の方では逆行列ではなく転置行列を乗じて、得られたベクトルを0.01~0.99にスケーリングする、という方法をとっていました。

結果を振り返る

どんな結果だったか

f:id:wanko_sato:20171001133439p:plain

3パターンの正規乱数から発生させた三次元空間上の座標データを学習データとし、同じように発生させた少数のデータをテストデータとしました。学習の結果、

f:id:wanko_sato:20171001141436p:plain

精度96%となりました。グループ1(黒)とグループ3(緑)はほぼ正確に分類できているようですが、グループ2(赤)がちょっと難しそうだ、という結果となっています。
最初の方にも書いた通り、各データがどのグループに分類されているか、は出力層の各ノードに確率として出力されます。その中で確率最大のノードをその入力データの分類グループとして割り当てています。
なのですが、そういえば、単に「確率最大」といっているけれども、そもそもどういう出力になっているのかきちんと見ていなかったな、と思い立ち、各出力がどのような値になっているのかをプロットしてみました。

f:id:wanko_sato:20171104110847p:plain

おや、なにやら面白い形になっていますねぇ。
ちなみにX軸が出力層ノード1の確率、Y軸がノード2の確率、Z軸がノード3の確率になっています。ノード番号はグループ番号に対応しています。
これ、よくよく見てみると、Y軸つまりノード2の確率が0.2~0.4という狭い範囲で分布していることに気づきます。また、ノード1の確率が1付近の場合はノード2の確率が~0.3、ノード3の確率が1付近の場合はノード2の確率が0.35~になっていることもわかります。したがって、ノード2の確率がノード1、ノード3よりも高くなる範囲が非常に狭い、という形になっているのです。どういうことかというと、

f:id:wanko_sato:20171104111340p:plain

ノード1とノード2の確率をプロット。

f:id:wanko_sato:20171104111404p:plain

ノード2とノード3の確率をプロット。

f:id:wanko_sato:20171104111426p:plain

ノード1とノード3の確率をプロット。

このように、ノード2の出力値のとり得る範囲が非常に狭くなっており、かつノード1とノード3が中間的な確率をとるケースが非常に少ない、ということが見えてきます。ということは、back queryを行うにあたって、「任意の値の組み合わせ」を使うとその出力値を出す入力値にうまく戻せないケースがたくさん出てくるのじゃないのかな、と想像することができるわけです。

学習モデルの係数行列

では、ニューラルネットワークのキモである係数行列はどのような値になっているでしょうか。

> Whi
             [,1]       [,2]       [,3]
[1,]  0.929102250  1.8228347  0.3424140
[2,] -1.958903344 -0.3506342 -0.8565186
[3,]  0.001166766 -1.0417063 -2.2413307
> Woh
           [,1]       [,2]       [,3]
[1,] -8.8318889  16.343451  16.377595
[2,] -0.2582631  -1.045473  -1.541251
[3,]  5.5468657 -19.782169 -19.860237

有効桁数を2桁にして、行列として表現するとこんな感じになります。

W_{hi}=\begin{pmatrix}
0.93&1.82&0.34\\-1.06&-0.35&-0.86\\0.00&-1.04&-2.24\\
\end{pmatrix}

W_{oh}=\begin{pmatrix}-8.83&16.34&16.38\\-0.26&-1.05&-1.54\\5.55&-19.78&-19.86\\
\end{pmatrix}

これを眺めてみると、隠れ-出力の係数行列W_{oh}の2行目の値が1行目、3行目に比べて非常に小さいことがわかります。2行目は出力層のノード2に入っていくデータの係数になりますから、ノード2がとる値は小さくなる、ということがわかります。
また、W_{oh}の1行目は、出力層のノード2と3の値をプラス方向で大きく評価し、3行目はマイナス方向で大きく評価しています。
これでも非常に特徴的な結果が得られていると思われます。

back queryの実装

ということで、学習済みの係数行列を使い、任意の出力ベクトルを算出するための入力ベクトルを求めるback queryをRで実装してみます。学習済みの係数行列W_{hi}W_{oh}はすでにあるものとします。もしまだない場合は前回の記事を参考に、係数行列を作成しておいてください。

wanko-sato.hatenablog.com

ロジット関数と逆行列

すでに示した通り、back queryにはロジスティック関数の逆関数であるロジット関数と、係数行列の逆関数が必要です。
ロジット関数は

logit(x) = ln(\displaystyle \frac{x}{1-x})

であり、Rにおける自然対数は引数なしのlog()関数でいけますから、

logit <- function(x){
  log(x/(1-x))
}

とすればOKです。関数の入力にベクトルをとることもできます。

また、逆行列はsolve()関数で得ることができます。

> solve(Whi)
            [,1]       [,2]        [,3]
[1,]  0.01436237 -0.5035618  0.19462897
[2,]  0.59305101  0.2812723 -0.01688551
[3,] -0.27562564 -0.1309895 -0.43821430

これが本当に逆行列になっているかどうかは

> solve(Whi) %*% Whi
             [,1]          [,2]         [,3]
[1,] 1.000000e+00  0.000000e+00 5.551115e-17
[2,] 8.018014e-17  1.000000e+00 2.775558e-17
[3,] 2.894820e-17 -5.551115e-17 1.000000e+00

で確認できます。非対角要素が微妙に0になってないですが、まぁ大丈夫でしょう。

テスト計算

念のため、


H = W_{oh}^{-1}logit(O)\\
X = W_{hi}^{-1}logit(H)

で本当に出力ベクトルから入力ベクトルが再現できるのかを確認します。
使用する入力データはこれで、

> testData[1,]
[1] 0.2850674 0.1898153 0.2989618 1.0000000

そのときの出力が

> testOut[1,]
[1] 0.9759900310 0.2818141988 0.0003535542

これでした。
で、

# test
outIn <- testOut[1,1:3]
outIn <- try(logit(outIn))
outIn <- solve(Woh) %*% outIn
outIn <- logit(outIn)
outIn <- solve(Whi) %*% outIn

とすると、

> outIn
          [,1]
[1,] 0.2850674
[2,] 0.1898153
[3,] 0.2989618

ちゃんと入力ベクトルに戻すことができています。

実装

実装自体は簡単なのですが、前述した通り、ロジット関数がうまく動かないケースを想定してコーディングする必要があります。
また、念のためきちんと計算できているかどうかを確認するため、計算の途中結果も出力できるようにします。

# create "output" data
vecX <- seq(0.001,0.999,0.001)
inDF <- as.matrix(expand.grid(vecX,vecX,vecX))

「出力データ」は0.001~0.999まで、0.001きざみの確率とし、3変数にそのすべての組み合わせを入れ込みます。

# function for back query
f <- function(x){
  outIn1 <- as.vector(x)
  outIn2 <- logit(outIn1)
  outIn3 <- solve(Woh) %*% outIn2
  if(all(!(outIn<0)) & all(!(outIn3==1))){
    outIn4 <- logit(outIn3)
    outIn5 <- solve(Whi) %*% outIn4
  }else{
    outIn4 <- c(NaN,NaN,NaN)
    outIn5 <- c(NaN,NaN,NaN)
  }
  if(all(is.nan(outIn5))|all(outIn==0)){
    outIn6 <- 0
  }else{
    outIn6 <- 1
  }
  outIn7 <- min(which(outIn1==max(outIn1)))
  return(c(outIn1,outIn2,outIn3,outIn4,outIn5,outIn6,outIn7))
}

all()関数でロジット関数への入力値が0未満でないか、かつ入力値が1でないかをチェックし、問題なければ計算を実行、そうでなければNaNを返すようにしています。わかりやすいように計算が最後まで実行できたものは1、そうでないものは0のフラグを立てるようにしました。

# clear data
test <- c()

# loop function by lapply
test <- lapply(seq_len(nrow(inDF)),function(x){
  x <- inDF[x,]
  out <- f(x)
  return(out)
})

# convert data.frame
test <- do.call(rbind,test)

これでOKです。

> head(test)
     [,1] [,2] [,3]      [,4]     [,5]     [,6]     [,7]      [,8]     [,9] [,10] [,11] [,12] [,13] [,14] [,15] [,16] [,17]
[1,] 0.01 0.01 0.01 -4.595120 -4.59512 -4.59512 1.927025 -5.944878 6.691091   NaN   NaN   NaN   NaN   NaN   NaN     0     1
[2,] 0.02 0.01 0.01 -3.891820 -4.59512 -4.59512 1.760176 -6.179530 6.878220   NaN   NaN   NaN   NaN   NaN   NaN     0     1
[3,] 0.03 0.01 0.01 -3.476099 -4.59512 -4.59512 1.661551 -6.318232 6.988832   NaN   NaN   NaN   NaN   NaN   NaN     0     1
[4,] 0.04 0.01 0.01 -3.178054 -4.59512 -4.59512 1.590844 -6.417673 7.068134   NaN   NaN   NaN   NaN   NaN   NaN     0     1
[5,] 0.05 0.01 0.01 -2.944439 -4.59512 -4.59512 1.535422 -6.495617 7.130292   NaN   NaN   NaN   NaN   NaN   NaN     0     1
[6,] 0.06 0.01 0.01 -2.751535 -4.59512 -4.59512 1.489658 -6.559978 7.181619   NaN   NaN   NaN   NaN   NaN   NaN     0     1

こんな感じのデータになっています。

> c(nrow(test),sum(test[,16]))
[1] 970299   7681

こうしてみると、全970,299通りの組み合わせのうち、最後まで計算できた組み合わせが7,681通り、全体の0.8%しかありませんでした。すでに述べた通り、ノード2が出力する値の範囲が極めて狭いこと、ノード1とノード3のとりうる組み合わせが非常に少ないことが原因として考えられます。

結果の可視化

数字とにらめっこしていてもよくわからないので、可視化してみましょう。
まずは、計算が最後まで可能だった入力データがどんなものだったか見てみましょう。

outPlot <- test[test[,16]==1&!is.infinite(test[,15]),]
rgl::plot3d(outPlot[,1:3],xlim=c(0,1),ylim=c(0,1),zlim=c(0,1),col=outPlot[,17])

※一か所だけinfiniteになっていたところがあったのでそれを除いています。

f:id:wanko_sato:20171104130031p:plain

色分けは、黒がノード1の確率が一番高いとき、赤がノード2、緑がノード3です。確率が同じ出会った場合、ノード番号の一番小さいものをグループ番号としました。
形状としてはちょっとまがった八つ橋みたいな三角形をしていて、厚みのない分布になっています。計算可能な範囲が非常に限られていることがここからもわかります。
次に、復元された入力データをプロットしてみます。

rgl::plot3d(outPlot[,13:15],col=outPlot[,17])

f:id:wanko_sato:20171104130640p:plain

プロットがそもそも想定しているよりも広範囲にわたってしまっているので、範囲を限定して、

rgl::plot3d(outPlot[,13:15],xlim=c(0,1),ylim=c(0,1),zlim=c(0,1),col=outPlot[,17])

f:id:wanko_sato:20171104130843p:plain

これ、ちょっとわかりずらいですが、(1,1,0)と(0,0,1)を結ぶ、やや厚みのある対角面がグループ2の領域になっており、それを挟んで下側がグループ1、上側がグループ3という領域分割になっています。
境界をとるにはグループ1と2またはグループ2と3の確率が等しいときのデータをプロットすれば良いので、

outBound <- outPlot[outPlot[,1]==outPlot[,2]|outPlot[,2]==outPlot[,3],]
rgl::plot3d(outBound[,13:15],xlim=c(0,1),ylim=c(0,1),zlim=c(0,1),col=outBound[,17])

とすれば良く、

f:id:wanko_sato:20171104131454p:plain

このように、決定境界らしきものが描けました。これって要するに出力ベクトルO=(o_{1},o_{2},o_{3})の中で

{\displaystyle
\begin{eqnarray}
  \left\{
    \begin{array}{1}
      o_{1} = o_{2}\\
      o_{2} = o_{3}\\
    \end{array}
  \right.
\end{eqnarray}}

のいずれかを満たす入力ベクトルX=(x_{1},x_{2},x_{3})なわけで、式として書き下すと

{\displaystyle
\begin{eqnarray}
  \left\{
    \begin{array}{1}
      sigm(woh_{11}h_{1} + woh_{12}h_{2} + woh_{13}h_{3}) = sigm(woh_{21}h_{1} + woh_{22}h_{2} + woh_{23}h_{3})\\
      sigm(woh_{21}h_{1} + woh_{22}h_{2} + woh_{23}h_{3}) = sigm(woh_{31}h_{1} + woh_{32}h_{2} + woh_{33}h_{3})\\
    \end{array}
  \right.
\end{eqnarray}}

ただし、

{\displaystyle
\begin{eqnarray}
  \left\{
    \begin{array}{1}
      h_{1} = sigm(whi_{11}x_{1} + whi_{12}x_{2} + whi_{13}x_{3})\\
      h_{2} = sigm(whi_{21}x_{1} + whi_{22}x_{2} + whi_{23}x_{3})\\
      h_{3} = sigm(whi_{31}x_{1} + whi_{32}x_{2} + whi_{33}x_{3})\\
    \end{array}
  \right.
\end{eqnarray}}

ということになるわけです。
さすがにこれを満たすX=(x_{1},x_{2},x_{3})を探索するコードは面倒なので書きませんが、おそらくそれによって決定境界面を描くことができるでしょう。

まとめ

というわけで、至極単純なニューラルネットワークを題材に、その仕組みを再検討してback queryの実装まで行ってみました。非常に単純なモデルとはいえ、興味深い結果が得られたのではないかと思っています。よくよく考えたらニューラルネットワークといえども結局のところはn次元の決定境界を描くものですから、その良い面と悪い面をしっかり把握したうえで使っていければ良いのかな、と思う次第であります。

と、今回、行列だの連立方程式だの、texで書きまくったので、良い勉強になりました。

今回は以上です。