문제

문제 링크

두 점 사이를 어떻게 구할까

  • 그냥 두 점 사이의 거리를 DFS나 BFS로 찾는다??
  • 그 결과 시간 초과가 발생했다…ㅎㅎ 역시 이건 방법이 아니다.
  • 찾아보다가 알게 된 LCA 알고리즘! 다시금 정리하고 넘어간다.

LCA 알고리즘

  • LCA 알고리즘은 트리 내에 공통부모인 LCA(Lowest Common Ancestor)를 찾아주는 알고리즘이다.
  • 처음 트리를 생성할 때 각 노드의 parent와 level을 정리해놓은 후에
  • 두 노드에서 올라가는 부모 노드가 같아질 때까지 찾아가는 것을 말한다.

LCA를 어떻게 활용할 것인가

  • 이 문제에서는 1번 노드를 기준으로 트리를 만든다.
  • 이 때 1번으로 부터의 거리를 다 정리해서 어레이에 담아둔다.
  • 거리 찾는 함수 내부에서는 BFS로 1부터 각 노드의 거리를 업데이트 시켜준다.(나중에 거리 사용때 긴히 사용될 것이다.)
  • LCA를 찾아준다음에 미리 저장해둔 1부터 두 점 까지의 거리에 LCA까지 거리*2를 빼준다.

알고리즘 진행

  1. 일단 edge를 다 받아
  2. 받고 나서 트리를 만들어
  3. 1번 원소를 큐에 넣어
  4. 큐에 원소가 없을때까지 다음을 반복해
  5. 원소를 꺼내고 인접 원소의 값을 업데이트해줘
  6. root로부터의 거리를 기록하고
  7. parent와 level을 기록해줘
  8. parent가 0이면 넣어
  9. input에 따라 두 점의 level을 같게 해줘
  10. 두점으로부터 1번 원소까지 거리를 더하고 공통부모로부터 1번까지의 곱하기 2를 빼줘.

코드

#include <queue>
#include <vector>
#include <cstdio>
#define INF 1000000000

using namespace std;

typedef struct Point{
    int v, w;
    Point(){}
    Point(int i, int j){
        v = i;
        w = j;
    }
}Point;

char buf[1 << 17];

inline char read() {
    static int idx = 1 << 17;
    if (idx == 1 << 17) {
        fread(buf, 1, 1 << 17, stdin);
        idx = 0;
    }
    return buf[idx++];
}
inline int readInt() {
    int sum = 0;
    bool flg = 1;
    char now = read();

    while (now == 10 || now == 32) now = read();
    if (now == '-') flg = 0, now = read();
    while (now >= 48 && now <= 57) {
        sum = sum * 10 + now - 48;
        now = read();
    }

    return flg ? sum : -sum;
}

vector< vector< Point > > ll;
int parent[17][40020], level[40020], d[40020];
int n, m, v, w, u;



int LCA(int chd, int par){
    int diff = level[chd] - level[par];
    for(int i = 16; i >= 0; i--){
        if((1<<i) <= diff){
            chd = parent[i][chd];
            diff -= (1<<i);
        }
    }
    if(chd == par) return chd;
    for(int i = 15; i >= 0; i--){
        if(parent[i][chd] != parent[i][par]){
            chd = parent[i][chd];
            par = parent[i][par];
        } 
    } 
    return parent[0][chd];
}

void build_tree(){
    queue<int> q;
    level[1] = 1, d[1] = 0, parent[0][1] = 1;
    q.push(1);
    while(!q.empty()){
        u = q.front(); q.pop();
        for(int i = 0; i < ll[u].size(); i++){
            v = ll[u][i].v;
            if(level[v] == 0){
                level[v] = level[u] + 1;
                parent[0][v] = u;
                d[v] = d[u] + ll[u][i].w;
                q.push(v);
            }
        }
    }
}
void find_parent(){
    for(int i = 1; i < 17; i++){
        for(int j = 1; j < n + 1; j++)parent[i][j] = parent[i-1][parent[i-1][j]];
    }
}

int main(){
    n = readInt();
    // scanf("%d", &n);
    
    ll = vector< vector< Point > >(n + 1);
    for(int i = 1; i < n; i++){
        v = readInt(), u = readInt(), w = readInt();
        // scanf("%d %d %d", &v, &u, &w);
        ll[v].push_back(Point(u, w));
        ll[u].push_back(Point(v, w));
    }
    build_tree();
    find_parent();

    m = readInt();
    // scanf("%d", &m);
    while(m--){
        v = readInt(), u = readInt();
        // scanf("%d %d", &v, &u);
        printf("%d\n", d[v] + d[u] - d[level[v] > level[u] ? LCA(v, u) : LCA(u, v)]*2);       
    }
}

결과

코드 복기 후 제출한 결과이다.

사진

느낀점

  • LCA를 찾는 것도 대단했는데 LCA까지의 거리를 2를 곱해서 뺴주는 걸 생각해내는 것은 진짜 대단했다… 뭐 내 머리는 아직 그정도는 안되니 나중에는 수식을 이용해서 더 효율적으로 코딩하는 연습을 해봐야겠다.