因果関係を捉える強化学習の論文を読む

因果推論では2つの流派(ルービンとパール)があり、同じ因果を扱っているが方法が異なるので混乱してしまう。ルービンはスコア法に代表される因果推論であり、パールはベイジアンネットに代表される非巡回有向図(DAG:Directed Acyclic Graph)を用いる因果推論となっている。下記の記事はさらに心理学(キャンベル)を追加した区分について記述している。

統計学における因果推論(ルービンの因果モデル) – 医療政策学×医療経済学

 

機械学習での因果推論では、データから因果を推定する次の様な方法があるが、介入という操作を使えず本当の因果は判定できないものである。

 (i) データの3次以上のモーメントを使う独立成分分析で因果方向を推定する方法 

       http://www.padoc.info/doc/kanoIca.pdf

  http://padoc.info/doc/sas2015_bn_structer.pdf

  但しデータの高次のモーメントを使うのでデータにノイズがあると正確に判断できない。

(ii)因果の向きをMDL(最小記述)尤度で判定してベイジアンネットを生成する方法

  http://padoc.info/doc/jsai2015BayesNetNoname.pdf

 次に示す論文のintervention(介入)実験に示した様に最尤法による因果の方向の推定は介入より劣っている事が示されている。ここにデータから生成するベイジアンネットの限界があると思われる。

 

今回読んだ論文はDAG上でエージェントが様々な介入を行ってその結果から効率的に因果関係を把握する学習についての論文である。

[1901.08162] Causal Reasoning from Meta-reinforcement Learning

強化学習についてはメタ学習Learn to Learn[1]を強化学習に適用したメタ強化学習[2]を使っている。学習はActor-Critic法だがbaselineやsoftmaxのパラメータはメタ学習で行っている。具体的な方法については論文中に殆ど記述がない。

 

(1)モデル

   (1-1) 対象ノード数

  対象となるDAGのノード数を5個に限定している。これはノード数が増えるとDAGグラフの数は爆発して計算不能になるためである。例えばノード3個の場合下図の様に11通りのDAGになる。ノードが10個だと4.8億通りDAGとなり事実上計算不能になる。

   f:id:mabonki0725:20190318093148p:plain

  従って本論文では5個のノードを有する58749通りのDAGがを対象としている。

  (このうち300個は試験用でその他の殆どは学習用にしている)

  (1-2) 生成されたDAGの各ノードの値

     ランダムに生成された5ノードのDAGでは、rootノードの情報は隠されているが、それ以外の親ノードを持たないノードは標準偏差0.1の乱数で値が設定されている。各矢印の重みw_{ij}には{-1,0,1}の何れかがランダムに設定されている。即ちノードの値は以下の乱数で設定されている。

 親ノードが無い場合

       p(X_i) = \mathcal{N}(\mu=0.0,\sigma=0.1)

    親ノードがある場合

  p(X_i|pa(X_i) ) =\mathcal{N} (\mu = \sum_j w_{ji} X_j,\sigma=0.1)  

     pa(X_i)はノードX_iの親ノード全部を示す

 

 (1-3) DAGからの情報

  rootノード以外の各ノードの値とノード間の相関係数はエージェントに伝えられる。しかし連結状態はエージェントには分からない。 

 

    (1-4) 介入(intervention)とは

  選択したノードに強制的に値を設定して、下図の様に介入したノードが因果の影響を与える全てのノードの値を変更する。下図には示されていないが変更されたノードはさらに因果の方向に変更を伝播させる。但し介入ノードは独立となり親ノードの影響を受けなくなる。

  

f:id:mabonki0725:20190318203618p:plain

介入の例(E=e)

 (1-5) DAGの学習方法

  この論文ではDAGの構造理解への学習フェーズは2過程に分かれている。

  (i) Information Phase (状況提供フェーズ)

        エージェントはノード数の数だけ介入しノードに値(5点)を設定することができる。この介入操作でDAGの各ノードの値は因果の方向に沿って伝播して変更されれ、エージェントはこの結果を収集することができる。

    (ii) Quiz Phase (問合せフェーズ)

   上記の情報フェーズ後、エージェントは影響が一番大きい(相関が一番大きい)ノードに(-5点)設定して、その負の介入によりDAG内で最大の値を持つノードを選択してこれを報酬とする。下図はDAG構造を推定して最大値(0.0)のノードを報酬とした例である。

  学習は上記の2フェーズをランダムに生成したDAGで報酬を得て、その報酬よりActor-CriticでQ関数を改善するが、パラメターはメタ学習で行う。

          f:id:mabonki0725:20190318153953p:plain

  (2) 実験

   以下の3種の実験を行い、どの様に観察するとエージェントが高い報酬を得られるか試行している。

   (i)  Information Phaseで介入操作が行われない場合(Observasional:観察のみ)

   (ii) Information Phaseで介入操作が行われる場合(Interventional:介入操作)

   (iii) Information Phase で介入が行われ、その値に摂動を与える場合(counterfactual:反証的操作)

 (2-1 ) Observasional実験

 Information Phaseで介入は行われないので、Quiz Phaseでの介入だけでの観察となる。

    この実験でのエージェントの観察方法

  Passive-Cond:5点のノードが見つかれば、そこから低い値の近傍へ因果の矢印があると解釈する

  Optimal-Assoc:互いに相関係数が高いノードが関連していると解釈する

        Obs.Map:全てのノードの因果方向について最尤となる構造を仮定する。

       これはデータからのベイジアンネットの推定と同じ方法

 図(c)の様にPassive-condの方が因果方向を捉えているのでQuiz Phaseではより高いノード(0点)を選択できている。また図(a)より最尤法の解釈は曖昧であることが分る。図(b)は親ノードが無い場合(Oprhan)は高い値を持つノードが判断しやすい事を示している。

 

f:id:mabonki0725:20190318210712p:plain

  (2-2)Interventional実験

 Information Phaseでは介入操作が行われ、その結果はエージェントが観測できる。

 この実験でのエージェントの観察方法

  Passive-Int:介入操作によって大きく値が変わったノードへ因果矢印があると解釈する

  int. Map:介入操作によって変化した状態を満たし最尤となる構造を推定する

 図(c)にある様に介入操作で矢印関係を正確に掴んでQuiz Phaseでノードを選択しているが少ない値を選んでいる。図(a)では介入による解釈が有意である事を示している。

f:id:mabonki0725:20190318211708p:plain

  (2-3) Counterfactual実験

  介入以外に反証的な操作として有り得ない値を設定している。
  この実験でのエージェントの観察方法

  Passive-CF:一連の介入と観察の最後にさらに+5を加点した介入をして観察して因果方向を解釈している

        Optimal-CF:反証的な介入での観察から最尤法によるDAGの解釈

 図(c)の左では介入だけでは同じ値で最大値を選択できていないが、反証的操作によって差ができて最大値ノードを選択できている。

f:id:mabonki0725:20190318212801p:plain

 

  (3) 感想

 一般に因果の方向を正しく推定するには介入や反証的介入が必要とされているが、項目数が多いと介入実験が複雑になってくる。この論文はこれをロボットに置き換えて自動化を試みたものといえる。しかしロボットを使った問題として介入毎に全ノードの観察データを取得してロボットに示さなければならず、またノードが増えた場合にはより膨大な処理時間が必要となると思われる。しかし正確な因果をロボットで試みるのは新しい考え方ではある。

  

[1][1606.04474] Learning to learn by gradient descent by gradient descent

[2][1611.05763] Learning to reinforcement learn