david의 CS Blog 자세히보기

Algorithm

세그먼트 트리(Segment Tree)

david0506 2021. 2. 11. 10:38

세그먼트 트리란?

 

 

구간의 정보를 포함하고 있는 자료구조로 보통 완전 이진 트리의 모양이다.

 

 

예를 들어 다음과 같은 문제가 있다.

 

 

어떤 $N$개의 수가 주어져 있고 다음과 같은 두 가지 연산을 실행한다.

$1.$ $a$번째 수의 값에 $b$만큼 더한다.

$2.$ $a$번째 수부터 $b$번째 수까지의 합을 구한다.

 

 

연산의 실행 횟수가 최대 $100000$개라고 하자.

 

 

 

이 때 이 두 연산을 구현해보자

 

for(int i=1; i<=q; i++){
    int num;
    cin >> num;
    if(num==1){
        int a, b; cin >> a >> b;
        arr[a]+=b;
    }
    if(num==2){
        int a, b; cin >> a >> b;
        int sum=0;
        for(int j=a; j<=b; j++){
            sum+=arr[j];
        }
        cout << sum << "\n";
    }
}

 

 

1번 연산은 시간복잡도가 $O(1)$, 2번 연산은 시간복잡도가 $O(N)$이 나온다.

 

전체 시간복잡도는 $O(NQ)$($Q$: 연산의 실행 횟수)이므로 시간초과 날 것이다.

 

 

 

이제 세그먼트 트리를 이용해 시간복잡도를 $O(QlogN)$으로 만들어 볼 것이다.

 

예를 들어 데이터가 $[7, -3, 8, -2]$로 주어져 있을 때 세그먼트 트리를 구축하면 다음과 같다.

 

 

이 때 세그먼트 트리는 구간합을 정보로 저장하고 있는데 상황에 따라 구간의 최댓값, 최솟값 등을 저장할 수도 있다.

 

노드의 범위는 아래의 두 자손 노드의 범위를 합친 것이고 저장하고 있는 값은 두 자손 노드의 값의 합이다.

 

 

구간합 구하기

 

$1$번째부터 $4$번째까지의 값의 합을 구하고 싶으면 루트 노드에 있는 $10$이 답이 되고

 

$1$번째부터 $3$번째까지의 값의 합을 알고 싶다면

 

 

주황색 부분만 더해서 답이 $12$가 됨을 알 수 있다.

 

구간합을 구하는 코드는 다음과 같다.

 

int sum(int node, int s, int e, int l, int r)
{
    if(l>e || r<s){ return 0; } //범위를 벗어났다면
    if(l<=s && e<=r){ return tree[node]; } //점위 안에 들어왔다면
    int mid=(s+e)/2;
    return sum(node*2, s, mid, l, r)+sum(node*2+1, mid+1, e, l, r);
}

 

 

$1~3$까지의 합을 구하는 것을 예로 들어 이 코드를 설명하자면

 

$sum(1, 4)$에서 $sum(1, 2), sum(3, 4)$로 나누어지고

 

$sum(1, 2)$는 구하려는 범위 안에 속하므로 값을 리턴한다.

 

$sum(3, 4)$는 $sum(3, 3), sum(4, 4)$로 나누어지고

 

$sum(3, 3)$은 구하려는 범위 안에 속하므로 값을 리턴하고

 

$sum(4, 4)$는 범위 바깥에 있으므로 0을 리턴한다.

 

 

값 업데이트 하기

 

$3$번째 값에 $5$를 더하려고 하면

 

 

다음과 같이 주황색 노드들의 값에 $5$씩 더해질 것이다.

 

코드를 구현해보면

 

void update(int node, int s, int e, int idx, int val)
{
    if(idx<s || idx>e){ return; } //범위 바깥이라면
    tree[node]+=val;
    if(s==e){ return; } //리프 노드라면 더 이상 나눌 수 없다
    int mid=(s+e)/2;
    update(node*2, s, mid, idx, val);
    update(node*2+1, mid+1, e, idx, val);
    tree[node]=tree[node*2]+tree[node*2+1];
}

 

다음과 같이 되는데 이 역시 예시를 통해 이해해보자.

 

$2$번째 값에 $6$을 더한다고 했을 때

 

$update(1, 4)$에서 2는 1~4에 속하므로 1번째 노드 값에 6을 더해주고

 

$update(1, 4)$는 $update(1, 2), update(3, 4)$로 나누어진다.

 

$update(3, 4)$에서 2는 3~4에 속하지 않으므로 탐색할 필요가 없다.

 

$update(1, 2)$에서 2는 1~2에 속하므로 2번째 노드 값에 6을 더해주고

 

$update(1, 2)$는 $update(1, 1), update(2, 2)$로 나누어진다.

 

$update(1, 1)$도 범위에 속하지 않으므로 탐색할 필요가 없고

 

$update(2, 2)$는 범위에 속하므로 노드 값에 6을 더해준다.

 

전체 코드

 

#include <iostream>
#include <bits/stdc++.h>

#define MAX 1000000

using namespace std;

int n, q;
int tree[MAX*4];
int v[MAX+1];

int init(int node, int s, int e)
{
    if(s==e){ return tree[node]=v[s]; } //리프 노드라면
    int mid=(s+e)/2;
    return tree[node]=init(node*2, s, mid)+init(node*2+1, mid+1, e);
}
int sum(int node, int s, int e, int l, int r)
{
    if(l>e || r<s){ return 0; }
    if(l<=s && e<=r){ return tree[node]; }
    int mid=(s+e)/2;
    return sum(node*2, s, mid, l, r)+sum(node*2+1, mid+1, e, l, r);
}

