[백준] 2042: 구간 합 구하기(Segment Tree)

2022. 6. 3. 19:22TIL💡/Algorithms

이전에 틀렸다가 이제서야 세그먼트 트리를 복습하고 성공한 문제이다.

그 때는 세그먼트 트리가 너무 어려운 개념이었는데 이제는 쉽게 이해된다.

 

세그먼트 트리는 저장된 자료를 트리 형식으로 저장하여 빠르게 구간합을 쿼리할 수 있도록 한다.

여기서 세그먼트 트리는 크게 3가지의 기능이 필요하다.

💚 세그먼트 트리 만들기

배열을 통해 트리를 구현한다.

인덱스 1부터 트리를 노드 간의 합으로 구성한다.

만약 구간 합이 아니라 구간의 최댓값, 최솟값 등 다양하게 응용 가능하다.

여기서 중요한 점은 트리 배열의 크기다. 단순히 리프 노드 뿐만 아니라 구간합을 표현하는 중간 노드들도 배열의 노드로 차지 하기 때문이다.

배열의 원소가 N이라면, N이 2의 제곱수라면 Full Binary Tree가 되므로 노드의 개수는 필요한 노드의 개수는 2 * N - 1이다.

하지만 2의 제곱수가 아니라면 조금 더 연산이 필요하다. 가장 깊은 깊이를 구하여서 필요한 노드 개수를 구하여야 한다.

세그먼트 트리의 깊이는 $\lceil log_2N \rceil$이다. 그래서 필요한 배열의 크기는 $2^{depth + 1} - 1$이다.

 

int depth = (int) ceil(log2(n));
int tree_size = (1 << (h + 1));
// arr: 초기 배열
// tree: 세그먼트 트리
// node: 세그먼트 트리 노드 번호
// node가 담당하는 합의 범위: [start, end]
long long init(vector<long long> &arr, vector<long long> &tree, long long node, long long start, long long end) {
    if(start == end) {
        return tree[node] = arr[start];
    }
    long long mid = (start + end) / 2;
    
    // 구간합 구하기
    return tree[node] = init(arr, tree, node * 2, start, mid) + init(arr, tree, node * 2 + 1, mid + 1, end);
    
}

💚 합 구하기

노드가 담당하는 구간이 [start, end]이고, 정확하게 합을 구하고 싶은 구간이 [left, right]이다.

// 노드가 담당하는 구간: [start, end]
// 합을 구하는 구간: [left, right]
long long sum(vector<long long> &tree, long long node, long long start, long long end, long long left, long long right) {
    if(left > end || right < start) return 0;
    if(left <= start && end <= right) return tree[node];
    
    long long mid = (start + end) / 2;
    return sum(tree, node * 2, start, mid, left, right) + sum(tree, node * 2 + 1, mid + 1, end, left, right);
}

 

💚 값 변경하기

값이 변경되면 해당 값이 포함된 구간의 노드를 모두 변경하여야 한다. 그런데 새롭게 계산하는 것이 아니라 이미 연산된 노드에 변경된 차이만 수정한다.

void update(vector<long long> &tree, long long node, long long start, long long end, long long index, long long diff) {
    if(index < start || index > end) return;
    tree[node] = tree[node] + diff;
    // 리프 노드가 아닌 경우 자식도 변경해야 한다.
    if(start != end) {
        long long mid = (start + end) / 2;
        update(tree, node * 2, start, mid, index, diff);
        update(tree, node * 2 + 1, mid + 1, end, index, diff);
    }
}

🕰 시간 복잡도

트리의 레벨에 비례하여 노드를 방문하기 때문에 $O(log N)$의 시간복잡도를 가진다.

 

https://www.acmicpc.net/problem/2042

 

2042번: 구간 합 구하기

첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)과 M(1 ≤ M ≤ 10,000), K(1 ≤ K ≤ 10,000) 가 주어진다. M은 수의 변경이 일어나는 횟수이고, K는 구간의 합을 구하는 횟수이다. 그리고 둘째 줄부터 N+1번째 줄

www.acmicpc.net

전체 코드

수의 범위가 크기 때문에 합을 구할 때는 항상 자료형에 주의해야 한다.

#include <iostream>
#include <vector>
#include <cmath>
using namespace std;
// arr: 초기 배열
// tree: 세그먼트 트리
// node: 세그먼트 트리 노드 번호
// node가 담당하는 합의 범위: [start, end]
long long init(vector<long long> &arr, vector<long long> &tree, long long node, long long start, long long end) {
    if(start == end) {
        return tree[node] = arr[start];
    }
    long long mid = (start + end) / 2;
    
    // 구간합 구하기
    return tree[node] = init(arr, tree, node * 2, start, mid) + init(arr, tree, node * 2 + 1, mid + 1, end);
    
}
// 노드가 담당하는 구간: [start, end]
// 합을 구하는 구간: [left, right]
long long sum(vector<long long> &tree, long long node, long long start, long long end, long long left, long long right) {
    if(left > end || right < start) return 0;
    if(left <= start && end <= right) return tree[node];
    
    long long mid = (start + end) / 2;
    return sum(tree, node * 2, start, mid, left, right) + sum(tree, node * 2 + 1, mid + 1, end, left, right);
}

void update(vector<long long> &tree, long long node, long long start, long long end, long long index, long long diff) {
    if(index < start || index > end) return;
    tree[node] = tree[node] + diff;
    // 리프 노드가 아닌 경우 자식도 변경해야 한다.
    if(start != end) {
        long long mid = (start + end) / 2;
        update(tree, node * 2, start, mid, index, diff);
        update(tree, node * 2 + 1, mid + 1, end, index, diff);
    }
}
int main() {
    long long n, m, k;
    
    cin >> n >> m >> k;
    vector<long long> arr(n);
    for(long long i = 0;i < n;i++) {
        cin >> arr[i];
    }
    long long h = (long long)ceil(log2(n));
    vector<long long> tree(1 << (h + 1));
    init(arr, tree, 1, 0, n - 1);
    
    for(long long i = 0;i < m + k;i++) {
        long long a, b, c;
        cin >> a >> b >> c;
        if(a == 1) {
            long long diff = c - arr[b - 1];
            arr[b - 1] = c;
            update(tree, 1, 0, n - 1, b - 1, diff);
        }
        else {
            cout << sum(tree, 1, 0, n - 1, b - 1, c - 1) << '\n';
        }
    
    }
}

참고

https://eun-jeong.tistory.com/18

 

[자료구조] 세그먼트 트리 (Segment Tree) C++ 구현

Segment Tree 배열 A[1], ..., A[N]이 있을 때, 아래 문제를 생각해보자. [문제 1] 순서쌍 (i, j)에 대하여 A[i], ... ,A[j] 중 최솟값을 찾는 경우를 생각해보자. A[i]부터 A[j]까지 순회하면서 찾는 것이 가장..

eun-jeong.tistory.com

 

'TIL💡 > Algorithms' 카테고리의 다른 글

[백준] 9465: 스티커  (0) 2022.06.04
[백준] 1275: 커피숍2  (0) 2022.06.04
[Codeforces] Number of Groups(UnionFind, Greedy)  (0) 2022.06.03
[Codeforces] Max GEQ Sum(Stack + Segment Tree)  (0) 2022.06.02
[Codeforces] Sum of Substrings  (0) 2022.06.01