세그먼트 트리란?
구간의 정보를 포함하고 있는 자료구조로 보통 완전 이진 트리의 모양이다.
예를 들어 다음과 같은 문제가 있다.
어떤 $N$개의 수가 주어져 있고 다음과 같은 두 가지 연산을 실행한다.
$1.$ $a$번째 수의 값에 $b$만큼 더한다.
$2.$ $a$번째 수부터 $b$번째 수까지의 합을 구한다.
연산의 실행 횟수가 최대 $100000$개라고 하자.
이 때 이 두 연산을 구현해보자
for(int i=1; i<=q; i++){
int num;
cin >> num;
if(num==1){
int a, b; cin >> a >> b;
arr[a]+=b;
}
if(num==2){
int a, b; cin >> a >> b;
int sum=0;
for(int j=a; j<=b; j++){
sum+=arr[j];
}
cout << sum << "\n";
}
}
1번 연산은 시간복잡도가 $O(1)$, 2번 연산은 시간복잡도가 $O(N)$이 나온다.
전체 시간복잡도는 $O(NQ)$($Q$: 연산의 실행 횟수)이므로 시간초과가 날 것이다.
이제 세그먼트 트리를 이용해 시간복잡도를 $O(QlogN)$으로 만들어 볼 것이다.
예를 들어 데이터가 $[7, -3, 8, -2]$로 주어져 있을 때 세그먼트 트리를 구축하면 다음과 같다.
이 때 세그먼트 트리는 구간합을 정보로 저장하고 있는데 상황에 따라 구간의 최댓값, 최솟값 등을 저장할 수도 있다.
노드의 범위는 아래의 두 자손 노드의 범위를 합친 것이고 저장하고 있는 값은 두 자손 노드의 값의 합이다.
구간합 구하기
$1$번째부터 $4$번째까지의 값의 합을 구하고 싶으면 루트 노드에 있는 $10$이 답이 되고
$1$번째부터 $3$번째까지의 값의 합을 알고 싶다면
주황색 부분만 더해서 답이 $12$가 됨을 알 수 있다.
구간합을 구하는 코드는 다음과 같다.
int sum(int node, int s, int e, int l, int r)
{
if(l>e || r<s){ return 0; } //범위를 벗어났다면
if(l<=s && e<=r){ return tree[node]; } //점위 안에 들어왔다면
int mid=(s+e)/2;
return sum(node*2, s, mid, l, r)+sum(node*2+1, mid+1, e, l, r);
}
$1~3$까지의 합을 구하는 것을 예로 들어 이 코드를 설명하자면
$sum(1, 4)$에서 $sum(1, 2), sum(3, 4)$로 나누어지고
$sum(1, 2)$는 구하려는 범위 안에 속하므로 값을 리턴한다.
$sum(3, 4)$는 $sum(3, 3), sum(4, 4)$로 나누어지고
$sum(3, 3)$은 구하려는 범위 안에 속하므로 값을 리턴하고
$sum(4, 4)$는 범위 바깥에 있으므로 0을 리턴한다.
값 업데이트 하기
$3$번째 값에 $5$를 더하려고 하면
다음과 같이 주황색 노드들의 값에 $5$씩 더해질 것이다.
코드를 구현해보면
void update(int node, int s, int e, int idx, int val)
{
if(idx<s || idx>e){ return; } //범위 바깥이라면
tree[node]+=val;
if(s==e){ return; } //리프 노드라면 더 이상 나눌 수 없다
int mid=(s+e)/2;
update(node*2, s, mid, idx, val);
update(node*2+1, mid+1, e, idx, val);
tree[node]=tree[node*2]+tree[node*2+1];
}
다음과 같이 되는데 이 역시 예시를 통해 이해해보자.
$2$번째 값에 $6$을 더한다고 했을 때
$update(1, 4)$에서 2는 1~4에 속하므로 1번째 노드 값에 6을 더해주고
$update(1, 4)$는 $update(1, 2), update(3, 4)$로 나누어진다.
$update(3, 4)$에서 2는 3~4에 속하지 않으므로 탐색할 필요가 없다.
$update(1, 2)$에서 2는 1~2에 속하므로 2번째 노드 값에 6을 더해주고
$update(1, 2)$는 $update(1, 1), update(2, 2)$로 나누어진다.
$update(1, 1)$도 범위에 속하지 않으므로 탐색할 필요가 없고
$update(2, 2)$는 범위에 속하므로 노드 값에 6을 더해준다.
전체 코드
#include <iostream>
#include <bits/stdc++.h>
#define MAX 1000000
using namespace std;
int n, q;
int tree[MAX*4];
int v[MAX+1];
int init(int node, int s, int e)
{
if(s==e){ return tree[node]=v[s]; } //리프 노드라면
int mid=(s+e)/2;
return tree[node]=init(node*2, s, mid)+init(node*2+1, mid+1, e);
}
int sum(int node, int s, int e, int l, int r)
{
if(l>e || r<s){ return 0; }
if(l<=s && e<=r){ return tree[node]; }
int mid=(s+e)/2;
return sum(node*2, s, mid, l, r)+sum(node*2+1, mid+1, e, l, r);
}
void update(int node, int s, int e, int idx, int val)
{
if(idx<s || idx>e){ return; }
tree[node]+=val;
if(s==e){ return; }
int mid=(s+e)/2;
update(node*2, s, mid, idx, val);
update(node*2+1, mid+1, e, idx, val);
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(0); cout.tie(0);
cin >> n >> q;
for(int i=1; i<=n; i++){
cin >> v[i];
}
init(1, 1, n);
for(int i=1; i<=q; i++){
int qry; cin >> qry;
if(qry==1){
int idx; long long val;
cin >> idx >> val;
update(1, 1, n, idx, val);
v[idx]=val;
}
if(qry==2){
int l, r;
cin >> l >> r;
cout << sum(1, 1, n, l, r) << "\n";
}
}
return 0;
}
아까 설명에 없었던 init 함수가 있는데 init 함수는 세그먼트 트리를 구축하는 함수이다.
시간복잡도를 보면 update 함수와 sum 함수 모두 $O(logN)$이고, init 함수는 $O(N)$이다.
init 함수가 느리다고 생각할 수 있지만 입력을 받으면서 리프 노드를 update하면 $O(NlogN)$이므로 이것보다는 빠르다.
최종적으로 전체 시간복잡도는 $O(QlogN)$이다.
$init(1, 1, N)$은 $1$번 노드부터 $1$~$N$까지 구축하는 것이다.
$update(1, 1, N, idx, val)$은 $1$번 노드부터, $1~N$까지 범위 중 $idx$위치에 $val$만큼 더하는 것이다.
$sum(1, 1, N, l, r)$은 $1$번 노드부터 $1~N$까지 범위 중 $l~r$ 사이의 합을 구하는 것이다.
응용
$LIS(Longest Increasing Sequence)$를 구할 때도 세그먼트 트리를 사용할 수 있다.
이는 DP를 사용해 $O(N^{2})$으로 해결 가능한데
세그먼트 트리를 사용해서 $O(NlogN)$으로 답을 구할 수 있다.(이분 탐색으로 $O(NlogN)$으로 푸는 방법도 있다.)
구간합이 아닌 최댓값 세그먼트 트리를 이용할 것이다.
$1.$ 배열 $v$에 $x$라는 값들을 입력 받아서 각각 (값, 위치)의 구조체로 만들고 값 오름차순, 인덱스 내림차순으로 정렬한다.
$2.$ 작은 $x$ 값부터 반복문으로 돌면서
$x$로 끝나는 LIS 길이 = ($v[i]=x$인 $i$에 대해 구간 $[1, i]$에 존재하는 LIS 길이의 최댓값)+1
$3.$ $x$로 끝나는 LIS 길이를 세그먼트 트리의 $i$번 인덱스 값으로 업데이트한다.
$4.$ 최댓값을 구한다.
이때 정렬을 값을 오름차순으로 정렬했는데 값이 같다면 인덱스는 내림차순으로 정렬해야한다.
그 이유는 LIS는 최장 증가 부분 수열이기 때문에 증가하는 수열이여야한다.
값이 같은 것을 포함하면 안되기 때문이다.
#include <iostream>
#include <algorithm>
#include <vector>
using namespace std;
int n;
struct info{
int val, idx;
};
bool cmp(const info &a, const info &b)
{
return a.val==b.val ? a.idx>b.idx : a.val<b.val;
}
vector<info> v;
long long tree[4000001];
inline void update(int node, int s, int e, int idx, long long val)
{
if(idx<s || idx>e){ return; }
if(s==e){ tree[node]=max(val, tree[node]); return; }
int mid=(s+e)/2;
update(node*2, s, mid, idx, val); update(node*2+1, mid+1, e, idx, val);
tree[node]=max(tree[node*2], tree[node*2+1]);
}
inline long long Max(int node, int s, int e, int l, int r)
{
if(e<l || r<s){ return 0; }
if(l<=s && e<=r){ return tree[node]; }
int mid=(s+e)/2;
return max(Max(node*2, s, mid, l, r), Max(node*2+1, mid+1, e, l, r));
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(0); cout.tie(0);
cin >> n;
for(int i=1; i<=n; i++){
int val; cin >> val;
v.push_back({val, i});
}
sort(v.begin(), v.end(), cmp);
for(int i=0; i<n; i++){
update(1, 1, n, v[i].idx, Max(1, 1, n, 1, v[i].idx)+1);
}
cout << Max(1, 1, n, 1, n) << "\n";
return 0;
}
이를 통해 BOJ 12015 가장 긴 증가하는 부분 수열 2, BOJ 12738 가장 긴 증가하는 부분 수열 3을 풀 수 있다.
그 외에도 평면 스위핑이나 교차점 세기 등 세그먼트 트리는 정말 다양한 곳에 쓰인다.
'Algorithm' 카테고리의 다른 글
동적계획법(경로 역추적) (0) | 2021.02.11 |
---|---|
Algorithm이란? (0) | 2021.02.11 |
그리디 알고리즘(greedy) (0) | 2021.02.11 |
분할 정복(Divide and Conquer) (0) | 2021.02.11 |
[ Graph ] 너비 우선 탐색 (BFS, Breadth-First-Search) (1) | 2021.02.11 |