web-dev-qa-db-ja.com

numbaセーフバージョンのitertools.combinations?

_itertools.combinations_の大きなセットをループするコードがいくつかありますが、これは現在パフォーマンスのボトルネックになっています。 numba@jit(nopython=True)を使用して高速化しようとしていますが、いくつかの問題が発生しています。

まず、この小さな例のように、numbaは_itertools.combinations_自体を処理できないようです:

_import itertools
import numpy as np
from numba import jit

arr = [1, 2, 3]
c = 2

@jit(nopython=True)
def using_it(arr, c):
    return itertools.combinations(arr, c)

for i in using_it(arr, c):
    print(i)
_

エラーをスローします:numba.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend) Unknown attribute 'combinations' of type Module(<module 'itertools' (built-in)>)

いくつかグーグルした後、私は このgithubの問題 を発見しました。

_@jit(nopython=True)
def permutations(A, k):
    r = [[i for i in range(0)]]
    for i in range(k):
        r = [[a] + b for a in A for b in r if (a in b)==False]
    return r
_

それを利用して、組み合わせに簡単に絞り込むことができます。

_@jit(nopython=True)
def combinations(A, k):
    return [item for item in permutations(A, k) if sorted(item) == item]
_

これで、combinations関数をエラーなしで実行して、正しい結果を得ることができます。ただし、これは@jit(nopython=True)を使用すると、使用しない場合よりも劇的に遅くなります。このタイミングテストの実行:

_A = list(range(20))  # numba throws 'cannot determine numba type of range' w/o list
k = 2
start = pd.Timestamp.utcnow()
print(combinations(A, k))
print(f"took {pd.Timestamp.utcnow() - start}")
_

numba @jit(nopython=True)デコレータを使用して2.6秒でクロックインし、コメントアウトすると1/000秒未満になります。ですから、それも私にとって実際に実行可能な解決策ではありません。

4
Max Power

itertools.combinations がCで記述されているため、この場合、Numbaを使用して得るものはあまりありません。

ベンチマークしたい場合は、Numba/Python itertools.combinatiionsの機能の実装:

@jit(nopython=True)
def using_numba(pool, r):
    n = len(pool)
    indices = list(range(r))
    empty = not(n and (0 < r <= n))

    if not empty:
        result = [pool[i] for i in indices]
        yield result

    while not empty:
        i = r - 1
        while i >= 0 and indices[i] == i + n - r:
            i -= 1
        if i < 0:
            empty = True
        else:
            indices[i] += 1
            for j in range(i+1, r):
                indices[j] = indices[j-1] + 1

            result = [pool[i] for i in indices]
            yield result

私のマシンでは、これはitertools.combinationsよりも約15倍遅いです。順列の取得と組み合わせのフィルタリングは、確かにさらに遅くなります。

1
Jacques Gaudin