用于计算频率的段树

时间:2015-12-30 07:53:25

标签: arrays data-structures segment-tree

有没有办法使用Segment Tree结构来计算数组中给定值的频率?

假设有一个大小为N的数组A,并且数组的每个元素A [i]都包含值0,1或2.我想执行以下操作:

  • 计算数组
  • 的任何范围[a,b]中的零的数量
  • 增加(mod 3)数组
  • 的任何范围[a,b]中的每个元素

示例:如果A = [0,1,0,2,0]:

  • 查询[2,4]必须返回1,因为[2,4]
  • 范围内有一个0
  • 增量[2,4]将A更新为[0,2,1,0,0]

这看起来非常类似于Range Sum Query问题,可以使用Segment Trees(在这种情况下使用Lazy Propagation因为范围更新)来解决,但是我没有成功调整我的seg树代码来解决这个问题,因为如果我将值存储在树中,就像在普通RSQ中一样,任何包含值“3”的父节点(例如)都不会毫无意义,因为有了这些信息,我无法提取在此范围内存在多少个零

提前致谢!

-

编辑:

段树是二叉树结构,用于在其节点中存储与数组相关的间隔。叶节点存储实际的阵列单元,并且每个父节点存储其子节点的函数f(节点 - >左,节点 - >右)。段树通常用于执行范围求和查询,其中我们想要计算数组范围[a,b]中所有元素的总和。在这种情况下,父节点计算的函数是其子节点中的值的总和。我们想使用segtrees来解决Range Sum Query问题,因为它允许在O(log n)中解决它(我们只需要下降树,直到我们找到完全被我们的范围查询覆盖的节点),远远好于天真的O(n)算法。

2 个答案:

答案 0 :(得分:1)

由于实际数组值存储在叶子中(级别L),因此让级别为L-1的节点存储它们包含的零个数(这将是范围[0,2]中的值)。除此之外,一切都是相同的,其余的节点将f(node-> left,node-> right)计算为node->left + node->right,并且零的数量将传播到根。

增加范围后,如果该范围不包含任何零,则不需要执行任何操作。然而,如果该范围具有零,那么所有那些零现在将是1并且当前节点的函数值(称为F)现在变为零。现在,值的变化需要向上传播到根,每次都从函数值中减去F.

答案 1 :(得分:0)

使用平方根分解可以轻松解决此问题 首先创建新的前缀和数组,对每个前缀和取3。 将整个数组划分为sqrt(n)个块。每个块将具有0、1、2的计数。同时创建一个临时数组,该数组将包含要添加到块元素中的总和 这是c ++中的实现:

#include <bits/stdc++.h>
using namespace std;
#define si(a) scanf("%d",&a)
#define sll(a) scanf("%lld",&a)
#define sl(a) scanf("%ld",&a)
#define pi(a) printf("%d\n",a)
#define pl(a) printf("%ld\n",a)
#define pll(a) printf("%lld\n",a) 
#define sc(a) scanf("%c",&a)
#define pc(a) printf("%c",a)
#define ll long long
#define mod 1000000007
#define w while
#define pb push_back
#define mp make_pair
#define f first
#define s second
#define INF INT_MAX
#define fr(i,a,b) for(int i=a;i<=b;i++)



