メインコンテンツへスキップ

ベイズ推論による多次元ガウス分布の学習

·2698 文字·6 分
目次

はじめに
#

「ベイズ推論による機械学習入門」を読んだので、ベイズ推論(ベイズ推定)への理解を深めるため、多次元ガウス分布の学習をPythonで実装した。 参考にしたのは、講談社 機械学習スタートアップシリーズの「ベイズ推論による機械学習入門」(須山敦志 著)。3.4節「多次元ガウス分布の学習と予測」から、平均と精度(分散共分散行列)が共に未知の場合における学習について実装した。 また、学習したパラメータを用いて、未観測データを予測するための分布(予測分布)も構築した。

なお、以下のブログに離散確率分布(ベルヌーイ分布・カテゴリ分布・ポアソン分布)と1次元ガウス分布の学習の実装例があったため、併せて参考にさせて頂いた。

「ベイズ推論による機械学習入門」を読んだので実験してみた (その1)

環境
#

ソフトウェア バージョン
python 3.6.5
numpy 1.14.3
scipy 1.1.0
matplotlib 2.2.2

以降では、各ライブラリを以下のようにインポートしていることを前提とする。

import math
import numpy as np
import scipy.stats
import matplotlib.pyplot as plt

ベイズ学習について
#

ベイズ学習は、観測データと未知パラメータに対する同時確率分布を構築し、観測データが得られたときの未知パラメータの事後分布を求める手法である。 ここでは、多次元ガウス分布の平均と精度が未知パラメータとなる。

多次元ガウス分布のベイズ推論
#

\(D\)次元の多次元ガウス分布は、次式で表される。

$$ \mathcal{N}(x|\mu, \Sigma) = \frac{1}{\sqrt{(2\pi)^D|\Sigma|}} \exp \biggl( -\frac{1}{2} (x-\mu)^\top \Sigma^{-1} (x-\mu) \biggl) $$

ここで、\( \mu \in \mathbb{R}^D\) は平均、\( \Sigma \in \mathbb{R}^{D \times D}\)は分散共分散行列である。 ただし、\(\Sigma\)は正定値行列(固有値が全て非負)でなければならない。 後々の数式を簡単にするため、精度行列\(\Lambda = \Sigma^{-1}\)を導入する。 \(\mu, \Lambda\)が推定したいパラメータになる。

\(\mu, \Lambda\)の確率分布を表現する共役事前分布は、ガウス・ウィシャート分布となる。

$$ \begin{array}{rl} p(\mu, \Lambda) &=& NW(\mu, \Lambda | m, \beta, \nu, W) \\\ &=& \mathcal{N}(\mu | m, (\beta \Lambda)^{-1}) \mathcal{W}(\Lambda | \nu, W) \end{array} $$

ここで、\(m, \beta, \nu, W\)はガウス・ウィシャート分布のパラメータである。初期値は以下の条件を満たすように適当に与える。

  • \(m \in \mathcal{R}^{D}\): 実数ベクトル
  • \(\beta \in \mathcal{R}\): 実数
  • \(\nu \in \mathcal{R}\): \(\nu > D-1\)を満たす実数
  • \(W \in \mathcal{R}^{D \times D}\): 正定値行列(固有値が全て非負)

事後分布を計算すると、ガウス・ウィシャート分布のパラメータはそれぞれ以下のように与えられる(詳細は本を参照)。

$$ \hat{\beta} = N + \beta $$

$$ \hat{m} = \frac{1}{\hat{\beta}} \left( \sum_{n=1}^N x_n + \beta m \right) $$

$$ \hat{W}^{-1} = \sum_{n=1}^N x_n x_n^{\top} + \beta mm^{\top} - \hat{\beta} \hat{m} \hat{m}^{\top} + W^{-1} $$

$$ \hat{\nu} = N + \nu $$

学習したガウス・ウィシャート分布のパラメータを使って、未観測のデータ\(x\)を予測する。予測分布は\(x\in \mathbb{R}^D\)上の多次元版のスチューデントのt分布となる。

$$ \mathrm{St} (x|\mu_s, \Lambda_s, \nu_s) = \frac{\Gamma( \frac{\nu_s+D}{2}) }{\Gamma( \frac{\nu_s}{2})} \frac{|\Lambda_s|^{\frac{1}{2}}}{(\pi \nu_s)^{\frac{D}{2}}} \biggl( 1+\frac{1}{\nu_s} (x-\mu_s)^{\top} \Lambda_s (x-\mu_s) \biggl)^{-\frac{\nu_s +D}{2} } $$

