mamba.jlとMCMCを巡る旅

2017年Julia Advent Calendar 9日目の記事です。

Juliaとベイズ統計モデリング、mamba.jlに同時入門した先日。 先人の資料の助けもあってPkd.add("Mamba")に次いで、jupyterとijulia上でusing Mambaがスムーズに動いたとき、 それまでWindows上のPyStan(+Visual Studio)の構築などで散々イライラしていた私は、安堵感と共に、 「お、これは幸先良いで。締切まで1週間以上あるし、信頼性工学を切り口にLife Data Analysisの信頼性区間くらいは、 さっさと出して、CalenderではTest Designの話でも書くかな」なんて呑気なことを考えながら手始めに、 線形回帰でのTutorialに取り組んでいました。

そう、これが今回の始まりです…。そして、非常に恐縮ながら途中まで…という状況で、旅路の記録になっています。

旅の抜け道をご教示いただけるなら、それも助かります!!!

環境

まず、動作環境をご紹介します. Forkが使えないWindowsを使い続けているのはお約束です。

  • Windows 10 - 64bit
  • Julia 0.6.1
  • mamba 0.11.1

なぜWindowsにこだわるのか、なぜWindows subsystem for Linuxを使わないのか、そこは今日も触れないでおきましょう。

線形回帰のおさらい

線形回帰が収束しない。これだけである。

アヒル本の第4章のデータRStanBook/data-salary.txt では、$ y \sim Normal(a+bx, sigma) $ のモデルにおいて、aとbの事後平均は、それぞれ、-118, 22ほどです。 これは大体最尤法の結果と同じ。そして、sigmaは85程度。

これに対して、前回の記事では、a,b,sigma は、それぞれ、-113, 22, 33で、a,bは許せるもののsigmaはプロットを見ても収束していませんでした。

元のIteration数が2000, Burnin(warm-up)が1000と小さめでしたので(とはいうものの、この例題でアヒル本は2000程度で上記の結果)、Samplerの性能の違いなどへの想像をしながらIterationを8000, Burninを3000へ伸ばしました。

Iterations = 3001:8000
Thinning interval = 1
Chains = 1,2,3,4
Samples per chain = 5000

Gelman, Rubin, and Brooks Diagnostic:
            PSRF 97.5%
        s2 1.000 1.000
        b 1.001 1.003
        a 1.002 1.003
Multivariate 1.001   NaN

収束は、していそうですね…

プロットはそこそこ雰囲気よさそうです

しかし… describe(sim).

Empirical Posterior Estimates:
    Mean         SD         Naive SE       MCSE         ESS   
s2 6986.028670 2592.8494023 18.3342139498 29.570409665 5000.00000
b   21.064966    1.2315484  0.0087083624  0.038741347 1010.54010
a  -80.859994   53.7228503  0.3798779177  1.789205605  901.56614

あれ?a、小さくないですか?!a,b,sigma、それぞれ、-80, 21, 84ですよ?

真実を求めて

Rとの比較

Iterationや初期値、modelの定義をとっかえひっかえしたりするも、事態が好転しないうちに、ひょっとしてmamba.jlのIssueに…と思いはじめました。

調べます。

ありました。linear regression example not giving reasonable results · Issue #120 · brian-j-smith/Mamba.jl

しかも、今年の7月から、まだOpenです。

After some digging it looks like the NUTS epsilon value is not getting set correctly in some cases. I will investigate further.

初期値でなんとかなる?全然納得できません。だいたい、単回帰程度のモデルで初期値に過敏に反応するようでは実務上お話になりません。NUTSの実装で $ \epsilon $ が悪い?よくわからないっす。

ここで、とっかえひっかえした経験から代表的な結果を挙げます。「おかしい」と感じる数字に*を打ってあります。mamba大丈夫ですか…?

条件 a(切片) b(傾き) sigma(残差偏差)
R glm(Ground Truth扱い) -119.7 21.9 79.1
R Stan(細かい設定なし⇔[初期値不明; Sampler NUTS; Iteration 2000]) -119 22 85
mamba.jl - 初期値~ Normal(Ground Truth, 1.0); Sampler Scheme1; Iteration 2000 -80* 20 83
mamba.jl - 初期値 a,b~Normal(0, 1), sigma~Uniform(80,85); Sampler Scheme2; Iteration 8000 -81* 21 84
mamba.jl - はじめてのmamba -113 22 33*

