web-dev-qa-db-ja.com

ストライドを使用した効率的な移動平均フィルター

私は最近 strides について この投稿への回答 で学び、それらを使用して、移動平均フィルターを私が提案したものよりも効率的に計算する方法を不思議に思っていました この投稿 (畳み込みフィルタを使用)。

これが今のところです。元の配列のビューを取得し、必要な量だけそれをロールし、カーネル値を合計して平均を計算します。エッジが正しく処理されていないことは承知していますが、後で対処できます...より速く、より良い方法はありますか?目的は、サイズが最大5000x5000 x 16レイヤーの大きな浮動小数点配列をフィルターすることです。このタスクはscipy.ndimage.filters.convolveはかなり遅いです。

8隣接の接続を探していることに注意してください。つまり、3x3フィルターは、9ピクセルの平均(焦点ピクセルの周囲に8)を取り、その値を新しい画像のピクセルに割り当てます。

import numpy, scipy

filtsize = 3
a = numpy.arange(100).reshape((10,10))
b = numpy.lib.stride_tricks.as_strided(a, shape=(a.size,filtsize), strides=(a.itemsize, a.itemsize))
for i in range(0, filtsize-1):
    if i > 0:
        b += numpy.roll(b, -(pow(filtsize,2)+1)*i, 0)
filtered = (numpy.sum(b, 1) / pow(filtsize,2)).reshape((a.shape[0],a.shape[1]))
scipy.misc.imsave("average.jpg", filtered)

これがどのように機能するかについての説明を編集してください:

現在のコード:

  1. stride_tricksを使用して、[[0,1,2]、[1,2,3]、[2,3,4] ...]のような配列を生成します。これは、フィルターカーネルの最上行に対応しています。
  2. 垂直軸に沿って回転し、カーネル[[10,11,12]、[11,12,13]、[13,14,15] ...]の中央の行を取得し、取得した配列に追加します1)
  3. カーネルの最下行[[20,21,22]、[21,22,23]、[22,23,24] ...]を取得するまで繰り返します。この時点で、各行の合計を取り、それをフィルターの要素数で除算して、各ピクセルの平均を取得します(1行と1列でシフトし、エッジの周りにいくつかの奇妙さを持ちますが、後で注意してください)。

私が望んでいたのは、stride_tricksをよりよく使用して、配列全体の9つの値またはカーネル要素の合計を直接取得することです。

29
Benjamin

価値のあるものについては、次のように「ファンシー」なストライドトリックを使用して実行します。昨日投稿する予定でしたが、実際の仕事に気を取られてしまいました! :)

@Paulと@eatはどちらも、これを行う他のさまざまな方法を使用したNice実装を持っています。先ほどの質問から続けるために、私はN次元の同等のものを投稿すると思いました。

ただし、1Dを超える配列の_scipy.ndimage_関数を大幅に超えることはできません。 (_scipy.ndimage.uniform_filter_は_scipy.ndimage.convolve_に勝るはずです)

さらに、多次元の移動ウィンドウを取得しようとしている場合、誤って配列のコピーを作成すると、メモリ使用量が爆発する危険があります。最初の「ローリング」配列は元の配列のメモリへのビューにすぎませんが、配列をコピーする中間ステップは、元の配列よりも桁の大きさ大きいコピーを作成します(つまり、 100x100の元の配列で作業していると言います...そのビュー((3,3)のフィルターサイズの場合)は98x98x3x3になりますが、元のメモリと同じメモリを使用します。 full 98x98x3x3配列がするメモリの容量!!)

基本的に、クレイジーストライドトリックの使用は、ndarrayの単一軸で移動ウィンドウ操作をベクトル化する場合に最適です。オーバーヘッドが非常に少ない移動標準偏差などを簡単に計算できます。これを複数の軸に沿って開始したい場合は可能ですが、通常はより専門的な関数を使用する方がよいでしょう。 (_scipy.ndimage_など)

とにかく、以下がその方法です。

_import numpy as np

def rolling_window_lastaxis(a, window):
    """Directly taken from Erik Rigtorp's post to numpy-discussion.
    <http://www.mail-archive.com/[email protected]/msg29450.html>"""
    if window < 1:
       raise ValueError, "`window` must be at least 1."
    if window > a.shape[-1]:
       raise ValueError, "`window` is too long."
    shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
    strides = a.strides + (a.strides[-1],)
    return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)

def rolling_window(a, window):
    if not hasattr(window, '__iter__'):
        return rolling_window_lastaxis(a, window)
    for i, win in enumerate(window):
        if win > 1:
            a = a.swapaxes(i, -1)
            a = rolling_window_lastaxis(a, win)
            a = a.swapaxes(-2, i)
    return a

filtsize = (3, 3)
a = np.zeros((10,10), dtype=np.float)
a[5:7,5] = 1

b = rolling_window(a, filtsize)
blurred = b.mean(axis=-1).mean(axis=-1)
_

したがって、b = rolling_window(a, filtsize)を実行すると、8x8x3x3配列が得られます。これは、実際には元の10x10配列と同じメモリを表示したものです。異なる軸に沿って簡単に異なるフィルターサイズを使用したり、N次元配列の選択した軸に沿ってのみ操作したりすることができます(つまり、4次元配列のfiltsize = (0,3,0,3)は6次元ビューを提供します)。

次に、任意の関数を最後の軸に繰り返し適用して、移動するウィンドウで効率的に計算できます。

ただし、mean(またはstdなど)の各ステップで元の配列よりもはるかに大きい一時配列を格納しているため、これはメモリ効率が悪いです。また、それもひどく速くなることはありません。

ndimageに相当するものは次のとおりです。

_blurred = scipy.ndimage.uniform_filter(a, filtsize, output=a)
_