ここで、スチューデントのt分布のパラメータは、ガウス・ウィシャート分布のパラメータを使って次式で与えられる。

$$ \mu_s = m $$

$$ \Lambda_s = \frac{(1-D+\nu)\beta}{1+\beta}W $$

$$ \nu_s = 1-D+\nu $$

また、\(\Gamma(\bullet)\)はガンマ関数と呼ばれる関数である。

学習が進むにつれて、スチューデントのt分布の形状は、元の多次元ガウス分布の形状に近づいていく。

実装
#

ガウス・ウィシャート分布のパラメータ推定
#

観測データXから、ガウス・ウィシャート分布のパラメータの推定値\(\hat{m}, \hat{\beta}, \hat{\nu}, \hat{W}\)を推定する関数を以下のように実装する。 ただし、計算効率は重視せず、数式通りに実装することを優先している。

def multivariate_normal_fit(X):
    N = X.shape[0] # Number of samples
    D = X.shape[1] # Dimension of sample
    
    beta  = 1
    m     = np.zeros(D)
    W_inv = np.linalg.inv(np.diag(np.ones(D)))
    nu    = D
    
    beta_hat = N + beta
    m_hat    = (X.sum(axis=0)+beta*m)/beta_hat
    
    X_sum = np.zeros([D, D])
    for i in range(N):
        X_sum += np.dot(X[i].reshape(-1,1), X[i].reshape(1,-1))
    
    W_hat_inv = X_sum + beta*np.dot(m.reshape(-1,1), m.reshape(1,-1)) \
            - beta_hat*np.dot(m_hat.reshape(-1,1), m_hat.reshape(1,-1)) + W_inv
    nu_hat = N + nu
    
    return m_hat, beta_hat, nu_hat, W_hat_inv

多次元版のスチューデントのt分布
#

学習後の確率分布を確認するため、多次元版のスチューデントのt分布をクラスとして実装する。 確率密度関数 (Probability Density Function, PDF) を求めるため、pdfメソッドを用意した。pdfメソッドに配列を引数として与えると、その配列に対応する確率を返す。

class multivariate_student_t():
    def __init__(self, mu, lam, nu):
        # mu: D size array, lam: DxD matrix, nu: scalar
        self.D   = mu.shape[0]
        self.mu  = mu
        self.lam = lam
        self.nu  = nu
        
    def pdf(self, x):
        temp1 = np.exp( math.lgamma((self.nu+self.D)/2) - math.lgamma(self.nu/2) )
        temp2 = np.sqrt(np.linalg.det(self.lam)) / (np.pi*self.nu)**(self.D/2) 
        
        if x.shape[0]==1:
            temp3 = 1 + np.dot(np.dot((x-self.mu).T, self.lam),  x-self.mu)/self.nu
        else:
            temp3 = []
            for a in x:
                temp3 += [1 + np.dot(np.dot((a-self.mu).T, self.lam),  a-self.mu)/self.nu]
        
        temp4 = -(self.nu+self.D)/2
        return temp1*temp2*(np.array(temp3)**temp4)

ここで、ガンマ関数の自然対数を返すmath.lgammaで実装した。 ガンマ関数math.gammaは大きな値を取り得ることがあり、以下のようにオーバーフローが生じる場合があるためである。

>>> math.gamma(200)
Traceback (most recent call last):

  File "<ipython-input-31-4fa9aaaad750>", line 1, in <module>
    math.gamma(200)

OverflowError: math range error

パラメータの学習
#

学習の結果を確認する。図示できるように、データの次元は\(D=2\)とする。 まず、多次元ガウス分布に従うサンプルデータを生成する。 ここで、データの平均は\((x_1, x_2)=(0, 1)\)であり、正の相関を持つ。

np.random.seed(0)

mean = np.array([0, 1])
cov  = np.array([[2, 1],
                 [1, 2]])
Ns   = 100                # Number of samples

X = np.random.multivariate_normal(mean, cov, Ns) # Sample data

サンプルデータを散布図にプロットする。

fig, ax = plt.subplots(figsize=(8, 4))
ax.scatter(X[:,0], X[:,1])
ax.axis('square')
ax.set_xlim(-5,5)
ax.set_ylim(-5,5)
ax.grid()
ax.set_xlabel("x1")
ax.set_ylabel("x2")
fig.tight_layout()
plt.show()

