Sum More

o59_may11_sum

ในข้อนี้หากเราพิจารณาค่า mm ใด ๆ โดยนิยามให้ F(m)F(m) หมายถึงฟังก์ชันแทนความถี่ของ mm จะสังเกตว่า F(m)F(m+1)F(m) \geq F(m+1) เสมอ เพราะถ้าคู่อันดับ (i,j)(i,j) ใด ๆ ถูกนับรวมใน F(m+1)F(m+1) แล้วแปลว่า S(i,j)m+1S(i,j) \geq m+1 แล้วคู่อันดับนั้นก็ย่อมถูกนับรวมใน F(m)F(m) ด้วย เพราะ S(i,j)m+1mS(i,j) \geq m+1 \geq m จากเหตุผลดังกล่าวจะได้ว่า ฟังก์ชัน F(m)F(m) เป็นฟังก์ชัน monotone ดังนั้นการหาค่า mm ที่มากที่สุดที่เป็นคำตอบสามารถหาได้ด้วยการ binary search ค่า mm โดยตรง

ระหว่างการทำ binary search หากกำลังพิจารณาคำตอบในช่วง [l,r][l,r] อยู่โดย mid = (l+r)/2 จะมีกรณีที่ต้องพิจารณาเพื่อเปลี่ยนช่วงคำตอบ 2 กรณีดังนี้

  • F(mid)kF(mid) \geq k นั่นหมายความว่าคำตอบที่ต้องการนั้นมีค่ามากกว่าหรือเท่ากับ mid แน่นอน จึงเปลี่ยนช่วงไปพิจารณาช่วง [mid,r][mid,r] ต่อไป
  • F(mid)<kF(mid) < k นั่นหมายความว่าคำตอบที่ต้องการนั้นน้อยกว่า mid แน่นอน จึงเปลี่ยนช่วงไปพิจารณาช่วง [l,mid1][l,mid-1] ต่อไป

วิธีในการคำนวณ F(m)F(m) สามารถทำได้ตรง ๆ ด้วยการไล่นับคู่อันดับ (i,j)(i,j) ที่ S(i,j)mS(i,j) \geq m ได้แต่วิธีนี้มี Time complexity O(N2)\mathcal{O}(N^2) ต่อการคำนวณ F(m)F(m) 1 ครั้ง และเมื่อรวมจำนวนครั้งที่คิด F(m)F(m) โดยจากการ binary search จำเป็นต้องคิด F(m)F(m) ทั้งหมด logN\log N ครั้ง ดังนั้นจึงได้ Time complexity O(N2logN)\mathcal{O}(N^2\log N) ซึ่งไม่ทันเวลา วิธีการ optimize นั้นเริ่มได้จากการ preprocess prefix sum ของทุก ๆ index ii ที่ ini \leq n โดยนิยามให้ pref(i)=pref(i) = ผลรวมของ A(1),A(2),...,A(i)A(1),A(2),...,A(i) จากการ preprocess prefix sum จะได้ว่า S(i,j)=pref(j)pref(i1)S(i,j) = pref(j) - pref(i-1) ดังนั้นคู่อันดับ (i,j)(i,j) ที่เราต้องการนั้นจะมีคุณสมบัติคือ mpref(j)pref(i1)m \leq pref(j) - pref(i-1) หรือก็คือ pref(i1)pref(j)mpref(i-1) \leq pref(j) - m