これは、さまざまな境界条件を処理し、配列の一時的なコピーを必要とせずにインプレースで「ぼかし」を実行し、very高速になります。ストライドトリックは、関数をone軸に沿って移動するウィンドウに適用するための良い方法ですが、通常、複数の軸に沿ってそれを行うための良い方法ではありません。

とにかく私の$ 0.02だけ...

28
Joe Kington

私はPythonのコードを書くために十分に詳しくありませんが、畳み込みを高速化するための2つの最良の方法は、フィルターを分離するか、フーリエ変換を使用することです。

分離フィルター:畳み込みはO(M * N)です。ここで、MとNはそれぞれ画像とフィルターのピクセル数です。 3行3列のカーネルでの平均フィルタリングは、最初に3行1列のカーネルで、次に1行3列のカーネルでフィルタリングすることと同じであるため、_(3+3)/(3*3)_ =〜30%速度を連続して向上させることができます2つの1-dカーネルとの畳み込み(これは明らかにカーネルが大きくなるにつれて良くなります)。もちろん、ここでもストライドトリックを使用できる場合があります。

フーリエ変換conv(A,B)ifft(fft(A)*fft(B))と同等です。つまり、直接空間でのたたみ込みは、フーリエ空間での乗算になります。ここで、Aは、画像とBがフィルターです。フーリエ変換の(要素ごとの)乗算では、AとBが同じサイズである必要があるため、Bはsize(A)の配列であり、カーネルは画像の中心にあり、他の場所ではゼロです。配列の中央に3行3列のカーネルを配置するには、Aを奇数サイズに埋め込む必要がある場合があります。フーリエ変換の実装によっては、これは畳み込みよりもはるかに高速になる場合があります(同じフィルターを複数回適用する場合は、fft(B)を事前計算して、計算時間をさらに30%節約できます)。 。

8
Jonas

修正する必要があると私が確信していることの1つは、ビュー配列bです。

割り当てられていないメモリからのアイテムがいくつかあるため、クラッシュします。

アルゴリズムの新しい説明を考えると、修正が必要な最初のことは、aの割り当ての範囲外でストライドしているということです。

bshape = (a.size-filtsize+1, filtsize)
bstrides = (a.itemsize, a.itemsize)
b = numpy.lib.stride_tricks.as_strided(a, shape=bshape, strides=bstrides)

更新

私はまだ方法を完全には把握しておらず、問題を解決するためのより簡単な方法があるように思われるので、ここにこれを配置します:

A = numpy.arange(100).reshape((10,10))

shifts = [(-1,-1),(-1,0),(-1,1),(0,-1),(0,1),(1,-1),(1,0),(1,1)]
B = A[1:-1, 1:-1].copy()
for dx,dy in shifts:
    xstop = -1+dx or None
    ystop = -1+dy or None
    B += A[1+dx:xstop, 1+dy:ystop]
B /= 9

...これは単純なアプローチのようです。唯一の無関係な操作は、Bの割り当てと設定が1回だけであるということです。すべての追加、除算、および索引付けは、関係なく行う必要があります。 16バンドを実行している場合でも、画像を保存する場合は、Bを1回割り当てるだけで済みます。これが役に立たない場合でも、なぜ問題が理解できないのかを明確にするか、少なくとも他の方法のスピードアップの時間を測定するためのベンチマークとして機能します。これは、ラップトップでfloat64の5k x 5kアレイ上で2.6秒で実行されます。0.5はBの作成です

4
Paul

どれどれ:

あなたの質問はそれほど明確ではありませんが、このような平均化を大幅に改善したいと思います。

import numpy as np
from numpy.lib import stride_tricks as st

def mf(A, k_shape= (3, 3)):
    m= A.shape[0]- 2
    n= A.shape[1]- 2
    strides= A.strides+ A.strides
    new_shape= (m, n, k_shape[0], k_shape[1])
    A= st.as_strided(A, shape= new_shape, strides= strides)
    return np.sum(np.sum(A, -1), -1)/ np.prod(k_shape)

if __name__ == '__main__':
    A= np.arange(100).reshape((10, 10))
    print mf(A)

では、実際にどのようなパフォーマンスの向上が期待できますか?

更新:
まず第一に、警告:現在の状態のコードは「カーネル」の形状に適切に適合しません。しかし、それは今の私の主な関心事ではありません(とにかくアイデアはすでに適切に適応する方法がすでにあります)。

私は4D Aの新しい形状を直感的に選択したところです。私にとっては、2Dカーネルの中心を元の2D Aの各グリッド位置の中心に置くことを考えるのは本当に意味があります。

しかし、その4Dシェーピングは実際には「最良」のものではないかもしれません。ここでの本当の問題は、合計のパフォーマンスです。マシンのキャッシュアーキテクチャを完全に利用するには、(4D Aの)「最良の順序」を見つけることができる必要があります。ただし、その順序は、マシンキャッシュと「連携」する「小さな」配列と、そうでない(少なくともそれほど簡単ではない)大きな配列では同じでない場合があります。

更新2:
これはmfのわずかに変更されたバージョンです。明らかに、最初に3D配列に再形成し、それから合計するのではなく、単にドット積をとる方が良いです(これには、カーネルが任意であるという利点があります)。ただし、Paulsの更新された機能よりも(私のマシンでは)3倍遅くなります。

def mf(A):
    k_shape= (3, 3)
    k= np.prod(k_shape)
    m= A.shape[0]- 2
    n= A.shape[1]- 2
    strides= A.strides* 2
    new_shape= (m, n)+ k_shape
    A= st.as_strided(A, shape= new_shape, strides= strides)
    w= np.ones(k)/ k
    return np.dot(A.reshape((m, n, -1)), w)
4
eat