Ground Truth扱いの値を出すRのコード

df<-read.csv("./data-salary.txt")
data <- list(N=nrow(df),X=df$X, Y=df$Y)
ans<-lm(data$Y~data$X)
ans
summary(ans)
sigma(ans)

Rも一応初期値いじってみました(やりかたあってます?)

fit <- stan(file="./stan45.stan",
            data=data,
            seed=1234,
            verbose=TRUE,
            init=function(){ list(a=rnorm(1,0,1), b=rnorm(1,0,1), sigma=rgamma(1,1))},
            chains=4, iter=2000, warmup=1000, thin=1
)
data {
    int N;
    real X[N];
    real Y[N];
}

parameters {
    real a;
    real b;
    real<lower=0> sigma;
}

model {
    for (n in 1:N) {
        Y[n] ~ normal(a + b*X[n], sigma);
    }
}

mamba.jlでのModelとSamplerの定義(s2ではなくてsigmaにしました)

model = Model(
    y = Stochastic(1,
        (a, b, x, sigma) -> MvNormal(a + b*x, sigma),
        false
    ),
    b = Stochastic(() -> Normal(0, 100)),
    a = Stochastic(() -> Normal(0, 100)),
    sigma = Stochastic(() -> InverseGamma(0.01, 0.01))
)
scheme1 = [NUTS([:a,:b]),Slice(:sigma,10)]
scheme2 = [NUTS([:a,:b,:sigma])]

Tutorialの確認

そして、疑いの目でMambaのTutorial — Mamba.jl 0.10.1 documentationも見てみます。

条件 a(切片) b(傾き) sigma(残差偏差)
R glm(Ground Truth扱い) 0.6 0.8 0.73
mamba.jl Tutorial(Iteration 5000の後) 0.6 0.8 1.08(⇔s2=1.18)

s2はカイ二乗分布に従うので自由度5くらいの例だと言いにくいですが、こちらの残差分散も怪しそうです。私が使い方を理解していない…ということ以上のことが起きているのが、いよいよ濃厚に感じられてきました。 Mamba.jlのMCMCに何か起きていそうです。ライブラリの中に潜るしかなさそうな様子です。

次ステップの逡巡

ここで、私にはいくつかの選択肢がありました。カレンダー締め切りが迫る中、妻の協力を得つつ、検証時間を確保した私は考えます。

ベイズ統計初学者、MCMCに至っては入門すらしていない私は、ろくな初期化もせず、2000やそこいらのIterationでGLMに近い答えを出してしまうRStanがすごすぎるのか、mamba.jlの熟成が足りておらず、何らかの間違いを含んでいるのか見極めないと何も言えません。 RStanには、SamplerにNUTSを採用していることが書かれており、 バージョンもいくつかありそうなことがわかります。 NUTSのパラメータをRStan, Mambaで比べてみると、若干違うことも気になります。

algorithm One of sampling algorithms that are implemented in Stan. Current options are "NUTS" (No-U-Turn sampler, Hoffman and Gelman 2011, Betancourt 2017), "HMC" (static HMC), or "Fixed_param". The default and preferred algorithm is "NUTS".

サーベイを軽くしながら、思い浮かんだことリスト。

  • とりあえず結果を…型
    • 事前分布をもっと広げて、うすーくしたり、事後分布とほぼ1:1にしてしまったりする
    • mamba.jlの他のSamplerをチューンして、とりあえずGLMに近い結果が出るまで、ガチャガチャを回す
    • mamba.jlで$ sigma = \frac{10000}{1+exp(-dummySigma)} $ を用意して下限値を設定する. またはそれに準じたmodel定義を.
  • 動作状況確認型
    • StanのSampler実装を読み込んで、mamba.jlのNUTS.jlと比べる.
    • いや、そもそもSamplerじゃないじゃないかもしれないじゃないか。ひょっとして事後確率分布の計算とかで違ってる?
    • NUTS(No-U-Turn Sampler)を自分で組んでみて比較する
      • mamba.jlのユーザー定義Samplerに自分で実装
      • MCMCも含めてjuliaで書いてみる?
  • 基礎テスト実行型
    • 単回帰よりさらに単純な条件での評価

MCMCへダイブ