โดยวิธีในการคำนวณ F(m)F(m) ใด ๆ นั้นเริ่มได้จากการพิจารณาตั้งแต่ j=1,2,3,...,nj = 1,2,3,...,n โดยสำหรับ jj ค่าหนึ่งจะนิยามให้ G(j)G(j) หมายถึงฟังก์ชันแทนจำนวนคู่อันดับ (i,j)(i,j) ที่ iji \leq j และ pref(i1)pref(j)mpref(i-1) \leq pref(j) - m เราสามารถคำนวณ G(j)G(j) โดยเริ่มจากการกำหนด array XX โดยนิยามให้ X(i)X(i) หมายถึงจำนวนของค่า ii และในตอนเริ่มต้น X(i)=0X(i) = 0 ในทุก ๆ ii จากนั้น update array XX ใน index pref(1),pref(2),pref(3),pref(4),...,pref(j1)pref(1),pref(2),pref(3),pref(4),...,pref(j-1) ด้วยการเพิ่มค่าช่องนั้น ๆ ไป 1 การทำแบบนี้ทำให้เราสามารถคำนวณค่า G(j)G(j) ได้สำหรับทุก ๆ jj ด้วยการหาผลรวม ...+X(1)+X(0)+X(1)+...+X(mpref(j))...+X(-1)+X(0)+X(1)+...+X(m-pref(j)) ดังนั้น หากพิจารณา j=1,2,3,...nj = 1,2,3,...n แล้วนำคำตอบมารวมกัน จะได้เป็นความถี่ของ mm นี่เรากำลังค้นหาอยู่

เนื่องจากค่า pref(i)pref(i) นั้นอาจมีค่ามากหรือติดลบได้ เราจึงนำเทคนิค coordinate compression มาช่วย โดยเทคนิคนี้จะเก็บข้อมูลไว้ใน vector coord และใช้การ sort ข้อมูลทั้งหมด และนำค่าซ้ำออก เมื่อค่าใด ๆ ถูกเรียกใช้งานจะแทนค่าค่านั้นด้วย index ใน vector coord แทน ดังนั้นปัญหาค่าติดลบหรือมากเกินว่าจะนำไปเป็น index ของ array จะหมดไป

for (int i = 1; i <= n; i++)
  coord.emplace_back(pref[i]);
//Coordinate compression
coord.emplace_back(0); // base case
sort(coord.begin(), coord.end());
coord.resize(unique(coord.begin(), coord.end()) - coord.begin());

การ update ค่านั้นหากนำค่าไปใส่ในช่องตรง ๆ ด้วย time complexity O(1)\mathcal{O}(1) และ query ไล่ผลรวมตั้งแต่ 11 ถึง j1j-1 ด้วย time complexity O(N)\mathcal{O}(N) จะทำให้ time complexity ต่อการดำเนินการต่อ j ใด ๆ เป็น O(N)\mathcal{O}(N) และ time complexity ในการคิด F(m)F(m) 1 ครั้งเป็น O(N2)\mathcal{O}(N^2) และเมื่อคิดการคำนวณทุก ๆ ครั้งจะมี time complexity เป็น O(N2logN)\mathcal{O}(N^2\log N) เหมือนเดิม ด้วยการใช้โครงสร้างข้อมูล fenwick tree สามารถ optimize time complexity ทั้ง update และ query ต่อการดำเนินการ 1 ครั้งเป็น O(logN)\mathcal{O}(\log N) ในการคิด F(m)F(m) 1 ครั้งจะลด time complexity เหลือเพียง O(NlogN)\mathcal{O}(N\log N) และเมื่อรวมเวลาการคำนวณทุกครั้งได้ time complexity เป็น O(Nlog2N)\mathcal{O}(N\log^2 N)

void update(int x, long k) {
  for (int i = x; i < N; i += i & -i)
    t[i] += k;
}

long long query(int x, long ret = 0) {
  for (int i = x; i; i -= i & -i)
    ret += t[i];
  return ret;
}

int F(int mid) {
  long long now = 0;
  memset(t, 0, sizeof t);
  update(get(0), 1);
  for (int i = 1; i <= n; i++) {
    int pos = lower_bound(coord.begin(), coord.end(), pref[i] - mid) -
              coord.begin() + 1;
    if (coord[pos - 1] > pref[i] - mid)
      --pos;
    if (pos)
      now += query(pos);
    update(get(pref[i]), 1);
  }
  return now;
}