[Codeforces] Max GEQ Sum(Stack + Segment Tree)

2022. 6. 2. 17:30TIL💡/Algorithms

 

이번에는 D문제를 풀었다.

DP 식으로 풀었는데 역시나 시간 초과가 발생했다.

#include <iostream>
#include <vector>
#include <string>
#include <algorithm>
using namespace std;

int main() {
    int t;
    cin >> t;
    while(t--) {
        int n;
        bool answer = true;
        cin >> n;
        
        vector<vector<int>> sum(n, vector<int>(n, 0));
        vector<vector<int>> maximum(n, vector<int>(n, 0));
        for(int i = 0;i < n;i++) {
            cin >> sum[i][i];
            maximum[i][i] = sum[i][i];
        }
        
        for(int l = 1;l < n;l++) {
            if(answer == false) {
                break;
            }
            for(int i = 0;i + l < n;i++) {
                int j = i + l;
                sum[i][j] = sum[i][j - 1] + sum[j][j];
                maximum[i][j] = max(maximum[i][j - 1], maximum[j][j]);
                if(maximum[i][j] < sum[i][j]) {
                    answer = false;
                    break;
                }
            }
        }
        
        if(answer) {
            cout << "YES" << '\n';
        }
        else {
            cout << "NO" << '\n';
        }
    }
    return 0;
}

 

1번 시도. Dynamic Programming → 효율성 문제 발생

메모리 차지도 심하고, 시간복잡도도 높다.

#include <iostream>
#include <vector>
#include <string>
#include <algorithm>
using namespace std;

int main() {
    int t;
    cin >> t;
    while(t--) {
        int n;
        bool answer = true;
        cin >> n;
        
        vector<vector<int>> sum(n, vector<int>(n, 0));
        vector<vector<int>> maximum(n, vector<int>(n, 0));
        for(int i = 0;i < n;i++) {
            cin >> sum[i][i];
            maximum[i][i] = sum[i][i];
        }
        
        for(int l = 1;l < n;l++) {
            if(answer == false) {
                break;
            }
            for(int i = 0;i + l < n;i++) {
                int j = i + l;
                sum[i][j] = sum[i][j - 1] + sum[j][j];
                maximum[i][j] = max(maximum[i][j - 1], maximum[j][j]);
                if(maximum[i][j] < sum[i][j]) {
                    answer = false;
                    break;
                }
            }
        }
        
        if(answer) {
            cout << "YES" << '\n';
        }
        else {
            cout << "NO" << '\n';
        }
    }
    return 0;
}

2번 시도. Divide Conquer → 오답

메모이제이션을 제거하였기에 메모리 차지도 줄이고, 시간 복잡도를 줄였는데 금나큼 놓친 부분이 생겼다.

binary search로만 sum과 max를 체크하기 때문에 일부 i와 j의 조합을 놓친다..ㅠ

#include <iostream>
#include <vector>
#include <string>
#include <algorithm>
using namespace std;
bool answer = true;
vector<int> seq(20000);
pair<long long, int> divide(int i, int j) {
    if(i == j) {
        return make_pair(seq[i], seq[i]);
    }
    // cout << i << " " << j << endl;
    int mid = (i + j) / 2;
    auto p1 = divide(i, mid);
    auto p2 = divide(mid + 1, j);
    
    long long sum1 = p1.first;
    long long sum2 = p2.first;
    
    int max1 = p1.second;
    int max2 = p2.second;
    
    if(sum1 + sum2 > max(max1, max2)) {
        answer = false;
    }

    return make_pair(sum1 + sum2, max(max1, max2));
}
int main() {
    int t;
    cin >> t;
    while(t--) {
        int n;
        answer = true;
        seq.clear();

        cin >> n;
        for(int i = 0;i < n;i++) {
            cin >> seq[i];
        }
        
        divide(0, n - 1);
    
        if(answer) {
            cout << "YES" << '\n';
        }
        else {
            cout << "NO" << '\n';
        }
    }
    return 0;
}

 

3번 시도. 세그먼트 트리

수학적인 연산을 빼놓고 생각하니 답이 없다.

 

Hint 1.

If we have a list of subarrays where the element at index $i$ is the max, which subarrays should we check to be sufficient?

 

만약 i에 위치한 원소가 최댓값인 부분배열의 목록을 가질 때, 어떤 부분배열들을 봐야 탐색에 충분할까?

 

Hint 2.

Checking subarrays which end or start at index $i$ is sufficient, so we can optimize our solution with this observation as the basis.

 

i에서 시작하거나 끝나는 부분 배열을 확인하는 것이 충분할 때, 우리는 우리의 솔루션을 이 관찰을 기반으로 최적화할 수 있다.

 

$a_i$가 maximum인 부분배열의 합이 $a_i$를 넘는지(exceeds)가 궁금하다.

인덱스가 $i$인 element보다 이전의 더 큰 값, 이후의 더 큰 값을 찾는 행위를 포함한다.

이는 스택을 활용해 $O(n)$이라는 시간 복잡도를 가진다.

$x_i, y_i$라는 인덱스들을 설정하자.