scatter_2d_normal_dist
scatter_2d_normal_dist

次に、関数multivariate_normal_fitから、ガウス・ウィシャート分布のパラメータを求める。

m_hat, beta_hat, nu_hat, W_hat_inv = multivariate_normal_fit(X)

得られたパラメータをスチューデントのt分布のパラメータに変換し、 multivariate_student_tオブジェクトを作成する。

D       = m_hat.shape[0]
mu_hat  = m_hat
lam_hat = (1-D+nu_hat)*beta_hat*np.linalg.inv(W_hat_inv) / (1+beta_hat) 
nu_hat  = 1 - D + nu_hat

mt = multivariate_student_t(mu_hat, lam_hat, nu_hat)

最後に、元のガウス分布の形状と、推定したスチューデントのt分布の形状を比較する。 両確率分布の確率を、x1, x2とも-5~5の範囲で求める。

X1, X2 = np.meshgrid(np.arange(-5, 5, 0.1), np.arange(-5, 5, 0.1))
Y = np.vstack([X1.ravel(), X2.ravel()]).T

mn_pdf = scipy.stats.multivariate_normal.pdf(Y, mean=mean, cov=cov)
mn_pdf = mn_pdf.reshape(X1.shape[0], -1)

mt_pdf = mt.pdf(Y)
mt_pdf = mt_pdf.reshape(X1.shape[0], -1)

これらをヒートマップに表示する。色が濃いほど確率が高いことを表す。 このように、推定した確率密度関数と、元の確率密度関数はほぼ一致している。

fig, ax = plt.subplots(ncols=2, figsize=(10, 4))
ax0 = ax[0].pcolor(X1, X2, mn_pdf, cmap="Blues", vmin=0, vmax=0.1)
ax1 = ax[1].pcolor(X1, X2, mt_pdf, cmap="Blues", vmin=0, vmax=0.1)
for i in range(2):
    ax[i].axis('equal')
    ax[i].grid()
    ax[i].set_xlabel("x1")
    ax[i].set_ylabel("x2")
ax[0].set_title("Original PDF")
ax[1].set_title("Inferred PDF")
plt.colorbar(ax=ax[0], mappable=ax0)
plt.colorbar(ax=ax[1], mappable=ax1)
fig.tight_layout()
plt.show()

pdf_heatmap
pdf_heatmap

以上をまとめたコードは以下の通り。

import math
import numpy as np
import scipy.stats
import matplotlib.pyplot as plt

class multivariate_student_t():
    def __init__(self, mu, lam, nu):
        # mu: D size array, lam: DxD matrix, nu: scalar
        self.D   = mu.shape[0]
        self.mu  = mu
        self.lam = lam
        self.nu  = nu
        
    def pdf(self, x):
        temp1 = np.exp( math.lgamma((self.nu+self.D)/2) - math.lgamma(self.nu/2) )
        temp2 = np.sqrt(np.linalg.det(self.lam)) / (np.pi*self.nu)**(self.D/2) 
        
        if x.shape[0]==1:
            temp3 = 1 + np.dot(np.dot((x-self.mu).T, self.lam),  x-self.mu)/self.nu
        else:
            temp3 = []
            for a in x:
                temp3 += [1 + np.dot(np.dot((a-self.mu).T, self.lam),  a-self.mu)/self.nu]
        
        temp4 = -(self.nu+self.D)/2
        return temp1*temp2*(np.array(temp3)**temp4)

def multivariate_normal_fit(X):
    N = X.shape[0] # Number of samples
    D = X.shape[1] # Dimension of sample
    
    beta  = 1
    m     = np.zeros(D)
    W_inv = np.linalg.inv(np.diag(np.ones(D)))
    nu    = D
    
    beta_hat = N + beta
    m_hat    = (X.sum(axis=0)+beta*m)/beta_hat
    
    X_sum = np.zeros([D, D])
    for i in range(N):
        X_sum += np.dot(X[i].reshape(-1,1), X[i].reshape(1,-1))
    
    W_hat_inv = X_sum + beta*np.dot(m.reshape(-1,1), m.reshape(1,-1)) \
            - beta_hat*np.dot(m_hat.reshape(-1,1), m_hat.reshape(1,-1)) + W_inv
    nu_hat = N + nu
    
    return m_hat, beta_hat, nu_hat, W_hat_inv

