回滚莫队学习笔记

被这个东西把 WC 金牌搞没了……来学一下。

回滚莫队有两种,不删除莫队和不插入莫队,可以用来解决加点和删点只有一种可以在复杂度正确的时间内维护的情形。


1.不删除莫队

不删除莫队具体操作如下:

  1. 将所有询问按照左端点所在块升序,同一个块内右端点升序排序。
  2. 如果询问左右端点在同一个块内,特判,直接暴力。(否则右端点会左移导致需要删点)
  3. 如果左端点来到一个新的块,那么清空莫队,将指针放在这个块的右端点处。
  4. 然后左端点在这个块内的询问右端点是不降的,可以直接右移右端点。
  5. 每次询问都左移左端点,查询结束后将左端点恢复到当前块右端点的状态。

可以发现这样就避免了删点操作。

然后来分析复杂度:首先暴力的部分复杂度显然不会超过 O(nn)O(n\sqrt n)

然后是清空莫队的操作,这个一次 O(n)O(n),最多进行 O(n)O(\sqrt n) 次,也是 O(nn)O(n\sqrt n)

右移右端点操作,每个块最多 O(n)O(n),总的是 O(nn)O(n\sqrt n)

左移和重置左端点,每次最多 O(n)O(\sqrt n),全部询问加起来是 O(nn)O(n\sqrt n)

综上,复杂度是 O(nn)O(n\sqrt n) 的。


2.不插入莫队

类比不删除莫队:

  1. 将所有询问按照左端点所在块升序,同一个块内右端点降序排序。
  2. 不需要特判同一个块内的情况。(因为右端点不断左移不会影响)
  3. 到新的块之后将左端点移动到块的左端点,右端点移动到序列末尾。(当然这里需要支持能够在正确复杂度内完成插入序列并初始化的操作)
  4. 块内询问右端点不增,每次左移右端点。
  5. 对于每个询问右移左端点,查询结束后重置。

这样就避免了加点操作。

复杂度分析同理,也是 O(nn)O(n\sqrt n)


3.一些例题

(1) 「JOISC 2014 Day1」历史研究

给一个序列,询问区间数字出现次数与数字本身乘积的最大值。

加点容易维护,因为答案只可能被当前加入的点更新;但是删点不行,因为我们没法快速找出来删掉这个点之后的最大值。

于是套一个不删除莫队上去即可。

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
using namespace std;
int n,m,block,pos[100001],p[100001],sum[100001][2],node[100001],cnt,a[100001];
long long maxn,ans[100001];
struct element
{
    int l,r,id;
    bool operator <(const element &other) const
    {
        return pos[l]^pos[other.l]? pos[l]<pos[other.l]:r<other.r;
    }
}q[100001];
inline int read()
{
    int x=0;
    char c=getchar();
    while(c<'0'||c>'9')
        c=getchar();
    while(c>='0'&&c<='9')
    {
        x=(x<<1)+(x<<3)+(c^48);
        c=getchar();
    }
    return x;
}
inline void insert(int x)
{
    ++sum[x][0];
    maxn=max(maxn,1ll*sum[x][0]*node[x]);
}
inline void del(int x)
{
    --sum[x][0];
}
int main()
{
    n=read(),m=read();
    block=sqrt(n);
    for(int i=1;i<=n;++i)
    {
        node[i]=a[i]=read();
        pos[i]=(i-1)/block+1;
        if(i!=1&&pos[i]!=pos[i-1])
            p[pos[i-1]]=i-1;
    }
    p[pos[n]]=n;
    sort(node+1,node+n+1);
    cnt=unique(node+1,node+n+1)-node-1;
    for(int i=1;i<=n;++i)
        a[i]=lower_bound(node+1,node+cnt+1,a[i])-node;
    for(int i=1;i<=m;++i)
    {
        q[i].l=read(),q[i].r=read();
        q[i].id=i;
    }
    sort(q+1,q+m+1);
    for(int i=1,l=1,r=0,lst=0;i<=m;++i)
    {
        if(pos[q[i].l]==pos[q[i].r])
        {
            for(int j=q[i].l;j<=q[i].r;++j)
            {
                ++sum[a[j]][1];
                ans[q[i].id]=max(ans[q[i].id],1ll*sum[a[j]][1]*node[a[j]]);
            }
            for(int j=q[i].l;j<=q[i].r;++j)
                --sum[a[j]][1];
            continue;
        }
        if(pos[q[i].l]^lst)
        {
            lst=pos[q[i].l];
            for(;l>p[pos[q[i].l]]+1;--l)
                insert(a[l-1]);
            for(;r<p[pos[q[i].l]];++r)
                insert(a[r+1]);
            for(;l<p[pos[q[i].l]]+1;++l)
                del(a[l]);
            for(;r>p[pos[q[i].l]];--r)
                del(a[r]);
            maxn=0;
        }
        for(;r<q[i].r;++r)
            insert(a[r+1]);
        long long tmp=maxn;
        for(;l>q[i].l;--l)
            insert(a[l-1]);
        ans[q[i].id]=maxn;
        maxn=tmp;
        for(;l<p[pos[q[i].l]]+1;++l)
            del(a[l]);
    }
    for(int i=1;i<=m;++i)
        cout<<ans[i]<<'\n';
    cout<<endl;
    return 0;
}