매 인덱스를 연산한 이후, 우리는 $a_i$가 $[x_i + 1, i]$와 $[i, y_i - 1]$로 구성된 부분 배열의 최댓값이라는 것을 알게 된다.

 

$(j, k)$는 j 인덱스로 시작하고 k 인덱스로 끝나는 부분 배열의 합(sum)을 표현한다. 이 때 $j$는 $[x_i + 1, i]$에 속하고, $k$는 $[i, y_i - 1]$에 속한다.

만약 $(j, k) > a_i$이라면, 

$ (j, k)  = (j, i - 1) + (i, i) + (i + 1, k) > a_i $이므로 $(j, i - 1) + (i + 1, k) > 0$이다.

따라서 $(j, i - 1) or (i + 1, k)$ 부분 배열의 적어도 하나는 0보다 커야 한다는 의미이므로

부분 배열인 (j, i) , (i, k) 둘 중 하나는 최댓값 $a_i$보다 커야 한다. 따라서 $i$를 기준으로 시작하거나 끝나는 배열만 확인하면 된다.

 

따라서 $(x_i + 1, i), (x_i + 2, i), ..., (i - 1, i)와 (i, i + 1), (i, i + 2), ... (i, y_i - 1)$ 부분 배열을 확인해야 한다.

우리는 그 사이에서 $a_i$를 넘어서는지를 확인하면 된다. 이는 prefix sums과 suffix sums으로 간단하게 바뀐다.(reduce to)

 

$max(i, y_i - 1) - prefix[i - 1] > a_i$

 

$max(i, y_i - 1)$은 해당 범위 내에서 최대 prefix sum을 리턴한다. 이 쿼리는 세그먼트 트리를 활용해서 $O(log n)$으로 수행될 수 있다. 이 쿼리 중 하나라도 참이라면, NO를 출력하고, 그것이 아니라면 YES로 출력한다.

 

최댓값을 탐색하는 것이 $O(n)$, 세그먼트 트리에 $O(log n)$을 활용해 전체 시간 복잡도는 $O(n log n)$이다.

 

원본

Let's look at the problem from the perspective of each 𝑎𝑖ai. We want to check whether the sum of the subarrays, where 𝑎𝑖ai is the maximum element, exceeds 𝑎𝑖ai or not.
Firstly, we must find out in which subarrays is 𝑎𝑖ai the maximum. This involves finding the previous greater element index and the next greater element index of 𝑖i, which can be done for all indices in 𝑂(𝑛)O(n) using stacks. Take these indices as 𝑥𝑖xi, 𝑦𝑖yi. After computing this for every index, we'll know that 𝑎𝑖ai is max in subarrays with starting index [𝑥𝑖+1,𝑖][xi+1,i] and ending index [𝑖,𝑦𝑖−1][i,yi−1].
Take (𝑗,𝑘)(j,k), which represents the sum of a subarray which starts at index 𝑗j and ends at index 𝑘k, where 𝑗∈[𝑥𝑖+1,𝑖]j∈[xi+1,i], 𝑘∈[𝑖,𝑦𝑖−1]k∈[i,yi−1]. If (𝑗,𝑘)>𝑎𝑖(j,k)>ai, then (𝑗,𝑖−1)+(𝑖,𝑖)+(𝑖+1,𝑘)>𝑎𝑖(j,i−1)+(i,i)+(i+1,k)>ai, giving us (𝑗,𝑖−1)+(𝑖+1,𝑘)>0(j,i−1)+(i+1,k)>0. Hence, at least one of the subarrays, (𝑗,𝑖−1)(j,i−1) or (𝑖+1,𝑘)(i+1,k) has a sum greater than 00, which implies that one of subarrays (𝑗,𝑖)(j,i), (𝑖,𝑘)(i,k) has sum greater than 𝑎𝑖ai, so only checking subarrays which start or end at index 𝑖i suffices.
Therefore, for an index 𝑖i, we need to check subarrays (𝑥𝑖+1,𝑖),(𝑥𝑖+2,𝑖),…,(𝑖−1,𝑖)(xi+1,i),(xi+2,i),…,(i−1,i), and subarrays (𝑖,𝑖+1),(𝑖,𝑖+2),…,(𝑖,𝑦𝑖−1)(i,i+1),(i,i+2),…,(i,yi−1). Since we just care if any one of them exceed 𝑎𝑖ai, finding the max of them is enough. This reduces to making a range query over the prefix sums and one over the suffix sums. The query on prefix sums would look like

max(𝑖,𝑦𝑖−1)−prefix[𝑖−1]>𝑎𝑖max(i,yi−1)−prefix[i−1]>ai
Where max(𝑖,𝑦𝑖−1)max(i,yi−1) returns the max prefix sum in the given range. This query can be done using a segment tree in 𝑂(log𝑛)O(log⁡n). If any of the queries is true, then we just have to output "NO", else output "YES".
With this we get the time complexity of the solution as 𝑂(𝑛log𝑛)O(nlog⁡n).

 

__builtin_popcount(x) : This function is used to count the number of one's(set bits) in an integer