///////////////////////////////////////////////////////////////
struct block
{
    int one;
    int two;
    int zero;
    block()
    {
        one=two=zero=0;
    }
};
ll a[100005],a1[100005];
ll sum[400];
int main()
{
    int n,m;
    cin>>n>>m;
    string s;
    cin>>s;
    int N=(int)(sqrt(n));
    struct block b[N+10];
    for(int i=0;i<n;i++)
    {
        a[i]=s[i]-'0';
        a[i]%=3;
        a1[i]=a[i];
    }
    for(int i=1;i<n;i++)
    a[i]=(a[i]+a[i-1])%3;
    for(int i=0;i<n;i++)
    {
        if(a[i]==0)
        b[i/N].zero++;
        else if(a[i]==1)
        b[i/N].one++;
        else
        b[i/N].two++;
    }
    w(m--)
    {
        int type;
        si(type);
        if(type==1)
        {
            int ind,x;
            si(ind);
            si(x);
            x%=3;
            ind--;
                int diff=(x-a1[ind]+3)%3;
                if(diff==1)
                {
                    int st=ind/N;
                    int end=(n-1)/N;
                    int kl=(st+1)*N;
                    int hj=min(n,kl);
                    for(int i=st*N;i<hj;i++)
                    {
                        a[i]=(a[i]+sum[st])%3;
                    }
                    sum[st]=0;
                    for(int i=ind;i<hj;i++)
                    {
                        if(a[i]==0)
                        b[st].zero--;
                        else if(a[i]==1)
                        b[st].one--;
                        else
                        b[st].two--;


                        a[i]=(a[i]+diff)%3;



                        if(a[i]==0)
                        b[st].zero++;
                        else if(a[i]==1)
                        b[st].one++;
                        else
                        b[st].two++;
                    }

                    for(int i=st+1;i<=end;i++)
                    {
                        int yu=b[i].zero;
                        b[i].zero=b[i].two;
                        b[i].two=b[i].one;
                        b[i].one=yu;
                        sum[i]=(sum[i]+diff)%3;
                    }
                }
                else if(diff==2)
                {


                    int st=ind/N;
                    int end=(n-1)/N;
                    int kl=(st+1)*N;
                    int hj=min(n,kl);
                    for(int i=st*N;i<hj;i++)
                    {
                        a[i]=(a[i]+sum[st])%3;
                    }
                    sum[st]=0;
                    for(int i=ind;i<hj;i++)
                    {
                        if(a[i]==0)
                        b[st].zero--;
                        else if(a[i]==1)
                        b[st].one--;
                        else
                        b[st].two--;


                        a[i]=(a[i]+diff)%3;



                        if(a[i]==0)
                        b[st].zero++;
                        else if(a[i]==1)
                        b[st].one++;
                        else
                        b[st].two++;
                    }

                    for(int i=st+1;i<=end;i++)
                    {
                        int yu=b[i].zero;
                        b[i].zero=b[i].one;
                        b[i].one=b[i].two;
                        b[i].two=yu;
                        sum[i]=(sum[i]+diff)%3;
                    }
                }

            a1[ind]=x%3;
        }
        else
        {
            int l,r;
            ll x=0,y=0,z=0;
            si(l);
            si(r);
            l--;
            r--;
            int st=l/N;
            int end=r/N;
            if(st==end)
            {
                for(int i=l;i<=r;i++)
                {
                    ll op=(a[i]+sum[i/N])%3;
                    if(op==0)
                    x++;
                    else if(op==1)
                    y++;
                    else 
                    z++;
                }
            }
            else
            {
                for(int i=l;i<(st+1)*N;i++)
                {
                    ll op=(a[i]+sum[i/N])%3;
                    if(op==0)
                    x++;
                    else if(op==1)
                    y++;
                    else 
                    z++;
                }
                for(int i=end*N;i<=r;i++)
                {
                    ll op=(a[i]+sum[i/N])%3;
                    if(op==0)
                    x++;
                    else if(op==1)
                    y++;
                    else 
                    z++;
                }
                for(int i=st+1;i<=end-1;i++)
                {
                    x+=b[i].zero;
                    y+=b[i].one;
                    z+=b[i].two;
                }
            }
            ll temp=0;
            if(l!=0)
            {
                temp=(a[l-1]+sum[(l-1)/N])%3;
            }
            ll ans=(x*(x-1))/2;
            ans+=((y*(y-1))/2);
            ans+=((z*(z-1))/2);
            if(temp==0)
            ans+=x;
            else if(temp==1)
            ans+=y;
            else
            ans+=z;
            pll(ans);
        }
    }
    return 0;
}
相关问题