(2)mex

给一个序列,询问区间 mex。

首先忽略所有比 nn 大的数。

然后,删点是容易维护的,开桶维护每个数的数量,删除到 0 之后直接看能不能更新答案即可,但是加点不行,我们还是找不出最大值。

于是套一个不插入莫队即可。

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
using namespace std;
int n,m,a[200001],sum[200001][2],pos[200001],p[200001],block,ans[200001],mex;
struct element
{
    int l,r,id;
    bool operator <(const element &other) const
    {
        return pos[l]^pos[other.l]? pos[l]<pos[other.l]:r>other.r;
    }
}q[200001];
inline int read()
{
    int x=0;
    char c=getchar();
    while(c<'0'||c>'9')
        c=getchar();
    while(c>='0'&&c<='9')
    {
        x=(x<<1)+(x<<3)+(c^48);
        c=getchar();
    }
    return x;
}
inline void insert(int x)
{
    if(x<=n)
        ++sum[x][0];
}
inline void del(int x)
{
    if(x<=n&&!--sum[x][0])
        mex=min(mex,x);
}
int main()
{
    n=read(),m=read();
    block=sqrt(n);
    for(int i=1;i<=n;++i)
    {
        a[i]=read();
        pos[i]=(i-1)/block+1;
        if(pos[i]!=pos[i-1])
            p[pos[i]]=i;
    }
    for(int i=1;i<=m;++i)
    {
        q[i].l=read(),q[i].r=read();
        q[i].id=i;
    }
    sort(q+1,q+m+1);
    for(int i=1,l=1,r=0,lst=0;i<=m;++i)
    {
        if(lst^pos[q[i].l])
        {
            lst=pos[q[i].l];
            for(;l>p[pos[q[i].l]];--l)
                insert(a[l-1]);
            for(;r<n;++r)
                insert(a[r+1]);
            for(;l<p[pos[q[i].l]];++l)
                del(a[l]);
            for(int j=0;j<=n;++j)
                if(!sum[j][0])
                {
                    mex=j;
                    break;
                }
        }
        for(;r>q[i].r;--r)
            del(a[r]);
        int tmp=mex;
        for(;l<q[i].l;++l)
            del(a[l]);
        ans[q[i].id]=mex;
        for(;l>p[pos[q[i].l]];--l)
            insert(a[l-1]);
        mex=tmp;
    }
    for(int i=1;i<=m;++i)
        cout<<ans[i]<<'\n';
    return 0;
}

(3)【WC2022】 秃子酋长

把我送走的题。

首先有一个 naive 的 O(nnlogn)O(n\sqrt n\log n) 的做法,就是跑普通莫队,然后用 set 维护数字集合,每次插入和删除都只会改变 O(1)O(1) 组相邻关系,直接维护即可,这样可以拿到 5050 分,也是我的考场做法。

