Cod sursa(job #2034609)

Utilizator danny794Dan Danaila danny794 Data 8 octombrie 2017 09:39:14
Problema Lowest Common Ancestor Scor 70
Compilator cpp Status done
Runda Arhiva educationala Marime 2.55 kb
#include <cmath>
#include <cstdio>
#include <utility>
#include <vector>

class EulerTree {
 public:
  EulerTree(int n) {
    length = 0;
    intervalSize = 1 << ((int) (std::log(2 * n - 1) / log(2)) + 1);
    index.resize(n + 1);
    level.resize(intervalSize * 2);
    euler.resize(intervalSize * 2);
  }

  int findLCA(int a, int b) {
    a = index[a] - intervalSize + 1;
    b = index[b] - intervalSize + 1;
    if (a > b) {
      std::swap(a, b);
    }
    return euler[findLCA(a, b, 1, intervalSize, 1)];
  }

  void readData(int n) {
    int parent;
    std::vector<std::vector<int>> children;
    children.resize(n + 1);
    for (int i = 2; i <= n; i++) {
      scanf("%d", &parent);
      children[parent].emplace_back(i);
    }
    insert(1, 0, children);
    buildIntervalTree();
  }

 private:
  void insert(int idx, int lvl, const std::vector<std::vector<int>>& children) {
    for (const auto& child : children[idx]) {
      insert(idx, lvl);
      insert(child, lvl + 1, children);
    }
    insert(idx, lvl);
  }

  void insert(int node, int lvl) {
    int pos = length + intervalSize;
    index[node] = pos;
    euler[pos] = node;
    level[pos] = lvl;
    length++;
  }

  void buildIntervalTree() {
    for (int i = length + intervalSize; i < 2 * intervalSize; i++) {
      level[i] = length;
    }
    int aux = intervalSize / 2;
    while (aux) {
      for (int i = aux; i < aux * 2; i++) {
        level[i] = level[2 * i + 1];
        euler[i] = euler[2 * i + 1];
        if (level[i] > level[2 * i]) {
          level[i] = level[2 * i];
          euler[i] = euler[2 * i];
        }
      }
      aux /= 2;
    }
  }

  int findLCA(int a, int b, int left, int right, int pos) {
    if (a == left && b == right) {
      return pos;
    }
    int mid = (left + right) / 2;
    if (b <= mid) {
      return findLCA(a, b, left, mid, 2 * pos);
    } else if (a > mid) {
      return findLCA(a, b, mid + 1, right, 2 * pos + 1);
    } else {
      int leftPos = findLCA(a, mid, left, mid, 2 * pos);
      int rightPos = findLCA(mid + 1, b, mid + 1, right, 2 * pos + 1);
      if (level[leftPos] < level[rightPos]) {
        return leftPos;
      } else {
        return rightPos;
      }
    }
  }

  int intervalSize;
  int length;
  std::vector<int> euler;
  std::vector<int> level;
  std::vector<int> index;
};

int main() {
  freopen("lca.in", "r", stdin);
  freopen("lca.out", "w", stdout);
  int n, tests, a, b;
  scanf("%d %d", &n, &tests);
  EulerTree tree(n);
  tree.readData(n);
  while (tests-- > 0) {
    scanf("%d %d", &a, &b);
    printf("%d\n", tree.findLCA(a, b));
  }
  return 0;
}