예를 들어 4를 파라미터로 전달하면, 4는 이진수로 100이므로 1의 비트 수인 1을 리턴한다.

 

위 내용을 조금 더 이해하기 쉽게...

 

어떤 부분 배열의 최댓값이 $i$에 위치한 $a_i$일 때, 이보다 이전 위치에서 해당 값보다 더 큰 값은 prevGreater로, 이보다 다음 위치에서 해당 값보다 더 큰 값은 nextGreater로 구한다.

이를 통해 알아보니 $a_i$보다 작은 값은 $p_i$에 위치하고, 큰 값은 $b_i$에 위치한다는 것을 파악하였다.

그러면 $a_i$에 위치한 원소가 최댓값이 되기 위해서는 부분 배열의 starting index j : [$p_i + 1, i$], ending indexk : [$i, b_i - 1$]

 

$x + y + a_i > a_i$ 이면 $x + y > 0$, 이를 위해서는 $x > 0 || y > 0$이라는 점을 활용해

$ (j, k)  = (j, i - 1) + (i, i) + (i + 1, k) > a_i $에서 $(j, i - 1) + (i + 1, k) > 0$ , 그리고 $(j, i - 1) > 0 || (i + 1, k) > 0$도출

 

여기서 구간합을 만들 때 prefix sum, suffix sum을 활용한다.

 

그리고 보통 세그먼트 트리를 구간합으로 사용하는데, 여기서는 일반적인 구간합이 아니라 내부의 구간합 중 최대 구간합을 가져온다.

예를 들어 prefixSum이 [-1, 0, 0, 1]이라면, prefixTree는 인덱스 1부터 인덱스 2 * n는 아래와 같은 트리를 나타낸다.

#include <vector>
#include <iostream>
#include <stack>
typedef long long ll;
using namespace std;
 
const ll ninf = -1e15;
 
vector<int> nextGreater(vector<ll>& arr, int n) {
    stack<int> s;
        vector<int> result(n, n);
    for (int i = 0; i < n; i++) {
        while (!s.empty() && arr[s.top()] < arr[i]) {
            result[s.top()] = i;
            s.pop();
        }
        s.push(i);
    }
        return result;
}
 
vector<int> prevGreater(vector<ll>& arr, int n) {
    stack<int> s;
        vector<int> result(n, -1);
    for (int i = n - 1; i >= 0; i--) {
        while (!s.empty() && arr[s.top()] < arr[i]) {
            result[s.top()] = i;
            s.pop();
        }
        s.push(i);
    }
        return result;
}
 
ll query(vector<ll> &tree, int node, int ns, int ne, int qs, int qe) {
    if (qe < ns || qs > ne) return ninf;
    if (qs <= ns && ne <= qe) return tree[node];
 
    int mid = ns + (ne - ns) / 2;
    ll leftQuery = query(tree, 2 * node, ns, mid, qs, qe);
    ll rightQuery = query(tree, 2 * node + 1, mid + 1, ne, qs, qe);
    return max(leftQuery, rightQuery);
}
 
int main() {
   int t;
   cin >> t;
   while (t--) {
        int n, _n;
        cin >> n;
        vector<ll> arr(n, 0);
        for (auto& a : arr)
            cin >> a;
        
        // Round off n to next power of 2
        _n = n;
        while (__builtin_popcount(_n) != 1) _n++;
 
 
        // Prefix sums
        vector<ll> prefixSum(n, 0), suffixSum(n, 0);
        prefixSum[0] = arr[0];
        for (int i = 1; i < n; i++) {
            prefixSum[i] = prefixSum[i - 1] + arr[i];
        }
        suffixSum[n - 1] = arr[n - 1];
        for (int i = n - 2; i >= 0; i--) {
            suffixSum[i] = suffixSum[i + 1] + arr[i];
        }
        
        // Two max-segtress, one on the prefix sums, one on the suffix sums
        vector<ll> prefixTree(2 * _n, ninf), suffixTree(2 * _n, ninf);
 
        for (int i = 0; i < n; i++) {
            prefixTree[_n + i] = prefixSum[i];
            suffixTree[_n + i] = suffixSum[i];
        }
 
        for (int i = _n - 1; i >= 1; i--) {
            prefixTree[i] = max(prefixTree[2 * i], prefixTree[2 * i + 1]);
            suffixTree[i] = max(suffixTree[2 * i], suffixTree[2 * i + 1]);
        }
        vector<int> ng = nextGreater(arr, n);
        vector<int> pg = prevGreater(arr, n);
        bool flag = true;
       
       cout << endl;
 
        for (int i = 0; i < n; i++) {
            ll rightMax = query(prefixTree, 1, 0, _n - 1, i + 1, ng[i] - 1) - prefixSum[i];
            ll leftMax = query(suffixTree, 1, 0, _n - 1, pg[i] + 1, i - 1) - suffixSum[i];
        
            
            if (max(leftMax, rightMax) > 0) {
                flag = false;
                break;
            }
        }
        if (flag)
            cout << "YES\n";
        else
            cout << "NO\n";
   }
}