次第に苦しくなってきましたが、結局、NUTSを自分で組んで確認する方向を取ってみることにしました。CourseraでNg先生の機械学習で線形回帰の機械学習やNNを組んだように、 単回帰に特化すればMCMCとて意外と書けるかもしれません。

どこから始めるか

NUTS(No-U-Turn Sampler)は、HMC(Hamiltonian Monte Calro)を改良したもので、HMCは、Metropolis-Hasting Algorithmの提案関数を効率的にしたものということです。 幸いなことに、MH法はWikipediaもあるくらいですし、 ここから始めましょう。

ランダムウォークMH法

MH法の提案分布のうち、正規分布のような対称な分布を使ものを、Random-walk MH algorithm(ランダムウォークMH法)と呼ぶとのこと。日本語のWikipediaではメトロポリス・アルゴリズムと呼んでいるようです。 実装はMITの資料を見つけたのでこちらにならいます。

まず、概要は非常にスッキリ。こちらのランダムウォークの部分、そして、alphaの計算がNUTSでは複雑になっていくという心構えで見ていけば、よさそうな感触がつかめます。コメントを付けた通り、ランダムウォークMH方では、 提案分布が対称であるため、alphaの計算がシンプルにできます。

#Random-walk MH法
# モデルはy ~ N(a+b*x, sigma)
# a, b, sigmaの初期値; Chain 1つ分
Init = Real[-200, 1, 1] 
Iteration = 90000
BurnIn    = 10000
chain = Array[Init]

for i in 1:Iteration
    # 前回のa,b,sigmaの推定値からランダムウォークで次の推定値候補を引き出す
    # chain[i][1] : a
    # chain[i][2] : b
    # chain[i][3] : sigma
    # 戻り値は配列値で[aの候補, bの候補, sigmaの候補]
    proposal_draw   = proposal(chain[i][1], chain[i][2], chain[i][3])
    
    # 提案分布が対象であることから、MITの資料 数式(1)によってalphaの計算が非常に簡単になる
    # posteriorは目標分布の確率密度関数(の対数値)
    tmp1 = posterior(proposal_draw[1],proposal_draw[2],proposal_draw[3])
    tmp2 = posterior(chain[i][1], chain[i][2], chain[i][3])
    alpha = min(1,exp(tmp1-tmp2))
    
    # 採択するかどうか決める!
    u = rand(Uniform(0,1))
    if u < alpha
        push!(chain,proposal_draw)
    else
        push!(chain,chain[i])
    end
end

# BurnInは除外
chain = chain[BurnIn+1:end]
# 結果の取り出し
a = [chain[i][1] for i in 1:length(chain)]
b = [chain[i][2] for i in 1:length(chain)]
sigma = [chain[i][3] for i in 1:length(chain)]

println("MCMC complete")

初期値[a,b,sigma]=[-100, 1, 1]; Iteration = 200000; BurnIn = 90000など、いくつか特徴を見てみましたが、切片と残差偏差が、トレードオフのような関係にあってなかなかGround Truthに近づきません。 たとえば先述の条件では(a, b, sigma)=(-70, 20.9, 85.2)でした。 mamba.jlの方に近い!!

こちらのコードはGithubにてご紹介しています → Mamba.jl_Practice/Random-walk MH.ipynb at 4ce5b1fd3069cbf0ac6bc313489e62dad2ac0625 · Chachay/Mamba.jl_Practice

お詫びと感想

今回はMCMCおよびベイズ統計の勉強が足りず、みなさまにお見苦しい記事をお見せしてしまいました。

ただ、Random-walk MH法は是非動かしていただきたいのですが、Pythonでは考えられないほどのIteration速度で、ただ、ただ、驚くばかりです。これぞJuliaを使う素晴らしさでしょう。 しかしながら、PythonやRと違って、開発環境やパッケージの洗練、コミュニティはの厚みには、 もう少し時間がかかりそうなこともも感じました。 エラーの発生時も箇所の特定に少しコツがいりますし、mamba.jl 1つとっても使用例がやはり少ないです。 これから、みなさまとご一緒できればと思いますので、暖かく見守っていただけると幸いです。

続きについて

HMCとNUTSは続きでやっていきます。今回は時間切れで、オチなし…ということでした。完全に見積もりミスです。すみません。 次回は旅行記録ではなく、報告形式で!

続き

一旦のケリをつけたのでリンク。

No comments:

Post a Comment