Python 3:优化项目Euler问题#14

时间:2019-06-10 02:16:47

标签: python python-3.x memoization collatz

我正在尝试使用Python 3解决Hackerrank Project Euler Problem #14(最长Collat​​z序列)。以下是我的实现。

cache_limit = 5000001
lookup = [0] * cache_limit
lookup[1] = 1


def collatz(num):
    if num == 1:
        return 1
    elif num % 2 == 0:
        return num >> 1
    else:
        return (3 * num) + 1


def compute(start):
    global cache_limit
    global lookup
    cur = start
    count = 1

    while cur > 1:
        count += 1
        if cur < cache_limit:
            retrieved_count = lookup[cur]
            if retrieved_count > 0:
                count = count + retrieved_count - 2
                break
            else:
                cur = collatz(cur)
        else:
            cur = collatz(cur)

    if start < cache_limit:
        lookup[start] = count

    return count


def main(tc):
    test_cases = [int(input()) for _ in range(tc)]
    bound = max(test_cases)
    results = [0] * (bound + 1)

    start = 1
    maxCount = 1
    for i in range(1, bound + 1):
        count = compute(i)
        if count >= maxCount:
            maxCount = count
            start = i
        results[i] = start

    for tc in test_cases:
        print(results[tc])


if __name__ == "__main__":
    tc = int(input())
    main(tc)

有12个测试用例。上面的实现一直持续到测试用例#8,但由于以下原因而在测试用例#9至#12中失败。

Terminated due to timeout

我已经坚持了一段时间。不知道还能在这里做什么。

在这里还有什么可以优化的,这样我就不再超时了?

任何帮助将不胜感激:)

注意: 使用上面的实现,我可以解决实际的Project Euler问题#14。它只为hackerrank中的这4个测试用例提供超时。

5 个答案:

答案 0 :(得分:0)

是的,您可以对代码进行一些优化。但我认为,更重要的是,您需要考虑一个数学观察,这是问题的核心

whenever n is odd, then 3 * n + 1 is always even. 

鉴于此,总可以将(3 * n +1)除以2。这样可以节省一小段时间...

答案 1 :(得分:0)

这是一个改进(耗时1.6秒):无需计算每个数字的顺序。您可以创建字典并存储序列中元素的数量。如果已经出现一个数字,则该序列的计算方式为dic [original_number] = dic [n] + count-1。这样可以节省大量时间。

import time

start = time.time()

def main(n,dic):
    '''Counts the elements of the sequence starting at n and finishing at 1''' 
    count = 1
    original_number = n
    while True:
        if n < original_number:
            dic[original_number] = dic[n] + count - 1 #-1 because when n < original_number, n is counted twice otherwise
            break
        if n == 1:
            dic[original_number] = count
            break
        if (n % 2 == 0):
            n = n/2
        else:
            n = 3*n + 1
        count += 1
    return dic

limit = 10**6
dic = {n:0 for n in range(1,limit+1)}

if __name__ == '__main__':
    n = 1
    while n < limit:
        dic=main(n,dic)

        n += 1        
    print('Longest chain: ', max(dic.values()))
    print('Number that gives the longest chain: ', max(dic, key=dic.get))
    end = time.time()

    print('Time taken:', end-start)

答案 2 :(得分:0)

解决这个问题的技巧是只计算最大输入的答案,并将结果保存为所有较小输入的查找,而不是计算极端上限。

这是我通过所有测试用例的实现。(Python3)

MAX = int(5 * 1e6)
ans = [0]
steps = [0]*(MAX+1)
 
def solve(N):
    if N < MAX+1:
        if steps[N] != 0:
            return steps[N]
    if N == 1:
        return 0
    else:
        if N % 2 != 0:
            result = 1+ solve(3*N + 1) # This is recursion
        else:
            result = 1 + solve(N>>1) # This is recursion
        if N < MAX+1:    
            steps[N]=result # This is memoization
        return result
    
inputs = [int(input()) for _ in range(int(input()))]
largest = max(inputs)

mx = 0
collatz=1
for i in range(1,largest+1):
    curr_count=solve(i)
    if curr_count >= mx:
        mx = curr_count
        collatz = i
    ans.append(collatz)
    
for _ in inputs:
    print(ans[_])

答案 3 :(得分:-1)

这是我的蛮力:

'
#counter
C = 0
N = 0
for i in range(1,1000001):
n = i
c = 0
while n != 1:
    if n % 2 == 0:
        _next = n/2
    else:
        _next= 3*n+1
    c = c + 1
    n = _next
    if c > C:
    C = c
    N = i

 print(N,C)

答案 4 :(得分:-2)

这是我的实现方式(针对在Euler项目网站上专门提出的问题):

num = 1
limit = int(input())
seq_list = []
while num < limit:
    sequence_num = 0
    n = num
    if n == 1:
        sequence_num = 1
    else:
        while n != 1:
            if n % 2 == 0:
                n = n / 2
                sequence_num += 1
            else:
                n = 3 * n + 1
                sequence_num += 1

        sequence_num += 1
    seq_list.append(sequence_num)
    num += 1

k = seq_list.index(max(seq_list))
print(k + 1)