题目链接:[WC2016]论战捆竹竿

题意:

给定长度为 $n$ 的字符串 $S$,现有一个空串 $T$,每次可将 $S$ 去掉一个 $\text{border}$ 后接在 $T$ 上。问 $T$ 的长度可以是 $[n,w]$ 中的多少个数。
$1\le n\le 5\times 10^5,1\le w\le 10^{18}$。

每次去掉一个 $\text{border}$ 加入,等价于加入一个 $\text{period}$。设 $S$ 的 $\text{period}$ 长度构成集合 $X$,取出 $X$ 中一个元素 $x$。

借用同余最短路的思想,设 $\text{mindis}_i$ 表示 $\bmod\ x=i$ 的所有长度中,能被 $X$ 中元素所拼出的最小长度。则能够被拼出的,长度 $\bmod\ x=i$ 的串长为 $\text{mindis}_i+k\times x$。可以通过该数组快速求出答案。

则现在需要做的是加速同余最短路的过程,快速求出 $\text{mindis}_i$。

考虑以下结论:

一个字符串 $S$ 的 $\text{border}$ 长度构成 $\mathcal{O}(\log n)$ 个等差数列。

而 $\text{period}$ 对应 $\text{border}$,所以也是等差数列。这启发我们将一个等差数列中的 $\text{border}$ 放在一起考虑。

对于一个等差数列 $a,a+b,a+2b\cdots a+kb$,若当前已经得到之前的等差数列所求出的,$\bmod\ a$ 意义下的 $\text{mindis}_i$,考虑如何快速用一个等差数列更新该数组。

在 $\bmod\ a$ 意义下,公差为 $b$ 的等差数列将所有 $[0,a-1]$ 中的元素划分为 $\gcd(a,b)$ 个等价类,每个等价类互不干扰。

对于一个等价类,找出其中 $\text{mindis}$ 值最小的位置 $p$,显然在这轮更新中,$\text{mindis}_p$ 不改变。则从 $p$ 开始,每次考虑使用等差数列中一个元素来更新后面的元素。

具体来说,对于一个下标为 $q$ 的位置,若 $q-p\le k$,则可以使用 $\text{mindis}_p+(q-p)\times b$ 来更新 $\text{mindis}_q$。(此处将在 $\bmod\ a$ 意义下的环展开,并认为 $p$ 是其中第一个元素,则 $q>p$)

这显然可以使用单调队列来实现。

还剩下一个问题,求解完当前等差数列后,如何跳转到下一个等差数列?对于一个首项为 $a_1$ 的等差数列,求解完后 $\text{mindis}$ 是在 $\bmod\ a_1$ 的意义下,而下一个首项为 $a_2$ 的等差数列中需在 $\bmod\ a_2$ 意义下进行。

首先,可以使用 $\text{mindis}_x$ 转移到 $\text{mindis'}_{\text{mindis}_x\bmod a_2}$,同时,还需要考虑每一个长度为 $a_1$ 的转移,即 $\text{mindis}_x+k\times a_1$ 也可能更新 $\text{mindis'}$。这和上面是类似的问题,但是由于没有项数限制,所以甚至不用单调队列,记一个前缀 $\min$ 即可。

时间复杂度 $\mathcal{O}(T\times n\log n)$。

//Code By CXY07
#include<bits/stdc++.h>
using namespace std;

//#define FILE
#define int long long
#define file(FILENAME) freopen(FILENAME".in", "r", stdin), freopen(FILENAME".out", "w", stdout)
#define randint(l, r) (rand() % ((r) - (l) + 1) + (l))
#define LINE() cout << "LINE = " << __LINE__ << endl
#define debug(x) cout << #x << " = " << x << endl
#define abs(x) ((x) < 0 ? (-(x)) : (x))
#define min(a, b) (a < b ? a : b)
#define inv(x) qpow((x), mod - 2)
#define lowbit(x) ((x) & (-(x)))
#define ull unsigned long long
#define pii pair<int, int>
#define LL long long
#define mp make_pair
#define pb push_back
#define scd second
#define vec vector
#define fst first
#define endl '\n'

const int MAXN = 5e5 + 10;
const int INF = 1.1e18;
const double PI = acos(-1);
const double eps = 1e-6;
//const int mod = 1e9 + 7;
//const int mod = 998244353;
//const int G = 3;
//const int base = 131;

int T, n, w, mod, Ans;
int fail[MAXN], period[MAXN], cnt;
int dis[MAXN];
char s[MAXN];

template<typename T> inline bool read(T &a) {
    a = 0; char c = getchar(); int f = 1;
    while(c < '0' || c > '9') {if(c == '-') f = -1; c = getchar();}
    while(c >= '0' && c <= '9') {a = a * 10 + (c ^ 48); c = getchar();}
    a *= f;
    return 1;
}

template<typename A, typename ...B>
inline bool read(A &x, B &...y) {return read(x) && read(y...);}

int Gcd(int x, int y) {
    if(!y) return x;
    return Gcd(y, x % y);
}

void clear() {
    memset(dis, 0x3f, sizeof dis);
    memset(fail, 0, sizeof fail);
}

void GetPeriod() {
    fail[1] = 0, fail[0] = -1, cnt = 0;
    for(int i = 2, j; i <= n; ++i) {
        j = fail[i - 1];
        while((~j) && s[j + 1] != s[i]) j = fail[j];
        fail[i] = ++j;
    }
    for(int i = fail[n]; ~i; i = fail[i]) period[++cnt] = n - i;
} 

void _run(int p, int gap, int c) { // i % mod == p, mod...mod + gap * c
    static int t[MAXN], m, pos; m = 0, pos = -1;
    static int q[MAXN], head, tail; head = 1, tail = 0;
    for(int i = p; ;) {
        if(pos == -1 || dis[i] < dis[pos]) pos = i;
        i = (i + gap) % mod; if(i == p) break;
    }
    for(int i = pos; ;) {
        t[++m] = i; i = (i + gap) % mod;
        if(i == pos) break;
    } assert(m == mod / Gcd(gap, mod));
    q[++tail] = 1;
    for(int i = 2; i <= m; ++i) {
        while(head <= tail && i - q[head] > c) head++;
        if(head <= tail && i - q[head] <= c) dis[t[i]] = min(dis[t[i]], dis[t[q[head]]] + (i - q[head]) * gap + mod);
        while(head <= tail && dis[t[i]] - i * gap < dis[t[q[tail]]] - q[tail] * gap) tail--;
        q[++tail] = i;
    }
}

void run(int gap, int c) { // mod...mod + gap * c
    int lim = Gcd(gap, mod);
    for(int i = 0; i < lim; ++i) _run(i, gap, c);
}

void trans(int _new) { // mod -> _new
    static int _d[MAXN]; memset(_d, 0x3f, sizeof _d);
    static int t[MAXN], m;
    int lim = Gcd(mod, _new);
    for(int i = 0; i < mod; ++i) _d[dis[i] % _new] = min(_d[dis[i] % _new], dis[i]);
    for(int p = 0, pos, mn; p < lim; ++p) {
        pos = -1, m = 0, mn = INF;
        for(int i = p; ;) {
            if(pos == -1 || _d[i] < _d[pos]) pos = i;
            i = (i + mod) % _new; if(i == p) break;
        }
        for(int i = pos; ;) {
            t[++m] = i;
            i = (i + mod) % _new; if(i == pos) break;
        } assert(m == _new / lim);
        for(int i = 1; i <= m; ++i) {
            _d[t[i]] = min(_d[t[i]], mn + i * mod);
            mn = min(mn, _d[t[i]] - i * mod);
        }
    }
    mod = _new;
    memcpy(dis, _d, sizeof dis);
}

void calc() {
    int L = 1, R; mod = period[L];
    while(L < cnt) {
        R = L + 1;
        while(R + 1 <= cnt && period[R + 1] - period[R] == period[R] - period[R - 1]) R++;
        run(period[L + 1] - period[L], R - L);
        L = R + 1; if(L <= cnt) trans(period[L]);
    }
}

void solve() {
    Ans = 0;
    clear(); read(n), read(w), scanf("%s", s + 1); w -= n;
    GetPeriod(); dis[0] = 0; calc();
    for(int i = 0; i < mod; ++i)
        if(dis[i] <= w) Ans = Ans + (w - dis[i]) / mod + 1;
    printf("%lld\n", Ans);
}

signed main () {
#ifdef FILE
    freopen("P4156.in", "r", stdin);
    freopen("P4156.out", "w", stdout);
#endif
    read(T);
    while(T--) solve();
    return 0;
}

标签: 字符串, 同余最短路, 组合字符串

添加新评论