poj 1330_Nearest Common Ancestors_LCA

发布于 2019-05-23  20 次阅读


题目大意

给出n个点,n-1条边,求两个点的LCA


#include 
#include 
#include 
using namespace std;
#define fill(x, y) memset(x, y, sizeof(x))
#define maxn 10001
int t[maxn], n, l, r, root;
struct edge
{
    int x, y, next; 
}e[maxn];
int ls[maxn];
bool v[maxn];
int find(int x)
{
    if (t[x] == x) return x;
    t[x] = find(t[x]);
    return t[x];
}
int insert(int x, int y)
{
    if (find(x) != find(y))
    {
        t[find(x)] = find(y);
        return 0;
    }
    return 0;
}
void dfs(int x)
{
    v[x] = 1;
    if ((x == l && v[r]) || (x == r && v[l]))
        if (x == l) printf("%d\n", find(r));
        else if (x == r) printf("%d\n", find(l));
    for (int i = ls[x]; i; i = e[i].next)
    {
        dfs(e[i].y);
        insert(e[i].y, x);
    }
}
void tarjan()
{
    for (int i = 1; i <= n; i++)
        t[i] = i;
    fill(v, 0);
    dfs(root);
}
int main()
{
    int o;
    scanf("%d", &o);
    for (int p = 1; p <= o; p++)
    {
        fill(ls, 0);
        fill(v, 0);
        scanf("%d", &n);
        for (int i = 1; i < n; i++)
        {
            scanf("%d%d", &e[i].x, &e[i].y);
            e[i].next = ls[e[i].x];
            ls[e[i].x] = i;
            v[e[i].y] = 1;
        }
        scanf("%d%d", &l, &r);
        for (int i = 1; i <= n; i++)
            if (v[i] == 0)
            {
                root = i;
                break;
            }
        tarjan();
    }
}

倍增

#include 
#include 
using namespace std;
#define maxn 10001
#define fill(x, y) memset(x, y, sizeof(x))
struct edge
{
    int to, next;
}e[maxn];
int dep[maxn], ls[maxn], f[maxn][17], v[maxn];
void dfs(int d, int now)
{
    dep[now] = d;
    for (int i = ls[now]; i; i = e[i].next)
        dfs(d + 1, e[i].to); 
}
int LCA(int l, int r)
{
    if (dep[l] < dep[r])
        {l ^= r; r ^= l; l ^= r;}
    for (int i = 16; i >= 0; i--)
        if (dep[r] <= dep[f[l][i]])
            l = f[l][i];
    if (l == r)
        return l;
    for (int i = 16; i >=0; i--)
        if (f[l][i] != f[r][i])
        {
            l = f[l][i];
            r = f[r][i];
        }
    return f[l][0];
}
int main()
{
    int t;
    scanf("%d", &t);
    for (int o = 1; o <= t; o++)
    {
        fill(f, 0);
        fill(ls, 0);
        fill(dep, 0);
        fill(v, 0); 
        int n;
        scanf("%d", &n);
        for (int i = 1; i < n; i++)
        {
            int x, y;
            scanf("%d%d", &x, &y);
            e[i].to = y;
            e[i].next = ls[x];
            ls[x] = i;
            v[y]++; 
            f[y][0] = x;
        }
        for (int i = 1; i <= n; i++)
            if (!v[i])
            {
                dfs(1, i);
                break;
            }
        for (int j = 1; j <= 16; j++)
            for (int i = 1; i <= n; i++)
                f[i][j] = f[f[i][j-1]][j-1];
        int l, r;
        scanf("%d%d", &l, &r);
        printf("%d\n", LCA(l, r));
    }
}
]]>