Donald
Donald

Reputation: 31

Data structure for easily finding the maximum length of a non-decreasing sequence on a certain interval

I'm trying to find a solution for this challenge:

Given is n, which is the size of an array filled with ones. There are two types of operations possible on this array:

  • Increase (x, y, m): add m to all array elements in the interval from position x to y (inclusive). 1 <= x <= y <= n.
  • Give (x, y): return the maximum length of a non-decreasing sequence on the given interval.

Example:

Input: n = 5:

Operations:

  1. Increase 1 2 3 (x = 1, y = 2, m = 3)

    Now our array is [4, 4, 1, 1, 1]

  2. Give 2 4 (x = 2, y = 4)

    The maximum length is 2 because maximum non-decreasing sequence is 1, 1.

I look for a solution where every operation has O(log(n)) time complexity.

My approach

I've noticed that we can store this array as an array of zeros, where each element represent how it is greater than previous. For example, instead of [1, 4, 2, 5] we have [0, 3, -2, 3]. Now we can easily find non-decreasing sequences by just looking at negative numbers. I've tried to go that way and optimize finding negative numbers (e.g by using a set or tree), but for some situations the operation "Give" will still have a O(n) complexity, which is not what I want.

Here is how my algorithm worked:
Please notice, that if we will use an array of zeros, we can change it (when there is increase operation) just in two steps: arr[x] += m and arr[y + 1] -= m (I assume that arr is 1-based). In the beginning I have new empty set. During increase operation I do those two steps above and then:

  1. If arr[x] after "+= m" operation is not longer less than 0, just remove x from the set.

  2. If arr[y + 1] after "-= m" operation is not longer equal or greater than 0, just add y + 1 to the set.

After all increase operations we will have a set of all negative numbers. And as I wrote above, those are non-decreasing intervals. For example, if we have n = 6 and set = {2, 6}, we have 3 non-decreasing intervals: (1, 1), (2, 5), (6, 6). So we can do Give operations in O(set.size()), which is not good.

How can I execute each operation with O(logn) time complexity?

Upvotes: 2

Views: 178

Answers (1)

Unmitigated
Unmitigated

Reputation: 89404

This can be solved using a segment tree with lazy propagation (alternatively, the range increases can be handled separately in a binary indexed tree). In each node, store the sum of increases on both the left and right endpoints of the interval the node represents as well as the length of three longest non-decreasing subarrays: the one that starts at the left endpoint, the one that ends at the right endpoint, and the overall longest one anywhere within the node's interval. For a leaf node, all three lengths are one and we start with each node having 0 increment.

When merging two nodes, consider if the longest non-decreasing subarray ending at the right endpoint of the left node can be combined with the longest non-decreasing subarray starting at the left endpoint of the right node by checking if the element at the right endpoint of the former is not larger than the element at the left endpoint of the latter, after applying increases (recall that the sum of increase operations on the endpoints is already stored in each node). The longest non-decreasing subarray starting at the left endpoint in the merged result will have the same value as that of the left node, unless the longest non-decreasing subarray in the left node covers the entire node, in which case it would be the combined result from the left and right nodes (if it was possible). The case for the longest non-decreasing subarray ending at the right endpoint is symmetric.

For an increase operation that covers an entire node, all non-decreasing subarrays remain non-decreasing as all elements are incremented by the same amount. It is only necessary to keep track of these as lazy updates (push updates down as soon as a node is encountered while traversing the segment tree).

Both queries and updates can be done in O(log(n)).

Implementation in C++:

#include <iostream>
#include <vector>
#include <numeric>
#include <algorithm>
struct Node {
    int left, right, leftIncr, rightIncr, prefixMax, suffixMax, overallMax, lazy;
};
std::vector<Node> tree;
void pushdown(int n) {
    tree[n].leftIncr += tree[n].lazy;
    tree[n].rightIncr += tree[n].lazy;
    if (tree[n].left != tree[n].right) {
        tree[n * 2].lazy += tree[n].lazy;
        tree[n * 2 + 1].lazy += tree[n].lazy;
    }
    tree[n].lazy = 0;
}
Node merge(Node l, Node r) {
    bool canCombine = l.rightIncr <= r.leftIncr;
    return { 
        l.left, r.right, l.leftIncr, r.rightIncr, 
        l.prefixMax + (l.prefixMax == l.right - l.left + 1 && canCombine) * r.prefixMax, 
        r.suffixMax + (r.suffixMax == r.right - r.left + 1 && canCombine) * l.suffixMax, 
        std::max({l.overallMax, r.overallMax, canCombine * (l.suffixMax + r.prefixMax)})
    };
}
Node build(int n, int l, int r) {
    if (l != r) {
        int mid = std::midpoint(l, r);
        tree[n] = merge(build(n * 2, l, mid), build(n * 2 + 1, mid + 1, r));
    } else {
        tree[n].left = tree[n].right = l;
        tree[n].prefixMax = tree[n].suffixMax = tree[n].overallMax = 1;
    }
    return tree[n];
}
void update(int n, int updateLeft, int updateRight, int incr) {
    pushdown(n);
    if (updateLeft <= updateRight)
        if (tree[n].left == updateLeft && tree[n].right == updateRight) {
            tree[n].lazy = incr;
            pushdown(n);
        } else {
            update(n * 2, updateLeft, std::min(updateRight, tree[n * 2].right), incr);
            update(n * 2 + 1, std::max(tree[n * 2 + 1].left, updateLeft), updateRight, incr);
            tree[n] = merge(tree[n * 2], tree[n * 2 + 1]);
        }
}
Node query(int n, int queryLeft, int queryRight) {
    pushdown(n);
    if (tree[n].left == queryLeft && tree[n].right == queryRight) return tree[n];
    if (queryRight <= tree[n * 2].right) return query(n * 2, queryLeft, queryRight);
    if (queryLeft > tree[n * 2].right) return query(n * 2 + 1, queryLeft, queryRight);
    return merge(query(n * 2, queryLeft, tree[n * 2].right), query(n * 2 + 1, tree[n * 2 + 1].left, queryRight));
}
int main() {
    int n;
    std::cin >> n;
    tree.resize(4 * n);
    build(1, 1, n);
    for (int type, l, r, incr; std::cin >> type >> l >> r;) {
        // type = 1 is query; type = 0 is update
        if (type) std::cout << query(1, l, r).overallMax << '\n';
        else {
            std::cin >> incr;
            update(1, l, r, incr);
        }
    }
}

Upvotes: 3

Related Questions