if __name__=="__main__":
    np.random.seed(0)
    
    mean = np.array([0, 1])
    cov  = np.array([[2, 1],
                     [1, 2]])
    Ns   = 100                # Number of samples
    
    X = np.random.multivariate_normal(mean, cov, Ns) # Sample data
    
    fig, ax = plt.subplots(figsize=(8, 4))
    ax.scatter(X[:,0], X[:,1])
    ax.axis('square')
    ax.set_xlim(-5,5)
    ax.set_ylim(-5,5)
    ax.grid()
    ax.set_xlabel("x1")
    ax.set_ylabel("x2")
    fig.tight_layout()
    plt.show()
    
    m_hat, beta_hat, nu_hat, W_hat_inv = multivariate_normal_fit(X)
    
    D       = m_hat.shape[0]
    mu_hat  = m_hat
    lam_hat = (1-D+nu_hat)*beta_hat*np.linalg.inv(W_hat_inv) / (1+beta_hat) 
    nu_hat  = 1 - D + nu_hat
    
    mt = multivariate_student_t(mu_hat, lam_hat, nu_hat)
    
    X1, X2 = np.meshgrid(np.arange(-5, 5, 0.1), np.arange(-5, 5, 0.1))
    Y = np.vstack([X1.ravel(), X2.ravel()]).T
    
    mn_pdf = scipy.stats.multivariate_normal.pdf(Y, mean=mean, cov=cov)
    mn_pdf = mn_pdf.reshape(X1.shape[0], -1)
    
    mt_pdf = mt.pdf(Y)
    mt_pdf = mt_pdf.reshape(X1.shape[0], -1)
    
    fig, ax = plt.subplots(ncols=2, figsize=(10, 4))
    ax0 = ax[0].pcolor(X1, X2, mn_pdf, cmap="Blues", vmin=0, vmax=0.1)
    ax1 = ax[1].pcolor(X1, X2, mt_pdf, cmap="Blues", vmin=0, vmax=0.1)
    for i in range(2):
        ax[i].axis('equal')
        ax[i].grid()
        ax[i].set_xlabel("x1")
        ax[i].set_ylabel("x2")
    ax[0].set_title("Original PDF")
    ax[1].set_title("Inferred PDF")
    plt.colorbar(ax=ax[0], mappable=ax0)
    plt.colorbar(ax=ax[1], mappable=ax1)
    fig.tight_layout()
    plt.show()

また、学習データのサンプル数Nsを5, 10, 100と変えて、推定精度に与える影響を調べる。下図のように、Nsが増えるほど、元の確率密度分布(左上)に近づいている。

pdf_heatmap_2
pdf_heatmap_2

Helve
著者
Helve
関西在住、電機メーカ勤務のエンジニア。X(旧Twitter)で新着記事を配信中です

関連記事

ChainerのIteratorクラスによる学習用ミニバッチ作成
·2129 文字·5 分
データセットから学習用ミニバッチを作成してくれるIteratorクラスの動作を確認する。
Chainer入門 最小限のニューラルネットワーク実装
·1600 文字·4 分
ディープラーニング用のライブラリChainerの使い方を理解するため、ChainerのChainクラスとOptimizerを使って最小限のニューラルネットワーク (NN) を実装する。
SciPyを使ったFIRフィルタによる波形整形
·1247 文字·3 分
SciPyを使って、FIR (Finite Impulse Response, 有限インパルス応答) フィルタによる離散信号の波形を整形する。ローパス、ハイパス、バンドパス、バンドエリミネイトの各フィルタの設計から、信号への適用まで行う。
NumPyのemptyで空(長さ0)の配列を作る
·1419 文字·3 分
NumPyのempty関数を用いて、空の(要素を持たない)任意の次元の配列を作成できる。本記事では、空の配列の作り方、使い方について簡単に考察する。
NumPyを使った高速フーリエ変換による周波数解析
·1276 文字·3 分
NumPyのfftパッケージを使って、FFT (Fast Fourier Transform, 高速フーリエ変換) による離散信号の周波数解析を行い、信号の振幅を求める。
Matplotlib スタイルによるグラフの見た目の変更
·990 文字·2 分
matplotlibにはスタイルと呼ばれる、グラフの見た目を変更できるプリセットが用意されている。スタイルを活用することで、グラフの見た目を簡単に変更できる。