web-dev-qa-db-ja.com

numpyの対数確率行列を乗算する数値的に安定した方法

対数確率を含む2つのNumPy行列(または他の2次元配列)の行列積を取る必要があります。明白な理由から、単純な方法np.log(np.dot(np.exp(a), np.exp(b)))は好ましくありません。

使用する

_from scipy.misc import logsumexp
res = np.zeros((a.shape[0], b.shape[1]))
for n in range(b.shape[1]):
    # broadcast b[:,n] over rows of a, sum columns
    res[:, n] = logsumexp(a + b[:, n].T, axis=1) 
_

動作しますが、実行速度はnp.log(np.dot(np.exp(a), np.exp(b)))の約100倍です。

使用する

_logsumexp((tile(a, (b.shape[1],1)) + repeat(b.T, a.shape[0], axis=0)).reshape(b.shape[1],a.shape[0],a.shape[1]), 2).T
_

または、タイルと形状変更の他の組み合わせも機能しますが、現実的なサイズの入力行列に非常に大量のメモリが必要なため、上記のループよりも実行速度がさらに遅くなります。

私は現在、これを計算するためにCでNumPy拡張機能を作成することを検討していますが、もちろんそれは避けたいと思います。これを行うための確立された方法はありますか、またはこの計算を実行するためのメモリをあまり消費しない方法を知っている人はいますか?

編集:このソリューションを提供してくれたlarsmansに感謝します(派生については以下を参照):

_def logdot(a, b):
    max_a, max_b = np.max(a), np.max(b)
    exp_a, exp_b = a - max_a, b - max_b
    np.exp(exp_a, out=exp_a)
    np.exp(exp_b, out=exp_b)
    c = np.dot(exp_a, exp_b)
    np.log(c, out=c)
    c += max_a + max_b
    return c
_

このメソッドを、iPythonの魔法の_logdot_old_関数を使用して上記のメソッド(_%timeit_)と簡単に比較すると、次のようになります。

_In  [1] a = np.log(np.random.Rand(1000,2000))

In  [2] b = np.log(np.random.Rand(2000,1500))

In  [3] x = logdot(a, b)

In  [4] y = logdot_old(a, b) # this takes a while

In  [5] np.any(np.abs(x-y) > 1e-14)
Out [5] False

In  [6] %timeit logdot_old(a, b)
1 loops, best of 3: 1min 18s per loop

In  [6] %timeit logdot(a, b)
1 loops, best of 3: 264 ms per loop
_

明らかに、larsmansの方法は私のものを全滅させます!

38
mart

logsumexpは、方程式の右辺を評価することで機能します

log(∑ exp[a]) = max(a) + log(∑ exp[a - max(a)])

つまり、expのオーバーフローを防ぐために、合計を開始する前に最大値を引き出します。ベクトル内積を行う前に、同じことが適用できます。

log(exp[a] ⋅ exp[b])
 = log(∑ exp[a] × exp[b])
 = log(∑ exp[a + b])
 = max(a + b) + log(∑ exp[a + b - max(a + b)])     { this is logsumexp(a + b) }

しかし、導出において別の順番を取ることによって、私たちは得ます

log(∑ exp[a] × exp[b])
 = max(a) + max(b) + log(∑ exp[a - max(a)] × exp[b - max(b)])
 = max(a) + max(b) + log(exp[a - max(a)] ⋅ exp[b - max(b)])

最終的な形式では、内部にベクトル内積があります。また、行列の乗算にも容易に拡張できるため、アルゴリズムを取得します。

def logdotexp(A, B):
    max_A = np.max(A)
    max_B = np.max(B)
    C = np.dot(np.exp(A - max_A), np.exp(B - max_B))
    np.log(C, out=C)
    C += max_A + max_B
    return C

これにより、2つのAサイズの一時的なものと2つのBサイズの一時的なものが作成されますが、それぞれの1つは次の方法で削除できます。

exp_A = A - max_A
np.exp(exp_A, out=exp_A)

