web-dev-qa-db-ja.com

効率に関する考慮事項:ネストされたループと再帰

私は自分を中間のPythonプログラマーと見なします。私の最近の課題の1つは、与えられた カウントダウン 問題に対するすべての可能な解決策のリストを作成することでした。詳細にあまり掘り下げずに、私は次の方法で問題に取り組みました:

  • 最初に、RPNを使用して可能なすべての数値演算子配置のリストを生成します

  • そして、すべての可能な配置のすべての可能な順列番号/演算子を総当たりして、答えを与えるパターンを記録します。

完全なコードリストはさらに下にあります。

これはまったく非効率的であり、私のプログラムは完了するまでに5〜10分かかることを認識しています。

私は別のアプローチ here に出くわしました。これは再帰とジェネレーターを使用し、30秒のスケールでかなり速く終了します。 Pythonについての私の理解レベルでは、見つけたコードを読んでニュアンスを完全に理解することはできません。

可能なすべての順列を使用して分岐式を再帰的に作成し、正しい結果が得られるまでそれらを評価することを理解しています。これは、本質的に私がやっていることの別の見方です。私はわかりませんそのコードが私のものよりも桁違いに速い理由を理解しています。

運用面では、より速いコードは500万回の試行という規模で行われ、私の場合は1,500万回の試行が行われますが、それでも実行時間の違いとは一致しません。

私の質問:クラス/再帰アプローチについて正確に何がこれを私の単純なアプローチよりもはるかに効率的にするかについてのポインタに非常に感謝します基本的に同じ方法に。


ネストされたループでさまざまなモジュールのスイッチをオフにいじくり回した後、それを絞り込んだと思います。非常に残念なことに、最も遅い部分はRPN式を評価する方法だと思います。

私がしたこと:

  • result = RPN_eval(...)result = [0]に置き換えました。これにより、プログラムは9秒未満で完了します。

  • 次に、その行を元に戻し、RPN_eval(...)関数を呼び出しました。代わりに、attempt文字列生成を削除し、固定の2 2 +に置き換えました-このバージョンは69秒以内に終了しました...

  • 最後に、attempt2 2 + 2 +に修正すると、実行時間が120秒に増加しました。

式の数値と演算子を追加するごとに、この結果を(大まかに)外挿すると、プログラム時間は約1.7倍増加します。合計実行時間は10〜11分です。これは、私のプログラムが通常の状態で示しているものです。

私の新しい質問:したがって、RPN_eval関数の中で、ぎこちなくて遅いように見える部分は何ですか? さらに調査し、これを実際の個別の質問に形式化しますが、ここでは関係ありません


私は何かに夢中になっていると思います-RPNパターン式を(恐ろしい)ラムダ関数に動的に変換しようとしています。協力したらここにコードを追加します...

私のコードリスト:

import itertools as it
import random
import time
operators = ["+", "-", "/", "*"]
count = 0

def RPN_eval(expression, answer): #a standard stack approach to evaluating RPN expressions
    explist = expression.split(" ")
    explist.pop(-1)
    stack = []

    for char in explist:

        if not char in operators:
            stack.append(int(char))
        else:
            if char == "+":
                num1 = stack.pop()
                num2 = stack.pop()

                if num1 > num2:
                    return[-1]

                result = num1 + num2
                stack.append(result)

            if char == "-":
                num1 = stack.pop()
                num2 = stack.pop()
                result = -num1 + num2
                stack.append(result)

            if char == "*":
                num1 = stack.pop()
                num2 = stack.pop()

                if num1 > num2:
                    return [-1]

                result = num1 * num2
                stack.append(result)

            if char == "/":
                divisor = stack.pop()
                divident = stack.pop()

                try:
                    result = divident / divisor
                except:
                    return [-1]

                stack.append(result)

            if result<=0 or result != int(result):
                return [-1]

    return stack

################### This part runs once and generates 37 possible RPN patterns for 6 numbers and 5 operators
def generate_patterns(number_of_numbers): 
#generates RPN patterns in the form NNoNNoo where N is number and o is operator

    patterns = ["N "]

    for pattern1 in patterns:
        for pattern2 in patterns:
            new_pattern = pattern1 + pattern2 + "o "
            if new_pattern.count("N")<=number_of_numbers and new_pattern not in patterns:
                patterns.append(new_pattern)

    return patterns
#######################################


