言語解析で使うAttention型の深層学習がメタ学習を示す論文を読む

この論文は汎用翻訳モデルBertで使われるAttentionを使ったRNN型構造の深層学習が問題の構造に依らずメタ学習ができ、高次元のパターン認識強化学習でも驚異的な性能を示したとするICRL2018報告である。

 [1707.03141] A Simple Neural Attentive Meta-Learner

(1)SNAILモデル

 この手法[以下SNAIL:Simple neural attentive lerner)は下図の様なAttention層とCNN層を相互に挟んだRNN構造をしている。CNN層は一定幅のデータの特徴抽出をAttention層は可変長のデータを一定幅の特徴ベクトルに変換する役割を相互で行い、逐次的に投入される入力と出力データからパターンを学習して、入力を条件として出力を予測するものである。左図は教師ありデータの予測で右図は強化学習の行動予測で両者とも同じ構造で実現している。

 このSNAILの深層学習による学習は、全てのTask\tau_iの一連の入力x_tと出力a_tについて以下の損失関数\mathcal{L}_\thetaを最小にする様に調整している。

  min_\theta \mathbb{E}_{\tau_i \sim P(\tau)} [ \sum_{t=0}^{H_i}  \mathcal{L_\theta}_i(x_i,a_i) ] 

  但し

        s_i \sim P_i(x_t |x_{t-1},a_{t-1})  a_t \sim \pi(a_t|x_1,\dots,x_t;\theta)

   \tau_iはtask

           x_tは入力、

           a_tは出力

 ここで云うメタ学習とは同じモデルで異なる課題を学習できる事を示している。Taskの構造が異なっても、正解が判明している場合や学習済みの一連の入出力データをTaskとして生成し、これを逐次的にSNAILに通すことによってTaskが持つデータパターンを学習し、これを繰返す毎にその予測精度を向上させるものである。

f:id:mabonki0725:20190121215508p:plain


(1)アルゴリズム

    下記の①と②のプログラムは上図のネットワークの部分に対応している。ここでは言語解析で使う次の特徴抽出機能でデータを変換している。

 ①は入力を2の乗数毎に分割してCNNで特徴行列を作成している

 ②は自己注意(self-attension)を使って入力の特徴(Attension)を取出している。

     入力をk個で要素に分解した2つのデータに変換し、これらの内積(matmul)を計算して各要素の生起確率(Softmax)を計算した後、一定の数Vでの特徴量に変換している。

  (論文 Attension is All you need[1]参照)

f:id:mabonki0725:20190121215553p:plain

 (3-1) 教師ありデータの識別実験

 手書き文字OmniglotとImageNet画像の分類結果をSNAILに投入して驚異的な識別精度を示すことができている。両課題とも投入条件として、N個の分類毎にK個のデータでK×N個セットを1ブロックとしてランダムにSNAILに投入し繰り返し学習させている。 

・手書文字Omniglotの教師あり識別実験

 OmniglotはLakeが生成モデルによる文字の構造認識[2]に使った下図の様な手書き文字で、これの1200種を学習させ、その結果を教師データとしてSNAILに投入している。

f:id:mabonki0725:20190120175125p:plain

Omniglot

  SNAILによるOmniglotの文字種類の識別予測結果(N-way : Nは分類数 K-shot::Kは分類毎の学習データ数)では全てにおいて最良の結果が得られている。

f:id:mabonki0725:20190121123036p:plain

 ・ImageNetの画像の教師あり識別実験

  ImageNetは下図の様に分類別に画像をダウンロードできる。これもN分類毎にK個のデータでN×KのセットをSNAILへの投入を繰り返し学習させる。

f:id:mabonki0725:20190121132028p:plain

鳥の画像のダウロード画面

  ImageNetのSNAILでの識別結果(N-way : Nは分類数 K-shot::Kは分類毎の学習データ数)では全てにおいて最良の結果が得られている(±は95%信頼区間)。

 

                 f:id:mabonki0725:20190121131929p:plain

(3-2)強化学習での行動予測実験

 以下の4種類の課題について、何れも強化学習済みのTaskが生成する状況と行動をSNAILに投入して、次期の行動を予測する。この課題に使った強化学習は何れもTRPO/GAEモデルである。

    ①複数のスロットマシン(N腕バンディット問題)

          ②タブレット上の移動(省略)

    ③チータとアリの移動(省略)

    ④3D迷路での宝探し

 ①N腕バンディット問題

  K個のスロットマシンでそれぞれ当たる確率が異なる場合、定められた試行回数Nで最大の当たる回数(得点)を求める課題である。

f:id:mabonki0725:20190121145343p:plain

 ベイズモデルのGittenを無限回行うと理論解になる事が判明していて、SNAILでは試行回数が少ない場合はGittenより良い結果を出していることが分る。

f:id:mabonki0725:20190121145518p:plain

各モデルでの得点比較(N:試行回数 K:台数)

  ④3D迷路での宝探し

   下図の左橋図の様に自分視点(first person)の3D迷路と宝は強化学習用のツールであるVizDoomでランダムに生成されたものである。ここでは簡単な迷路と複雑な迷路の2種類でSNAILの性能を試している。ここで報酬は宝を得る(+1)、罰則は一歩毎(-0.01)と壁に当たると(-0.001)としている。

f:id:mabonki0725:20190121152722p:plain

   下表の迷路探索の結果は宝を見つけるまでの平均の歩数である。2つのエピソードの内epsortd2はepsortd1と同じ迷路での探索でepsort1の探索の学習を使っているので短く済んでいる。上図の赤線はepsord1で青線はepsort2で近道をしている事が分かる。

    f:id:mabonki0725:20190121152801p:plain

(4)感想

 全く異なった構造をしている教師付きラベルや強化学習の行動が同じ構造を持つSNAILで学習できる事は驚異的である。しかも高次元の画像でもよい性能を示している。人間の無意識での識別や行動決定は、課題別に解いてはエネルギーを費やすので、同じ様なメタモデルで経験的にパターンを学習して反応していると考えられる。著者達は最近このメタ学習モデルを実際のロボットに適用した報告している[3]。


[1]A.Vasawani et al. [1706.03762] Attention Is All You Need

[2]B.Lake et al. https://cims.nyu.edu/~brenden/LakeEtAl2011CogSci.pdf

[3]A.Nagabandi et al.[1803.11347] Learning to Adapt in Dynamic, Real-World Environments Through Meta-Reinforcement Learning


: