从spoj树相关的任务。树上的高(GOT)

时间:2013-08-13 21:31:57

标签: algorithm tree graph-algorithm

有一棵树,每个顶点都分配了一个数字。对于每个查询(a,b,c),系统会询问您从ab的路径上是否有一个顶点,该路径的编号为c

可以有重复的数字分配,这意味着可以为多个节点分配相同的数字。

我想到了一些涉及LCA的解决方案,将树节点转换为间隔,但是无法对其进行优化,因此解决方案会超时。

任何人都可以帮忙解决这个问题吗?这是问题的链接: http://www.spoj.com/problems/GOT/

1 个答案:

答案 0 :(得分:-1)

#pragma GCC optimize ("O3")
#pragma GCC target ("sse4")
#pragma comment(linker, "/stack:200000000")
#pragma GCC optimize ("Ofast")
#pragma GCC target ("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,avx,tune=native")
#include<bits/stdc++.h>  
using namespace std;

#define ff          first   
#define ss          second
#define pb          push_back
#define pf          push_front
#define pob         pop_back
#define pof         pop_front
#define mp          make_pair
#define ins         insert
#define ll          long long int
#define ld          long double
#define ull         unsigned long long
#define loop(i,n)   for(ll i=0;i<n;i++)
#define loop1(i,n)  for(ll i=1;i<=n;i++)
#define rev(i,n)    for(ll i=n-1;i>=0;i--)
#define rev1(i,n)   for(ll i=n;i>=1;i--)
#define test(t)     int t;cin>>t;while(t--)
#define endl        '\n'
#define all(c)      (c).begin(), (c).end()
#define tr(c,it)    for(auto it=(c).begin(); it!=(c).end(); it++)
#define rtr(c,it)   for(auto it=(c).rbegin(); it!=(c).rend(); it++)
#define sz(c)       (c).size()
// Define mod as per requirement
const ll MOD =1000000007;
//const ll MOD= 998244353; 
typedef vector<int> vi;
typedef vector<ll> vll;
typedef vector<bool> vb;
typedef vector<vector<int> > vvi;
typedef vector<vector<ll> > vvll;
typedef pair<ll,ll> pll;
typedef pair<int,int> pi;
typedef vector<pair<ll,ll> > vpll;
typedef vector<vector<pair<ll,ll> > > vvpll; 
struct Order1{bool operator()(pll const& a,pll const& b)const{return (a.ss < b.ss || (a.ss == b.ss && a.ff < b.ff));}};
struct Order2{bool operator()(pll const& a,pll const& b)const{return (a.ff < b.ff || (a.ff == b.ff && a.ss < b.ss));}};
typedef priority_queue <ll> maxpq;// maximum priority_queue
typedef priority_queue <ll, vector<ll>, greater<ll> > minpq; //minimum priority_queue
typedef priority_queue <pll> maxpq_pair_f;// maximum priority_queue for pairs with maximum ff
typedef priority_queue <pll,vector<pll>, Order2 > max_pq_s; // maximum priority_queue for pairs with maximum ss
typedef priority_queue< pll, vector<pll>, greater<pll> > minpq_pair_f; //minimum priority_queue for pairs with the smallest ff
typedef priority_queue< pll, vector<pll>, Order1 > minpq_pair_s; //minimum priority_queue for pairs with the smallest ss

long long min(long long a,long long b){if(b<=a)return b;return a;}
long long max(long long a,long long b){if(a>=b)return a;return b;}
long long add(long long x, long long y){x += y;while(x >= MOD) x -= MOD;while(x < 0) x += MOD;return x;}
long long multiply(long long x, long long y){return (x * 1ll * y) % MOD;}
long long power(long long x, long long y){long long z = 1;while(y){if(y & 1) z = multiply(z, x);x = multiply(x, x);y >>= 1;}return z;}
long long modInverse(long long x){return power(x, MOD - 2);}
long long divide(long long x, long long y){return multiply(x, modInverse(y));}
long long gcd(long long a,long long b) { if(a == 0)return b;return gcd(b % a, a);}
long long lcm(long long a,long long b){long long val=max(a,b)/gcd(a,b);val*=min(a,b);return val;} 
const long long N=200050; // can be changed as per constraints
long long fact[N]; // array to store values of factorial
void cal_factorial(){fact[0] = 1;for(long long i = 1; i < N; i++)fact[i] = multiply(fact[i - 1], i);}// function to calculate factorial upto N
long long nCr(long long n, long long k){return divide(fact[n], multiply(fact[k], fact[n - k]));}
void BOOST(){ios_base::sync_with_stdio(false);cin.tie(NULL);cout.tie(NULL);}
bool inc(long long x,long long y){return x<=y;}// for increasing order
bool dec(long long x,long long y){return x>=y;}// for decreasing order
// For union and intersection of vectors
vector<int>Intersection(vi &v1,vi&v2){vi v3;sort(all(v1));sort(all(v2));set_intersection(all(v1),all(v2),back_inserter(v3));return v3;}
vector<int>Union(vi &v1,vi&v2){vi v3;sort(all(v1));sort(all(v2));set_union(all(v1),all(v2),back_inserter(v3));return v3;}


