문제

문제 링크

문제 접근

  • 이 문제는 트리를 이해하고 디피를 활용하는 방법을 이해하고 있다면 풀 수 있는 문제이다.
  • 트리는 사이클이 존재하지 않음으로 하나의 점에서 부터 시작해서 이전 값을 이용해서 풀어나가는 dp를 사용할 수 있다.
  • 이 문제를 디피로 푸는 방법이 떠오르지 않는다면 백준 1520번 내리막길을 한번 풀어보길 바란다.

디피 방법

  • dfs를 이용해서 하나의 정점을 잡고 쭉 파고 들어간다.(굳이 트리를 만들 필요 없다.)
  • 이때 계속해서 자식 노드의 값을 발견하고 이를 이용해서 점화식으로 dp에 저장해나간다.
  • dp[i][0](i번째 노드를 포함시켰을 때의 값): 자식 노드들의 dp[i][1]의 합 + w[i]
  • dp[i][1](i번쨰 노드를 포함시키지 않았을 때의 값) : 자식 노드들의 dp[i][0] dp[i][1] 중 큰 값의 합
  • 결과적으로 1번 노드 dp에 저장된 값들 중에 큰 값이 정답이다.

선택한 노드 뽑아내기

  • 루트 노드부터 dp를 확인하며 뽑아낸다.
  • 해당 원소가 사용됐을 때의 dp 값이 해당 원소를 사용하지 않았을 때보다 크다면 그 원소는 포함되었을 것이다.
  • 그리고 부모 노드가 이미 포함되었다면 자식 노드는 포함될 수 없기 때문에 print 가능 여부도라는 정보도 같이 보내준다.
  • 결과적으로 print할 수 있으며 dp[i][0]가 dp[i][1]보다 큰 경우만 포함 시킨다.
  • 현재 노드를 포함 시킬 때는 자식 노드에게 0을 보내주고 포함 시키지 않을 때는 자식 노드에게 1을 보내준다.

코드

#include <cstdio>
#include <vector>
#include <algorithm>

using namespace std;

int ans, n, v1, v2, w[10002], dp[10002][2], visited[10002];
vector<vector<int> > e; 
vector<int> answer;

void find(int x){
    for(int i = 0; i < e[x].size(); i++){
        int nx = e[x][i];
        if(visited[nx])continue;
        visited[nx] = 1;
        find(nx);
        dp[x][0] += dp[nx][1];
        dp[x][1] += max(dp[nx][0], dp[nx][1]);
    }
    dp[x][0] += w[x];
}

void get_ans(int x, int print){
    if(print && dp[x][0] > dp[x][1]){
        answer.push_back(x);
        print = 0;
    }
    else print = 1;
    for(int i = 0; i < e[x].size(); i++){
        int nx = e[x][i];
        if(visited[nx])continue;
        visited[nx] = 1;
        get_ans(nx, print);
    }
}

int main(){
    scanf("%d", &n);
    for(int i = 1; i < n + 1; i++)scanf("%d", &w[i]);
    e = vector<vector<int> >(n + 1); 

    for(int i = 0; i < n - 1; i++){
        scanf("%d %d", &v1, &v2);
        e[v1].push_back(v2);
        e[v2].push_back(v1);
    }
    visited[1] = 1;
    find(1);
    for(int i = 2; i < n + 1; i++)visited[i] = 0;
    get_ans(1, 1);
    sort(answer.begin(), answer.end());
    printf("%d\n", max(dp[1][0], dp[1][1]));
    for(int i = 0; i < answer.size(); i++)printf("%d ", answer[i]);
}

느낀점

  • 드디어 풀었다… root node의 디피 값 0과 1만 확인하면 되는데 모든 노드의 dp 0번 값만 비교해서 결과값을 내서 계속 틀렸다.
  • 디피는 최종적인 결과만 확인하면 되는 것을 참 바보같다.
  • 그래도 풀었으니 만조쿠다ㅎㅎ