void update(int node, int s, int e, int idx, int val)
{
    if(idx<s || idx>e){ return; }
    tree[node]+=val;
    if(s==e){ return; }
    int mid=(s+e)/2;
    update(node*2, s, mid, idx, val);
    update(node*2+1, mid+1, e, idx, val);
}

int main()
{
    ios::sync_with_stdio(false);
    cin.tie(0); cout.tie(0);
    cin >> n >> q;
    for(int i=1; i<=n; i++){
        cin >> v[i];
    }
    init(1, 1, n);
    for(int i=1; i<=q; i++){
        int qry; cin >> qry;
        if(qry==1){
            int idx; long long val;
            cin >> idx >> val;
            update(1, 1, n, idx, val);
            v[idx]=val;
        }
        if(qry==2){
            int l, r;
            cin >> l >> r;
            cout << sum(1, 1, n, l, r) << "\n";
        }
    }
    return 0;
}

 

 

아까 설명에 없었던 init 함수가 있는데 init 함수는 세그먼트 트리를 구축하는 함수이다.

 

시간복잡도를 보면 update 함수와 sum 함수 모두 $O(logN)$이고, init 함수는 $O(N)$이다.

 

init 함수가 느리다고 생각할 수 있지만 입력을 받으면서 리프 노드를 update하면 $O(NlogN)$이므로 이것보다는 빠르다.

 

최종적으로 전체 시간복잡도는 $O(QlogN)$이다.

 

$init(1, 1, N)$은 $1$번 노드부터 $1$~$N$까지 구축하는 것이다.

$update(1, 1, N, idx, val)$은 $1$번 노드부터, $1~N$까지 범위 중 $idx$위치에 $val$만큼 더하는 것이다.

$sum(1, 1, N, l, r)$은 $1$번 노드부터 $1~N$까지 범위 중 $l~r$ 사이의 합을 구하는 것이다.

 

응용

 

$LIS(Longest Increasing Sequence)$를 구할 때도 세그먼트 트리를 사용할 수 있다.

 

이는 DP를 사용해 $O(N^{2})$으로 해결 가능한데

 

세그먼트 트리를 사용해서 $O(NlogN)$으로 답을 구할 수 있다.(이분 탐색으로 $O(NlogN)$으로 푸는 방법도 있다.)

 

구간합이 아닌 최댓값 세그먼트 트리를 이용할 것이다.

 

$1.$ 배열 $v$에 $x$라는 값들을 입력 받아서 각각 (값, 위치)의 구조체로 만들고 값 오름차순, 인덱스 내림차순으로 정렬한다.

 

$2.$ 작은 $x$ 값부터 반복문으로 돌면서

 

$x$로 끝나는 LIS 길이 = ($v[i]=x$인 $i$에 대해 구간 $[1, i]$에 존재하는 LIS 길이의 최댓값)+1

 

$3.$ $x$로 끝나는 LIS 길이를 세그먼트 트리의 $i$번 인덱스 값으로 업데이트한다.

 

$4.$ 최댓값을 구한다.

 

 

이때 정렬을 값을 오름차순으로 정렬했는데 값이 같다면 인덱스는 내림차순으로 정렬해야한다.

 

그 이유는 LIS는 최장 증가 부분 수열이기 때문에 증가하는 수열이여야한다.

 

값이 같은 것을 포함하면 안되기 때문이다.

 

 

#include <iostream>
#include <algorithm>
#include <vector>

using namespace std;

int n;
struct info{
    int val, idx;
};

bool cmp(const info &a, const info &b)
{
    return a.val==b.val ? a.idx>b.idx : a.val<b.val;
}

vector<info> v;
long long tree[4000001];

inline void update(int node, int s, int e, int idx, long long val)
{
    if(idx<s || idx>e){ return; }
    if(s==e){ tree[node]=max(val, tree[node]); return; }
    int mid=(s+e)/2;
    update(node*2, s, mid, idx, val); update(node*2+1, mid+1, e, idx, val);
    tree[node]=max(tree[node*2], tree[node*2+1]);
}

inline long long Max(int node, int s, int e, int l, int r)
{
    if(e<l || r<s){ return 0; }
    if(l<=s && e<=r){ return tree[node]; }
    int mid=(s+e)/2;
    return max(Max(node*2, s, mid, l, r), Max(node*2+1, mid+1, e, l, r));
}

int main()
{
    ios::sync_with_stdio(false);
    cin.tie(0); cout.tie(0);
    cin >> n;
    for(int i=1; i<=n; i++){
        int val; cin >> val;
        v.push_back({val, i});
    }
    sort(v.begin(), v.end(), cmp);
    for(int i=0; i<n; i++){
        update(1, 1, n, v[i].idx, Max(1, 1, n, 1, v[i].idx)+1);
    }
    cout << Max(1, 1, n, 1, n) << "\n";
    return 0;
}

 

이를 통해 BOJ 12015 가장 긴 증가하는 부분 수열 2, BOJ 12738 가장 긴 증가하는 부분 수열 3을 풀 수 있다.

 

그 외에도 평면 스위핑이나 교차점 세기 등 세그먼트 트리는 정말 다양한 곳에 쓰인다.

 

'Algorithm' 카테고리의 다른 글

동적계획법(경로 역추적)  (0) 2021.02.11
Algorithm이란?  (0) 2021.02.11
그리디 알고리즘(greedy)  (0) 2021.02.11
분할 정복(Divide and Conquer)  (0) 2021.02.11
[ Graph ] 너비 우선 탐색 (BFS, Breadth-First-Search)  (1) 2021.02.11