######### Slowest part of program ################
def calculate_solutions(numbers, answer):
    global count
    patterns = generate_patterns(len(numbers)) #RPN symbolic patterns for a given number pool, runs once, takes less than 1 second
    random.shuffle(patterns) #not necessary, but yields answers to look at faster on average
    print(patterns)
    solutions = [] #this list will store answer strings of good solutions. This particular input produces 56 answers.

    for pattern in patterns:
        nn = pattern.count("N") #counts the number of numbers in a symbolic pattern to produce corresponding number group permutations
        no = pattern.count("o") #same for operators
        numpermut = it.permutations(numbers,nn) #all possible permutations of input numbers, is an itertools.permutations object, not a list. Takes 0 seconds to define.

        print(pattern)

        for np in numpermut:
            oppermut = it.product(["+","-","*","/"],repeat=no) #all possible permutations of operator order for a given pattern, itertools object, not a list. Takes 0 seconds to define
            for op in oppermut:
                attempt = ""
                ni = 0
                oi = 0
                for sym in pattern:
                    if "N" in sym:
                        attempt+=str(np[ni])+" " #replace Ns in pattern with corresponding numbers from permutations
                        ni+=1
                    if "o" in sym:
                        attempt+=str(op[oi])+" " #replace os in pattern with corresponding operators from permutations
                        oi+=1

                count+=1
                result = RPN_eval(attempt, answer) #evaluate attempt

                if result[0] == answer:
                    solutions.append(attempt) #if correct, append to list

                    print(solutions)
    return solutions
#####################################    




solns = calculate_solutions([50 , 8 , 3 , 7 , 2 , 10],556)
print(len(solns), count)

そしてより速いコードリスト:

class InvalidExpressionError(ValueError):
    pass

subtract = lambda x,y: x-y
def add(x,y):
    if x<=y: return x+y
    raise InvalidExpressionError
def multiply(x,y):
    if x<=y or x==1 or y==1: return x*y
    raise InvalidExpressionError
def divide(x,y):
    if not y or x%y or y==1:
        raise InvalidExpressionError
    return x/y

count = 0
add.display_string = '+'
multiply.display_string = '*'
subtract.display_string = '-'
divide.display_string = '/'

standard_operators = [ add, subtract, multiply, divide ]

class Expression(object): pass

class TerminalExpression(Expression):
    def __init__(self,value,remaining_sources):
        self.value = value
        self.remaining_sources = remaining_sources
    def __str__(self):
        return str(self.value)
    def __repr__(self):
        return str(self.value)

class BranchedExpression(Expression):
    def __init__(self,operator,lhs,rhs,remaining_sources):
        self.operator = operator
        self.lhs = lhs
        self.rhs = rhs
        self.value = operator(lhs.value,rhs.value)
        self.remaining_sources = remaining_sources
    def __str__(self):
        return '('+str(self.lhs)+self.operator.display_string+str(self.rhs)+')'
    def __repr__(self):
        return self.__str__()

def ValidExpressions(sources,operators=standard_operators,minimal_remaining_sources=0):
    global count
    for value, i in Zip(sources,range(len(sources))):
        yield TerminalExpression(value=value, remaining_sources=sources[:i]+sources[i+1:])
    if len(sources)>=2+minimal_remaining_sources:
        for lhs in ValidExpressions(sources,operators,minimal_remaining_sources+1):
            for rhs in ValidExpressions(lhs.remaining_sources, operators, minimal_remaining_sources):
                for f in operators:
                    try:
                        count+=1
                        yield BranchedExpression(operator=f, lhs=lhs, rhs=rhs, remaining_sources=rhs.remaining_sources)
                    except InvalidExpressionError: pass

def TargetExpressions(target,sources,operators=standard_operators):
    for expression in ValidExpressions(sources,operators):
        if expression.value==target:
            yield expression

def FindFirstTarget(target,sources,operators=standard_operators):
    for expression in ValidExpressions(sources,operators):
        if expression.value==target:
            return expression
    raise (IndexError, "No matching expressions found")

if __name__=='__main__':
    import time
    start_time = time.time()
    target_expressions = list(TargetExpressions(556,[50,8,3,7,2,10]))
    #target_expressions.sort(lambda x,y:len(str(x))-len(str(y)))
    print ("Found",len(target_expressions),"solutions, minimal string length was:")
    print (target_expressions[0],'=',target_expressions[0].value)
    print()
    print ("Took",time.time()-start_time,"seconds.")
    print(target_expressions)
    print(count)
2
IliaK

「高速」ソリューションのいくつかのステップを見てみましょう。

TargetExpressionが問題とともに呼び出されました。 2つのイテレータを生成するValidExpressionを呼び出します。最初のイテレータは、各TerminalExpressionを一度に1つずつ提供します。 TargetExpressionのループは、それが答えであるかどうかを確認するためにそれぞれをチェックし、答えがそうである場合は、メインプログラムに(イテレータを介して1つずつ)渡されます。これを実行すると、ネストされたイテレータを使用して可能な限り並べ替えることにより、候補式を1つずつ返す2番目のイテレータが生成されます。これらの値は、TargetExpressionのループによって一度に1つずつループされます。ネストされた各イテレータも、一度に1つの値のみを返します。

