david의 CS Blog 자세히보기

Algorithm

[ Query ] Sqrt Decomposition (제곱근 분할법)

david0506 2021. 2. 25. 11:45

Sqrt Decomposition

 

 

구간 쿼리를 Segment Tree를 이용해서 처리하면 시간복잡도가 $O(logN)$이다. 이는 트리의 깊이에 비례하는데, 각 노드의 자식 노드의 수를 밑으로 가지는 로그의 시간복잡도를 가지는 것이다. 예를 들어 세그먼트 트리는 자식 노드가 2개이므로 엄밀한 시간복잡도 식은 $f(x) = log_{2}(x)$이다. 그렇다면 다음과 같이 생각할 수 있다.

 

 

 

"만약 세그먼트 트리의 자식 노드의 개수가 3개라면 시간복잡도는 어떻게 될까?

 

 

 

 

그러면 트리의 높이가 log_3(N)이므로 시간복잡도 식도 이와 같을 것이다. 그러면 공간복잡도를 배제하고 생각해보자. 시간복잡도를 최적화하려면 트리에서 자식 노드의 개수가 많을수록 유리할 것이다.

하지만 우리는 마냥 공간복잡도를 배제할 수는 없기 때문에 공간복잡도도 생각해보자. 자식 노드가 M개라고 하면 노드의 개수는 $1 + M + M^{2} + M^{3} + ...... + M^{N}$ = $\frac{M^{N}-1}{M-1}$이므로 적당히 √N개의 노드를 가지다면 메모리와 시간 모두 적절하지 않을까?하는 생각을 가질 수 있다. 하지만 그래도 메모리가 부담되기 때문에 구간을 √N개로 분할한다는 성질만을 이용해 다음 알고리즘을 생각해볼 수 있다.

 

 

 

 

 

제곱근 분할법은 특정 구간 쿼리를 $O(√N)$에 처리하는 알고리즘으로, 응용되어서 Mo's Algorithm의 기반이 되는 알고리즘이다. 가장 큰 특징은 구간을 $√N$개로 분할하여 관리한다는 것이다.

 

 

 

이를 이용하여 구간의 최댓값, 최솟값, 구간합 등을 구할 수 있는데 여기서는 구간합 구하기 코드를 구현해볼 것이다.

 

구간합을 구하는 쿼리에서 9개의 원소를 분할했다.

 

 

 

 

  • 초기화 작업

 

먼저 $√N$을 구해야하고 초기 상태의 √N개의 원소들로 이루어진 그룹들을 만들어야한다.

 

 

 

1
2
3
4
5
6
7
8
9
10
11
int n, m, k, sqrtN;
long long arr[1000001];
long long bucket[4000001]; //구간합을 저장해놓을 곳
 
void init()
{
    sqrtN=sqrt(n);
    for(int i=1; i<=n; i++){
        bucket[i/sqrtN]+=arr[i];
    }
}
cs

 

 

 

 

값 업데이트 하기

 

위의 그림에서 9번째 원소의 값을 6으로 바꾼다고 하면

 

 

 

다음과 같이 3번째 그룹만 업데이트하면 된다. 이를 구현하면 다음과 같다.

 

 

 

1
2
3
4
5
6
7
8
void update(int x, long long val)
{
    arr[x]=val;
    int idx=x/sqrtN;
    int s=idx*sqrtN, e=s+sqrtN;
    bucket[idx]=0;
    for(int i=s; i<e; i++){ bucket[idx]+=arr[i]; }
}
cs

 

 

 

먼저 $arr[x]$의 값을 바꿔주고 그 다음에 구간합을 다시 구해준다. 다른 방법으로는, 원소 하나만 바뀌었으므로 다음과 같이 해도 된다. 이 경우에 시간복잡도가 $O(1)$이 된다.

 

 

 

1
2
3
4
5
6
7
void update(int x, long long val)
{
    int idx=x/sqrtN;
    bucket[idx]-=arr[x];
    bucket[idx]+=val;
    arr[x]=val;
}
cs

 

 

 

구간합 구하기

 

 

구간 [$3$ ~ $8$]의 구간합을 구한다고 하면

 

 

 

구간 [$4$ ~ $6$]은 그냥 2번째 그룹으로 대체할 수 있다.

 

 

 

즉, 직접 구해야하는 것은 아래 그림과 같이 왼쪽과 오른쪽에 남아있는 원소들이다.

 

 

 

 

 

 

이는 그냥 반복문을 통해 구할 수 있다.

 

 

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
long long sum(int l, int r)
{
    long long ret=0;
    
    //1. 왼쪽과 오른쪽에 남은 원소들 먼저 더해준다.
    while(l%sqrtN!=0 && l<=r){ ret+=arr[l]; l++; }
    while((r+1)%sqrtN!=0 && l<=r){ ret+=arr[r]; r--; }
    
    //2. 구간 내에 있는 그룹들의 합을 구한다.
    while(l<=r){
        ret+=bucket[l/sqrtN];
        l+=sqrtN;
    }
    return ret;
}
cs

 

 

 

전체 코드

 

 

 

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#include <iostream>
#include <algorithm>
#include <vector>
#include <cmath>
 
using namespace std;
 
int n, m, k, sqrtN;
long long arr[1000001];
long long bucket[4000001];
 
void init()
{
    sqrtN=sqrt(n);
    for(int i=1; i<=n; i++){
        bucket[i/sqrtN]+=arr[i];
    }
}
 
void update(int x, long long val)
{
    arr[x]=val;
    int idx=x/sqrtN;
    int s=idx*sqrtN, e=s+sqrtN;
    bucket[idx]=0;
    for(int i=s; i<e; i++){ bucket[idx]+=arr[i]; }
}
 
long long sum(int l, int r)
{
    long long ret=0;
    while(l%sqrtN!=0 && l<=r){ ret+=arr[l]; l++; }
    while((r+1)%sqrtN!=0 && l<=r){ ret+=arr[r]; r--; }
    while(l<=r){
        ret+=bucket[l/sqrtN];
        l+=sqrtN;
    }
    return ret;
}
 
int main()
{
    ios::sync_with_stdio(false);
    cin.tie(0); cout.tie(0);
    cin >> n >> m >> k;
    for(int i=1; i<=n; i++){ cin >> arr[i]; }
    init();
    for(int i=1; i<=m+k; i++){
        int qry; cin >> qry;
        if(qry==1){
            int idx; long long val;
            cin >> idx >> val;
            update(idx, val);
        }
        else{
            int l, r;
            cin >> l >> r;
            cout << sum(l, r) << "\n";
        }
    }
    return 0;
}
 
cs