david의 CS Blog 자세히보기

Algorithm

최소 공통 조상(LCA) 알고리즘

david0506 2021. 2. 11. 11:01

LCA 알고리즘이란?

 

트리에서 두 정점 $u$, $v$의 가장 가까운 조상을 찾는 알고리즘이다. 즉, 두 정점 $u$, $v$가 만날 때까지 루트 쪽으로 정점을 올려주면 두 정점이 만났을 때의 정점이 최소 공통 조상이 된다.

 

 

 

 

 

 

이 트리에서 $5$번 정점과 $7$번 정점의 최소 공통 조상은 $2$번이고, $4$번 정점과 $6$번 정점의 최소 공통 조상은 1번 정점이다. 최소 공통 조상을 구한다면 트리에서 두 정점 간의 최단 거리 등을 쉽게 구할 수 있고, 그 외에도 다양하게 사용할 수 있다. 이를 구하기 위해 DP를 이용해서 $O(logN)$으로 해결할 것이다.

 

 

구현

 

 

1. dfs를 통해 트리의 각 정점의 깊이와 정점들의 조상을 구한다.

 

 

이때 조상은 $2^{0}$, $2^{1}$, $2^{2}$, ...번째 조상을 구한다. 위 트리에서 $8$번 노드의 $2^{0}$번째 조상은 $6$, $2^{1}$번째 조상은 3이라고 할 수 있는데 2의 거듭제곱만큼 위에 있는 조상 정점들을 구하는 이유는 2의 거듭제곱 수들의 합으로 모든 양의 정수를 표현할 수 있기 때문이다. 예를 들어 어떤 정점 $u$보다 3 위에 있는 조상 정점을 찾고 싶다면 $u$의 $2^{0}$번째 조상의 $2^{1}$번째 조상을 찾으면 된다.

 

각 정점들의 조상들을 구하기 위해서 동적 계획법을 사용할 것이다.

 

$ac[i][j]$를 i번 노드의 $2^j$번째 조상이라고 정의하면 $ac[i][j]$ = $ac[ ac[i][j-1] ][ j-1 ]$가 성립한다. 이는 어떤 정점 $i$의 $2^{j}$번째 조상은 정점 $i$의 $2^{j-1}$번째 조상의 $2^{j-1}$번째 조상이라는 뜻이다.

 

int n, m;
vector<int> adj[50001];
int ac[50001][20], dep[50001];
//ac[i][j]: i번 노드의 2^j번째 조상
//따라서 ac[i][0]은 i번 노드의 부모 노드

void dfs(int x, int d)
{
    dep[x]=d;
    for(int i=1; i<20; i++){
        ac[x][i]=ac[ac[x][i-1]][i-1];
    }
    for(int nxt : adj[x]){
        if(ac[x][0]!=nxt){
            ac[nxt][0]=x;
            dfs(nxt, d+1);
        }
    }
}

 

 

 

2. 두 정점 중 깊이가 더 깊은 정점을 깊이가 더 낮은 정점까지 끌어올려 준다.

 

 

위 트리에서 $1$번과 $7$번의 LCA를 구하기 위해 계산하면 $7$번 정점을 $2^{1}$번째 조상인 $2$번 정점으로 올린다.

$1$번보다 높아지지 않으면서 올릴 수 있는 한 최대한 올려야 한다. 위에서처럼 ac 배열을 이용해서 2의 거듭제곱수 만큼 위로 올려줄 것이다. 그러면 두 정점의 깊이가 같아졌다.

 

//정점 a를 깊이가 더 큰 정점으로 한다
if(dep[a]<dep[b]){ swap(a, b); }
//두 정점의 깊이 차가 2^i보다 작아지면 a를 2^i만큼 올려준다
for(int i=19; i>=0; i--){
    if(dep[a]-dep[b]>=(1<<i)){ a=ac[a][i]; }
}

 

 

 

3. 두 정점을 같이 올려주면서 같아질때 까지 이를 반복한다.

 

두 정점의 부모 정점이 같아질 때까지 두 정점을 2의 거듭제곱수 만큼 올려준다. 그러면 시간복잡도 $O(logN)$으로 해결할 수 있다.

 