ここでの1つの違いは、「高速」バージョンでは計算が不要になるということです。つまり、オペランドが順不同(つまり、最初が2番目以上)である場合、それらのオペランドで始まる他の結果の調査は停止します。たとえば、それが50 + 8で始まる場合、高速バージョンはすぐに保釈され、別の開始ペアをチェックします。私が間違っていない場合、バージョンは50 + 8で始まるすべての順列をチェックします。それはそれらを無視しますが、一度に1つずつですが、「高速」バージョンはツリーのその部分全体を無視します。

少し絞り込んだ編集を終えた後、RPN_evalメソッドについていくつか考えてみましょう。

まず、簡単なもの。相互に排他的なifステートメントのセットがあります。これらのチェックの1つだけがtrueであっても、各ループで4つすべての演算子をチェックします。そのような場合は、チェーンに変更する必要があります。

if not char in operators:
  #...
Elif char == "+":
  #...
Elif char == "-":
  #...
Elif char == "*":
  #...
Elif char == "/":

最後のifステートメントが必要かどうかはわかりませんが、おそらく何か不足しています。これが大きな違いを生むとは思えませんが、意図がより正確に伝わり、コーディングエラーが発生しにくくなります。

このコードが高価に見える唯一のことは、継続的なプッシュとポップであり、それはそれほど高価ではありません。私が推測できる最高のことは、同じオペランドを合計した他のバージョンよりも多くチェックしているためです。これは、前述のプルーニングと同じ理由です。 「高速」バージョンは、8 + 50で始まるすべてのツリーをチェックするときに、その操作を1回実行します。このアプローチでは、8 + 50で始まる候補ツリーごとに2つのプッシュと2つのポップを実行します。その計算をする時間はありませんが、それは少数ではありません。計算する式の数と、同じルートで始まる式の数を数えてみてください。それはおそらく目を見張るでしょう。

2
JimmyJames

RPNパーサーを改善しました。解決策自体はまだ私の特定の問題にかなり狭く調整されていますが、関連するスキルは非常にPythonicであり、価値があると思います。最後に私がしたこと:

  • フォームの各RPNパターンを使用しました。 NNoNoを解析して複合ラムダ関数に変換しました
  • 複合関数は2つのリストを取ります:試行する数値の順列とシンボルの順列。
  • ラムダの引数リストの個々の要素は、ラムダの作成時に内部的に事前に割り当てられるため、各数値は適切な場所に配置されます。オペレーターも同じ
  • この定義された複合関数は、特定のパターンのすべての順列に使用され、その後、関数は新しいパターンに対して再定義されます。 37パターンしかないので、解析とリスト操作の時間を節約できます。

結果はまだ圧倒的ですが、引き続き調査します。私のデスクトップマシンでは、動的関数アプローチはすべての結果を120秒で提供しますが、前のアプローチでは155秒です。

追加された関数のみが以下にリストされています。

import operator

def safe_division(num1,num2):
    try:
        r = num1/num2
        if int(r) == r:
            return r
        else:
            raise ValueError
    except ZeroDivisionError:
        raise ValueError

def ordered_add(num1,num2):
    if num1>num2:
        return num1+num2
    else:
        raise ValueError

def ordered_times(num1,num2):
    if num1>num2:
        return num1*num2
    else:
        raise ValueError

def positive_subtract(num1,num2):
    if num1>=num2:
        return num1-num2
    else:
        raise ValueError

ARITHMETIC_OPERATORS = {
    '+':  ordered_add, '-':  operator.sub,
    '*':  ordered_times, '/':  safe_division, '%':  operator.mod,
    '**': operator.pow, '//': operator.floordiv,
}



def RPN_to_opexpression(pattern):
    val = 0 #keeping track of which element in the permutation list to pass to the lambda
    sym = 0 #same but for symbols
    stack = []

    for char in pattern:
        if char == "N":
            #if the character represents a Number, the "result" of this operation is a lambda that returns the number itself
            stack.append(lambda symlist,vallist, val=val : vallist[val]) 
            val+=1
        else:
            #if the character is an operator, the attached lambda is a composite of two previous stack contents composited over the operator
            rhs = stack.pop()
            lhs = stack.pop()
            stack.append(lambda symlist,vallist, sym=sym, lhs=lhs, rhs=rhs: ARITHMETIC_OPERATORS[symlist[sym]](lhs(symlist,vallist),rhs(symlist,vallist)))
            sym+=1
    superfunction = stack.pop()

    return superfunction

######################
#and in main body:
superfun = RPN_to_opexpression(pattern)
...
result = superfun(op,np)
0
IliaK