Reputation: 95
My question is about this kata on Codewars. The function takes two sorted lists with distinct elements as arguments. These lists might or might not have common items. The task is find the maximum path sum. While finding the sum, if there any common items you can choose to change your path to the other list.
The given example is like this:
list1 = [0, 2, 3, 7, 10, 12]
list2 = [1, 5, 7, 8]
0->2->3->7->10->12 => 34
0->2->3->7->8 => 20
1->5->7->8 => 21
1->5->7->10->12 => 35 (maximum path)
I solved the kata but my code doesn't match the performance criteria so I get execution timed out. What can I do for it?
Here is my solution:
def max_sum_path(l1:list, l2:list):
common_items = list(set(l1).intersection(l2))
if not common_items:
return max(sum(l1), sum(l2))
common_items.sort()
s = 0
new_start1 = 0
new_start2 = 0
s1 = 0
s2 = 0
for item in common_items:
s1 = sum(itertools.islice(l1, new_start1, l1.index(item)))
s2 = sum(itertools.islice(l2, new_start2, l2.index(item)))
new_start1 = l1.index(item)
new_start2 = l2.index(item)
s += max(s1, s2)
s1 = sum(itertools.islice(l1, new_start1, len(l1)))
s2 = sum(itertools.islice(l2, new_start2, len(l2)))
s += max(s1, s2)
return s
Upvotes: 8
Views: 772
Reputation: 28636
Benchmarks
On the Discourse tab you can click "Show Kata Test Cases" (once you solved the kata) to see their test case generator. I used that to benchmark the solutions posted so far as well as one from me. A few dozen rounds, since the test cases are pretty random, causing big runtime fluctuation. In each round, all test cases generated were given to all solutions (so in each round, all solutions got the same test cases).
And also Kelly Bundy's worst case for sorting the set of common values:
Code shall follow.
Upvotes: 3
Reputation: 10545
Once you know the items shared between the two lists, you can iterate over each list separately to sum up the items in between the shared items, thus constructing a list of partial sums. These lists will have the same length for both input lists, because the number of shared items is the same.
The maximum path sum can then be found by taking the maximum between the two lists for each stretch between shared values:
def max_sum_path(l1, l2):
shared_items = set(l1) & set(l2)
def partial_sums(lst):
result = []
partial_sum = 0
for item in lst:
partial_sum += item
if item in shared_items:
result.append(partial_sum)
partial_sum = 0
result.append(partial_sum)
return result
return sum(map(max, partial_sums(l1),
partial_sums(l2)))
Time complexity: We only iterate once over each list (the iteration over the shorter lists of partial sums is irrelevant here), so this code is linear in the length of the input lists. However, as you and Kelly Bundy have noted, your own algorithm actually has the same time complexity, except for the sorting the common items part, which does not appear to be too relevant for the given test cases.
So as a general conclusion, if your goal is just to make your code fast enough for passing certain test cases, it can be better to profile the execution to find the time sinks in the actual implementation rather than worry about theoretical worst case scenarios.
Upvotes: 3
Reputation: 7385
This can be done in a single pass in O(n)
runtime and O(1)
space complexity. All you need is two pointers to traverse both arrays in parallel and two path values.
You increment the pointer to the smaller element and add its value to its path. When you find a common element, you add it to both paths and then set both paths to the max value.
def max_sum_path(l1, l2):
path1 = 0
path2 = 0
i = 0
j = 0
while i < len(l1) and j < len(l2):
if l1[i] < l2[j]:
path1 += l1[i]
i += 1
elif l2[j] < l1[i]:
path2 += l2[j]
j += 1
else:
# Same element in both paths
path1 += l1[i]
path2 += l1[i]
path1 = max(path1, path2)
path2 = path1
i += 1
j += 1
while i < len(l1):
path1 += l1[i]
i += 1
while j < len(l2):
path2 += l2[j]
j += 1
return max(path1, path2)
Upvotes: 4
Reputation: 57175
The problem says "aim for linear time complexity", which is a pretty big hint that things like nested loops won't fly (index
are nested O(n) loops here and sort()
is O(n log(n)) when there are many duplicate values between the input lists).
This answer shows how you can cache the repeated .index
calls and use start offsets from the last chunk to bring the complexity down.
As the linked answer also states, itertools.islice
isn't appropriate here because it traverses from the start of the list. Instead, use native slicing. This, coupled with the modifications to index
above, gives you linearithmic complexity overall, linear on most input.
For context, here's my approach, which isn't that different from yours, although I cache indices and avoid sorting.
I started by formulating the problem as a directed acyclic graph with the idea of searching for the maximum path sum:
+---> [0, 2, 3] ---+ +---> [10, 12]
[0] ---| |---> [7] ---|
+---> [1, 5] ------+ +---> [8]
We might as well also sum the values of each node for clarity:
+---> 5 ---+ +---> 22
0 ---| |---> 7 ---|
+---> 6 ---+ +---> 8
The diagram above reveals that a greedy solution will be optimal, given the uniqueness constraints. For example, starting from the root, we can only pick the 5 or 6 value path to get to 7. The larger of the two, 6, is guaranteed to be part of the maximum-weight path, so we take it.
Now, the question is only how to implement this logic. Going back to the lists, here's a more substantial input with formatting and annotations to help motivate an approach:
[1, 2, 4, 7, 8, 10, 14, 15 ]
[ 4, 8, 9, 11, 12, 15, 90]
^ ^ ^
| | |
This illustrates how the linked indices line up. Our goal is to iterate over each chunk between the links, taking the larger of the two sublist sums:
[1, 2, 4, 7, 8, 10, 14, 15 ]
[ 4, 8, 9, 11, 12, 15, 90]
^~~^ ^ ^~~~~~~~~~~~~~~~^ ^^
0 1 2 3 <-- chunk number
The expected result for the above input should be 3 + 4 + 7 + 8 + 32 + 15 + 90 = 159, taking all of the link values plus the top list's sublist sum for chunks 0 and 1 and the bottom list for chunks 2 and 3.
Here's a rather verbose, but hopefully easy to understand, implementation; you can visit the thread to see more elegant solutions:
def max_sum_path(a, b):
b_idxes = {k: i for i, k in enumerate(b)}
link_to_a = {}
link_to_b = {}
for i, e in enumerate(a):
if e in b_idxes:
link_to_a[e] = i
link_to_b[e] = b_idxes[e]
total = 0
start_a = 0
start_b = 0
for link in link_to_a: # dicts assumed sorted, Python 3.6+
end_a = link_to_a[link]
end_b = link_to_b[link]
total += max(sum(a[start_a:end_a]), sum(b[start_b:end_b])) + link
start_a = end_a + 1
start_b = end_b + 1
return total + max(sum(a[start_a:]), sum(b[start_b:]))
Upvotes: 4
Reputation: 27629
Your algorithm is actually fast, just your implementation is slow.
The two things that make it take overall O(n²) time:
l1.index(item)
always searches from the start of the list. Should be l1.index(item, new_start1)
.itertools.islice(l1, new_start1, ...)
creates an iterator for l1
and iterates over the first new_start1
elements before it reaches the elements you want. So just use a normal list slice instead.Then it's just O(n log n) for the sorting and O(n) for everything else. And the sorting's O(n log n) is fast, might easily take less time than the O(n) part for any allowed input and even larger ones.
Here's the rewritten version, gets accepted in about 6 seconds, just like the solutions from the other answers.
def max_sum_path(l1:list, l2:list):
common_items = list(set(l1).intersection(l2))
if not common_items:
return max(sum(l1), sum(l2))
common_items.sort()
s = 0
new_start1 = 0
new_start2 = 0
s1 = 0
s2 = 0
for item in common_items:
next_start1 = l1.index(item, new_start1) # changed
next_start2 = l2.index(item, new_start2) # changed
s1 = sum(l1[new_start1 : next_start1]) # changed
s2 = sum(l2[new_start2 : next_start2]) # changed
new_start1 = next_start1 # changed
new_start2 = next_start2 # changed
s += max(s1, s2)
s1 = sum(l1[new_start1:]) # changed
s2 = sum(l2[new_start2:]) # changed
s += max(s1, s2)
return s
Or you could use iterators instead of indexes. Here's your solution rewritten to do that, also gets accepted in about 6 seconds:
def max_sum_path(l1:list, l2:list):
common_items = sorted(set(l1) & set(l2))
s = 0
it1 = iter(l1)
it2 = iter(l2)
for item in common_items:
s1 = sum(iter(it1.__next__, item))
s2 = sum(iter(it2.__next__, item))
s += max(s1, s2) + item
s1 = sum(it1)
s2 = sum(it2)
s += max(s1, s2)
return s
I'd combine the last four lines into one, just left it like you had so it's easier to compare.
Upvotes: 5