Bについても同様です。 (入力行列が関数によって変更される可能性がある場合は、すべての一時行列を削除できます。)

24
Fred Foo

A.shape==(n,r)B.shape==(r,m)を想定します。行列積C=A*Bの計算では、実際にはn*mの合計があります。ログスペースで作業しているときに安定した結果を得るには、これらの各合計にlogsumexpトリックが必要です。幸い、AとBの行と列の安定性を個別に制御するのが非常に簡単なnumpyブロードキャストを使用しています。

コードは次のとおりです。

def logdotexp(A, B):
    max_A = np.max(A,1,keepdims=True)
    max_B = np.max(B,0,keepdims=True)
    C = np.dot(np.exp(A - max_A), np.exp(B - max_B))
    np.log(C, out=C)
    C += max_A + max_B
    return C

注:

この背後にある理由はFredFooの答えに似ていますが、彼は各行列に単一の最大値を使用しました。彼はすべてのn*mの合計を考慮していなかったため、コメントの1つで述べたように、最終的な行列の一部の要素はまだ不安定である可能性があります。

@ identity-m反例を使用して現在受け入れられている回答と比較:

def logdotexp_less_stable(A, B):
    max_A = np.max(A)
    max_B = np.max(B)
    C = np.dot(np.exp(A - max_A), np.exp(B - max_B))
    np.log(C, out=C)
    C += max_A + max_B
    return C

print('old method:')
print(logdotexp_less_stable([[0,0],[0,0]], [[-1000,0], [-1000,0]]))
print('new method:')
print(logdotexp([[0,0],[0,0]], [[-1000,0], [-1000,0]]))

印刷する

old method:
[[      -inf 0.69314718]
 [      -inf 0.69314718]]
new method:
[[-9.99306853e+02  6.93147181e-01]
 [-9.99306853e+02  6.93147181e-01]]
3
Hassan

resbの列にアクセスしていますが、 参照の局所性 が不十分です。試してみる1つのことは、これらを column-major order に格納することです。

1
Neil G

Fred Fooが現在受け入れている回答、およびHassanの回答は、数値的に不安定です(Hassanの回答の方が優れています)。ハッサンの答えが失敗する入力の例は後で提供されます。私の実装は次のとおりです。

import numpy as np
from scipy.special import logsumexp

def logmatmulexp(log_A: np.ndarray, log_B: np.ndarray) -> np.ndarray:
    """Given matrix log_A of shape ϴ×R and matrix log_B of shape R×I, calculates                                                                                                                                                             
    (log_A.exp() @ log_B.exp()).log() in a numerically stable way.                                                                                                                                                                           
    Has O(ϴRI) time complexity and space complexity."""
    ϴ, R = log_A.shape
    I = log_B.shape[1]
    assert log_B.shape == (R, I)
    log_A_expanded = np.broadcast_to(np.expand_dims(log_A, 2), (ϴ, R, I))
    log_B_expanded = np.broadcast_to(np.expand_dims(log_B, 0), (ϴ, R, I))
    log_pairwise_products = log_A_expanded + log_B_expanded  # shape: (ϴ, R, I)                                                                                                                                                              
    return logsumexp(log_pairwise_products, axis=1)

Hassanの答えとFredFooの答えと同じように、私の答えには時間計算量O(ϴRI)があります。彼らの答えは空間の複雑さO(ϴR + RI)(私は実際にはこれについてはわかりません)を持っていますが、私のものは残念ながら空間の複雑さO(ϴRI)を持っています-これはnumpyがϴ×R行列をR×I行列で乗算できるためですサイズϴ×R×Iの追加の配列を割り当てます。 O(ϴRI)空間の複雑さを持つことは、私の方法の内在的な特性ではありません-サイクルを使用して書き出すと、この空間の複雑さを回避できると思いますが、残念ながら、ストックnumpy関数を使用してそれを行うことはできないと思います。

コードが実際に実行される時間を確認しました。通常の行列の乗算よりも20倍遅くなります。