考虑去掉 log\log。我们发现如果我们用链表就不带 log\log,但是链表只能支持快速删除,无法支持快速插入,因为插入时要遍历整个链表。

那么考虑不插入莫队,每次移动左端点的时候记录一下要还原的点的链表信息即可,细节稍微多一些。

据说常数小就过了,还没有数据,不知道行不行。

#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
using namespace std;
int n,m,a[500001],p[500001],block,pos[500001],s[500001],pre[500001],nxt[500001],sum[500001],h[500001][2];
long long ans[500001],res;
struct element
{
    int l,r,id;
    bool operator <(const element &other) const
    {
        return pos[l]^pos[other.l]? pos[l]<pos[other.l]:r>other.r;
    }
}q[500001];
inline int read()
{
    int x=0;
    char c=getchar();
    while(c<'0'||c>'9')
        c=getchar();
    while(c>='0'&&c<='9')
    {
        x=(x<<1)+(x<<3)+(c^48);
        c=getchar();
    }
    return x;
}
inline void print(long long x)
{
    if(x>=10)
        print(x/10);
    putchar(x%10+'0');
}
inline int Abs(int x)
{
    return x>=0? x:-x;
}
inline void insert(int x)
{
    ++sum[x];
}
inline void del(int x)
{
    --sum[x];
    if(pre[x])
    {
        res-=Abs(p[pre[x]]-p[x]);
        nxt[pre[x]]=nxt[x];
    }
    if(nxt[x])
    {
        res-=Abs(p[nxt[x]]-p[x]);
        pre[nxt[x]]=pre[x];
    }
    if(pre[x]&&nxt[x])
        res+=Abs(p[pre[x]]-p[nxt[x]]);
    pre[x]=nxt[x]=0;
}
int main()
{
    n=read(),m=read();
    block=sqrt(n);
    for(int i=1;i<=n;++i)
    {
        p[a[i]=read()]=i;
        pos[i]=(i-1)/block+1;
        if(pos[i]^pos[i-1])
            s[pos[i]]=i;
    }
    for(int i=1;i<=m;++i)
    {
        q[i].l=read(),q[i].r=read();
        q[i].id=i;
    }
    sort(q+1,q+m+1);
    for(int i=1,l=1,r=0,lst=0;i<=m;++i)
    {
        if(lst^pos[q[i].l])
        {
            lst=pos[q[i].l];
            for(;l>s[pos[q[i].l]];--l)
                insert(a[l-1]);
            for(;r<n;++r)
                insert(a[r+1]);
            for(;l<s[pos[q[i].l]];++l)
                del(a[l]);
            int k=0;
            res=0;
            for(int j=1;j<=n;++j)
            {
                pre[j]=nxt[j]=0;
                if(sum[j])
                {
                    pre[j]=k;
                    if(k)
                    {
                        nxt[k]=j;
                        res+=Abs(p[j]-p[k]);
                    }
                    k=j;
                }
            }
        }
        for(;r>q[i].r;--r)
            del(a[r]);
        long long tmp=res;
        for(int j=l;j<q[i].l;++j)
        {
            h[j][0]=pre[a[j]];
            h[j][1]=nxt[a[j]];
        }
        for(;l<q[i].l;++l)
            del(a[l]);
        ans[q[i].id]=res;
        for(;l>s[pos[q[i].l]];--l)
        {
            pre[a[l-1]]=h[l-1][0];
            nxt[a[l-1]]=h[l-1][1];
            if(pre[a[l-1]])
                nxt[pre[a[l-1]]]=a[l-1];
            if(nxt[a[l-1]])
                pre[nxt[a[l-1]]]=a[l-1];
            h[l-1][0]=h[l-1][1]=0;
            insert(a[l-1]);
        }
        res=tmp;
    }
    for(int i=1;i<=m;++i)
    {
        print(ans[i]);
        putchar('\n');
    }
    return 0;
}

怄火。挥手。转圈。街舞。跳跳。献吻。跳绳。激动。发抖。磕头。爱情。飞吻。左太极。右太极。回头。