문제

문제 링크

문제 접근

  • 이 문제는 트리를 이해하고 디피를 활용하는 방법을 이해하고 있다면 풀 수 있는 문제이다.
  • 트리는 사이클이 존재하지 않음으로 하나의 점에서 부터 시작해서 이전 값을 이용해서 풀어나가는 dp를 사용할 수 있다.
  • 이 문제를 디피로 푸는 방법이 떠오르지 않는다면 백준 1520번 내리막길을 한번 풀어보길 바란다.

디피 방법

  • dfs를 이용해서 하나의 정점을 잡고 쭉 파고 들어간다.(굳이 트리를 만들 필요 없다.)
  • 이때 계속해서 자식 노드의 값을 발견하고 이를 이용해서 점화식으로 dp에 저장해나간다.
  • dp[i][0](i번째 노드를 포함시켰을 때의 값): 자식 노드들의 dp[i][1]의 합 + w[i]
  • dp[i][1](i번쨰 노드를 포함시키지 않았을 때의 값) : 자식 노드들의 dp[i][0] dp[i][1] 중 큰 값의 합
  • 결과적으로 1번 노드 dp에 저장된 값들 중에 큰 값이 정답이다.

코드

#include <cstdio>
#include <vector>
using namespace std;
int n, v1, v2, ans, dp[10002][2], w[10002], visited[10002];
vector<vector<int> > e;

void find(int x){
    for(int i = 0; i < e[x].size(); i++){
        int nx = e[x][i];
        if(visited[nx])continue;
        visited[nx] = 1;
        find(nx);
        dp[x][0] += dp[nx][1];
        dp[x][1] += max(dp[nx][0], dp[nx][1]);
    }
    dp[x][0] += w[x];
}

int main(){
    scanf("%d", &n);
    e = vector<vector<int> >(n + 1);
    for(int i = 1; i <= n; i++)scanf("%d", &w[i]);
    for(int i = 0; i < n - 1; i++){
        scanf("%d %d", &v1, &v2);
        e[v1].push_back(v2);
        e[v2].push_back(v1);
    }
    visited[1] = 1;
    find(1);
    ans = max(dp[1][0], dp[1][1]);
    printf("%d", ans);
}

느낀점

  • 호호호.. 트리에 디피라니 설렌다.
  • 트리의 독립 집합과 이 문제를 한주 동안 고민했는데 드디어 풀어냈다..ㅎㅎ 묵은 똥이 내려가는 것만 같다ㅎㅎㅎ