私の答えが数値的に安定していることを知る方法は次のとおりです。

  1. 明らかに、リターンライン以外のすべてのラインは数値的に安定しています。
  2. logsumexp関数は数値的に安定していることが知られています。
  3. そのため、私のlogmatmulexp関数は数値的に安定しています。

私の実装には別のNiceプロパティがあります。 numpyを使用する代わりに、pytorchで同じコードを記述したり、自動微分を使用して別のライブラリを使用したりすると、数値的に安定した後方パスが自動的に取得されます。後方パスが数値的に安定することを知る方法は次のとおりです。

  1. 私のコードのすべての関数はどこでも微分可能です(np.maxとは異なります)
  2. 明らかに、リターンラインを除くすべてのラインを逆伝播するのは数値的に安定しています。なぜなら、そこではまったく奇妙なことが起こっていないからです。
  3. 通常、pytorchの開発者は自分たちが何をしているのかを知っています。したがって、数値的に安定した方法でlogsumexpのバックワードパスを実装したことを信頼するだけで十分です。
  4. 実際には、logsumexpの勾配はsoftmax関数です(参照用にgoogle "softmaxはlogsumexpの勾配です"または https://arxiv.org/abs/1704.00805 命題1を参照)。ソフトマックスは数値的に安定した方法で計算できることが知られています。したがって、pytorch開発者はおそらくそこでsoftmaxを使用するだけです(私は実際にはチェックしていません)。

以下は、pytorchの同じコードです(バックプロパゲーションが必要な場合)。 pytorchのバックプロパゲーションがどのように機能するかにより、フォワードパス中に、バックワードパスのlog_pairwise_productsテンソルが保存されます。このテンソルは大きいので、おそらく保存したくないでしょう。後方パス中にもう一度再計算するだけです。そのような場合は、チェックポイントを使用することをお勧めします-それは本当に簡単です-以下の2番目の関数を参照してください。

import torch
from torch.utils.checkpoint import checkpoint

def logmatmulexp(log_A: torch.Tensor, log_B: torch.Tensor) -> torch.Tensor:
    """Given matrix log_A of shape ϴ×R and matrix log_B of shape R×I, calculates                                                                                                                                                             
    (log_A.exp() @ log_B.exp()).log() and its backward in a numerically stable way."""
    ϴ, R = log_A.shape
    I = log_B.shape[1]
    assert log_B.shape == (R, I)
    log_A_expanded = log_A.unsqueeze(2).expand((ϴ, R, I))
    log_B_expanded = log_B.unsqueeze(0).expand((ϴ, R, I))
    log_pairwise_products = log_A_expanded + log_B_expanded  # shape: (ϴ, R, I)                                                                                                                                                              
    return torch.logsumexp(log_pairwise_products, dim=1)


def logmatmulexp_lowmem(log_A: torch.Tensor, log_B: torch.Tensor) -> torch.Tensor:
    """Same as logmatmulexp, but doesn't save a (ϴ, R, I)-shaped tensor for backward pass.                                                                                                                                                   

    Given matrix log_A of shape ϴ×R and matrix log_B of shape R×I, calculates                                                                                                                                                                
    (log_A.exp() @ log_B.exp()).log() and its backward in a numerically stable way."""
    return checkpoint(logmatmulexp, log_A, log_B)

Hassanの実装が失敗する入力がありますが、私の実装は正しい出力を提供します。

def logmatmulexp_hassan(A, B):
    max_A = np.max(A,1,keepdims=True)
    max_B = np.max(B,0,keepdims=True)
    C = np.dot(np.exp(A - max_A), np.exp(B - max_B))
    np.log(C, out=C)
    C += max_A + max_B
    return C

log_A = np.array([[-500., 900.]], dtype=np.float64)
log_B = np.array([[900.], [-500.]], dtype=np.float64)
print(logmatmulexp_hassan(log_A, log_B)) # prints -inf, while the correct answer is approximately 400.69.
0
CrabMan