int lca(int a, int b)
{
    if(dep[a]<dep[b]){ swap(a, b); }
    for(int i=19; i>=0; i--){
        if(dep[a]-dep[b]>=(1<<i)){ a=ac[a][i]; }
    }
    
    //이미 두 정점이 같아졌다면 종료
    if(a==b){ return a; }
    
    for(int i=19; i>=0; i--){
        if(ac[a][i]!=ac[b][i]){ a=ac[a][i]; b=ac[b][i]; }
    }
    return ac[a][0];
}

 

 

 

연습 문제

 

1. BOJ 11438 LCA2

더보기

LCA를 구현하는 문제이다.

 

#include <iostream>
#include <bits/stdc++.h>

using namespace std;

int N;
vector<int> adj[100001];

int ac[100001][20];
int dep[100001];

void dfs(int x)
{
    for(int i=1; i<20; i++){
        ac[x][i]=ac[ac[x][i-1]][i-1];
    }
    for(int nxt : adj[x]){
        if(nxt==ac[x][0]){ continue; }
        ac[nxt][0]=x;
        dep[nxt]=dep[x]+1;
        dfs(nxt);
    }
}
 int LCA(int a, int b)
 {
     if(dep[a]<dep[b]){ swap(a, b); }
     for(int i=19; i>=0; i--){
        if(dep[a]-dep[b]>=(1<<i)){
            a=ac[a][i];
        }
     }
     if(a==b){ return a; }
     for(int i=19; i>=0; i--){
        if(ac[a][i]!=ac[b][i]){
            a=ac[a][i]; b=ac[b][i];
        }
     }
     return ac[a][0];
 }

int main()
{
    ios::sync_with_stdio(false);
    cin.tie(0); cout.tie(0);
    cin >> N;
    for(int i=1; i<=N-1; i++){
        int a, b; cin >> a >> b;
        adj[a].push_back(b);
        adj[b].push_back(a);
    }
    dfs(1);
    int M; cin >> M;
    while(M--){
        int a, b;
        cin >> a >> b;
        cout << LCA(a, b) << "\n";
    }
    return 0;
}

 

2. BOJ 1761 정점들의 거리

더보기

트리에서 두 정점 간의 최단 거리를 구하는 것은 일반적인 그래프에서보다 단순하다. 두 정점의 최소 공통 조상을 찾고 정점 $u$에서 LCA까지의 거리+$v$에서 LCA까지의 거리를 구하면 그것이 최단 거리가 될 것이다. 만약 LCA가 아닌 루트나 그 이외의 높이가 더 높은 정점을 거쳐서 간다면 거리가 더 크게 나올 것이다.

 

dist[i]: 루트와 정점 i 사이의 거리

라고 정의하면 우리가 구할 것은 (dist[u]-dist[LCA])+(dist[v]-dist[LCA])이다.

 

 

#include <iostream>
#include <bits/stdc++.h>

using namespace std;

typedef long long ll;

int N;
vector< pair<int, ll> > adj[40001];

int ac[40001][20], dep[40001];
ll dist[40001];

void dfs(int x)
{
    for(int i=1; i<20; i++){
        ac[x][i]=ac[ac[x][i-1]][i-1];
    }
    for(auto &i : adj[x]){
        int nxt=i.first;
        ll cost=i.second;
        if(ac[x][0]==nxt){ continue; }
        ac[nxt][0]=x;
        dist[nxt]=dist[x]+cost;
        dep[nxt]=dep[x]+1;
        dfs(nxt);
    }
}

int LCA(int a, int b)
{
    if(dep[a]<dep[b]){ swap(a, b); }
    for(int i=19; i>=0; i--){
        if((dep[a]-dep[b])>=(1<<i)){
            a=ac[a][i];
        }
    }
    if(a==b){ return a; }

    for(int i=19; i>=0; i--){
        if(ac[a][i]!=ac[b][i]){ a=ac[a][i]; b=ac[b][i]; }
    }
    return ac[a][0];
}

int main()
{
    ios::sync_with_stdio(false);
    cin.tie(0); cout.tie(0);
    cin >> N;
    for(int i=1; i<=N-1; i++){
        int a, b; ll c;
        cin >> a >> b >> c;
        adj[a].push_back({b, c});
        adj[b].push_back({a, c});
    }
    dfs(1);
    int m;
    cin >> m;
    while(m--){
        int a, b; cin >> a >> b;
        int ac=LCA(a, b);
        cout << dist[a]+dist[b]-2*dist[ac] << "\n";
    }
    return 0;
}