vector<int>graph[100010];
int level[100010];
int LCA[100010][21];
int start[100010];
int finish[100010];
int euler_tour[2*100010];
int name[100010];// the number assigned to each vertex
int ans[200050];
int freq[100010];
int assigned[100010];
int timer=1;
int n,q;
const int block=450;
struct query{
    int l;
    int r;
    int index;
    int c;
    bool is_LCA;
    int lca;
};
query Q[200050];
bool comp(query q1,query q2){
    int val1=q1.l/block;
    int val2=q2.l/block;
    if(val1!=val2)
        return val1<val2;
    return q1.r<q2.r;
}

void dfs(int node,int par,int lvl){
    start[node]=timer;
    euler_tour[timer]=node;
    level[node]=lvl;
    LCA[node][0]=par;
    timer++;
    for(auto child:graph[node]){
        if(child!=par)
            dfs(child,node,lvl+1);
    }

    finish[node]=timer;
    euler_tour[timer]=node;
    timer++;
}

void sparse_table(){
    memset(LCA,-1,sizeof(LCA));
    dfs(1,-1,0);
    for(int j=1;j<=20;j++){
        for(int i=1;i<=n;i++)
            if(LCA[i][j-1]!=-1)
                LCA[i][j]=LCA[LCA[i][j-1]][j-1];
    }
}
int find_LCA(int a,int b){
    if(level[a]>level[b])
        swap(a,b);
    int d=level[b]-level[a];
    while(d>0){
        int jump=log2(d);
        b=LCA[b][jump];
        d=d-(1<<jump);
    }
    if(a==b)
        return a;
    for(int i=20;i>=0;i--){
        if(LCA[a][i]!=-1 && LCA[a][i]!=LCA[b][i]){
            a=LCA[a][i];
            b=LCA[b][i];
        }
    }
    return LCA[a][0];
}

void add(int i){
    int node=euler_tour[i];
    freq[node]++;
    if(freq[node]==2){
        assigned[name[node]]--;
        return;
    }
    assigned[name[node]]++;

}
void remove(int i){
    int node=euler_tour[i];
    freq[node]--;
    if(freq[node]==1){
        assigned[name[node]]++;
        return;
    }
    assigned[name[node]]--;
}

void clear_everything(){
    for(int i=0;i<100010;i++)
        graph[i].clear();
    memset(freq,0,sizeof(freq));
}

static int uget(){
    int c;
    while(c = getchar(), isspace(c)) {}
    int n = c - '0';
    while(c = getchar(), isdigit(c)) { n = n * 10 + (c - '0'); }
    return n;
}

int main(){
    BOOST();
    int t=0;
    while(scanf("%d%d", &n,&q) != EOF){
        if(t>0)putchar('\n');
        t++;
        clear_everything();
        for(int i=1;i<=n;i++)
            name[i]=uget();
        for(int i=1;i<=n-1;i++){
            int u,v;
            u=uget();
            v=uget();
            graph[u].pb(v);
            graph[v].pb(u);
        }
        sparse_table();
        for(int i=1;i<=q;i++){
            int a,b,c;
            a=uget();
            b=uget();
            c=uget();
            int lca=find_LCA(a,b);
            if(start[a]>start[b])
                swap(a,b);
            if(lca==a || lca==b){
                Q[i].l=start[a];
                Q[i].r=start[b];
                Q[i].index=i;
                Q[i].c=c;
                Q[i].is_LCA=true;
                Q[i].lca=-1;
                continue;
            }
            Q[i].l=finish[a];
            Q[i].r=start[b];
            Q[i].index=i;
            Q[i].c=c;
            Q[i].is_LCA=false;
            Q[i].lca=lca;
        }

        sort(Q+1,Q+1+q,comp);
        int MoLeft=1,MoRight=0;
        for(int i=1;i<=q;i++){
            int L=Q[i].l;
            int R=Q[i].r;
            int to_find=Q[i].c;
            int id=Q[i].index;

            while(L>MoLeft){remove(MoLeft);MoLeft++;}
            while(L<MoLeft){MoLeft--;add(MoLeft);}
            while(R>MoRight){MoRight++;add(MoRight);}
            while(R<MoRight){remove(MoRight);MoRight--;}

            if(Q[i].is_LCA){
                if(assigned[to_find]>0)
                    ans[id]=1;
                else
                    ans[id]=0;
            }
            else{
                if(assigned[to_find])
                    ans[id]=1;
                else if(name[Q[i].lca]==to_find)
                    ans[id]=1;
                else
                    ans[id]=0;
            }
        }

        for(int i=1;i<=q;i++){
            if(ans[i])puts("Find");
            else
                puts("NotFind");
        }
    }

}