Colorful Garden

o61_may02_garden

กำหนดให้ cuc_u คือชนิดของดอกไม้ ณ โหนด uu

ในข้อนี้ โจทย์ให้กราฟใด ๆ มา โดยแต่ละโหนดจะมีดอกไม้ที่แตกต่างชนิดกัน ต้องการทราบว่ามีกี่วิธีในการเดินที่ผ่านเพียง cc โหนดเท่านั้น และต้องเดินผ่านแปลงดอกไม้ให้ครบ cc ชนิด นั่นคือการเดินนั้นจะผ่านดอกไม้แต่ละชนิดเพียง 1 ครั้งเท่านั้น ดังนั้น หากกำหนดให้ SS เป็นเซตของชนิดดอกไม้ที่เดินผ่านแล้ว เราจะได้ว่า S|S| เท่ากับจำนวนโหนดที่เดินผ่านแล้วด้วย

เราสามารถนับจำนวนวิธีในการเดินด้วยการใช้ Dynamic Programming โดยกำหนดให้ dp(u,S)dp(u, S) เป็นจำนวนวิธีในการเดินที่มาสิ้นสุดที่โหนด uu และเซตของชนิดของดอกไม้ที่เดินผ่านแล้วเป็น SS ดังนั้น Recurrence Formula จะเป็น

dp(u,S)=vdp(v,Scu)dp(u, S) = \sum \limits_{v} dp(v, S - {c_u})

เมื่อ vv คือโหนดที่ติดกันกับ uu

กำหนด dp(u,cu)=1dp(u, {c_u}) = 1 เป็น Base Case เนื่องจากการที่เริ่มเดินที่ uu จะถือว่าเดินผ่านดอกไม้ชนิด cuc_u ทันที นับเป็น 1 วิธี ส่วน dp(u,)dp(u, \emptyset) ไม่พิจารณา (มีค่าเท่ากับ 00) เพราะเป็นไปไม่ได้ที่จะอยู่ที่แปลงดอกไม้ uu แล้วไม่นับว่าเคยเจอดอกไม้ชนิด cuc_u

สำหรับขั้นตอนการทำ Dynamic Programming นั้น เราจะคำนวณจาก State ที่ S|S| เป็น 1,2,3,,c1, 2, 3, \dots, c ตามลำดับ เพราะจาก Recurrence Formula เราจะสามารถคำนวณค่าของ dp(u,S)dp(u, S) ได้ ก็ต่อเมื่อทราบค่าของ dp(v,Scu)dp(v, S - {c_u}) แล้ว ซึ่งจะสังเกตได้ว่า S=Scu+1|S| = |S - {c_u}| + 1 เสมอ

คำตอบของข้อนี้จะเป็น u=0n1dp(u,{0,1,,c1})\sum \limits_{u=0}^{n-1} dp(u, \{0, 1, \dots, c - 1\})

เราจะแทนเซต SS ด้วยการใช้ Bitmask หรือจำนวนเต็มที่เมื่อเขียนในเลขฐานสอง กำหนดให้ตำแหน่ง bit ที่ ii มีค่าเป็น 11 ก็ต่อเมื่อ iSi \in S

โค้ดตัวอย่างดังนี้

#include <bits/stdc++.h>

using namespace std;

const int N = 1e3 + 5;
const int M = 1e6 + 3;

int n, m, c, A[N];
long long dp[N][1 << 10];
vector<int> bit[N], g[N];

int main() {
  scanf("%d %d %d", &n, &m, &c);
  for (int i = 0; i < n; i++)
    scanf("%d", A + i), ++dp[i][1 << A[i]];
  for (int i = 1, a, b; i <= m; i++) {
    scanf("%d %d", &a, &b);
    g[a].emplace_back(b), g[b].emplace_back(a);
  }
  for (int i = 0; i < (1 << c); i++)
    bit[__builtin_popcount(i)].emplace_back(i);
  for (int i = 2; i <= c; i++)
    for (int j : bit[i])
      for (int u = 0; u < n; u++)
        if (j >> A[u] & 1)
          for (int v : g[u])
            dp[u][j] = (dp[u][j] + dp[v][j ^ (1 << A[u])]) % M;
  long long ans = 0;
  for (int i = 0; i < n; i++)
    ans = (ans + dp[i][(1 << c) - 1]) % M;
  printf("%lld\n", ans);

  return 0;
}

Time Complexity: O(n2c)\mathcal{O}(n \cdot 2^c)