【题解】Luogu-P4389 付公主的背包

好题!

题目链接:P4389 付公主的背包

给定 $n$ 种物品,第 $i$ 种体积为 $v_i$,有无限件。

对于 $s\in [1,m]$,求出用这些物品装恰好 $s$ 体积的方案数。答案对 $998244353$ 取模。

$1\le n,m\le 10^5, 1\le v_i\le m$。

首先有最简单的背包做法,但那并不重要。

由于信仰模数,并且对于 $[1,m]$ 中每个数都要求解,所以考虑多项式方法。

如果我们设第 $i$ 种物品的生成函数为:

$$G_i(x)=\sum_{i=0}^{\infty} x^{iv_i}$$

化成封闭形式:

$$G_i(x)=\dfrac{1}{1-x^{v_i}}$$

那么就是把他们全部卷起来:

$$F(x)=\prod_{i=1}^n G_i(x)$$

那背包大小为 $s$ 的答案就是 $[x^s]F(x)$。

如果直接上 $\texttt{NTT}$,复杂度是 $\mathcal{O}(nm\log m)$ 的,无法通过。但是由于这个多项式非常有特点,考虑用别的方法。

众所周知 我们可以对每个 $G_i(x)$ 做多项式 $\ln$,这样就可以把卷积转换成更为简单的系数对应相加了。最后再 $\exp$ 回去就可以了。但是对一个多项式做 $\ln$ 也是 $\mathcal{O}(m\log m)$ 的,总复杂度没变。

求一下上面那个东西的 $\ln$:

$$\ln \dfrac{1}{1-x^v}=\ln 1-\ln(1-x^v)=-\ln(1-x^v)$$

这时候考虑在推导多项式 $\ln$ 时候得到的柿子:

$$(\ln(1-x^v))'= \dfrac{(1-x^v)'}{(1-x^v)}$$

又 $\dfrac{1}{(1-x^v)}=\sum_{i=0}^{\infty} x^{iv}$,所以可以写成:

$$=(-vx^{v-1})\times \sum_{i=0}^{\infty} x^{iv}$$

$$=\sum_{i=1}^{\infty} -vx^{iv-1}$$

积分回去:

$$\int (\sum_{i=1}^{\infty} -vx^{iv-1})=-\sum_{i=1}^{\infty} \dfrac{x^{iv}}{i}$$

所以

$$\ln G(x)=0-(-\sum_{i=1}^{\infty} \dfrac{x^{iv}}{i})$$

$$=\sum_{i=1}^{\infty} \dfrac{x^{iv}}{i}$$

然后就可以 $\mathcal{O}(\log m)$ 求出一个 $G(x)$ 的 $\ln$。最后再 $\exp$ 回去即可。

注意对于相同的 $v_i$ 需要计数然后一次处理,否则调和级数的复杂度就假了。

时间复杂度 $\mathcal{O}(n\log n)$。

//Code By CXY07
#include<bits/stdc++.h>
using namespace std;

//#define FILE
#define int long long
//#define ull unsigned long long
#define LL long long
#define pii pair<int,int>
#define mp make_pair
#define pb push_back
#define fst first
#define scd second
#define inv(x) qpow((x),mod - 2)
#define lowbit(x) ((x) & (-(x)))
#define vec vector

const int MAXN = 3e5 + 10;
const int INF = 2e9;
const double PI = acos(-1);
const double eps = 1e-6;
//const int mod = 1e9 + 7;
const int mod = 998244353;
const int GG = 3;
//const int base = 131;

int n, m, Gi;
int lim, l, r[MAXN];
int F[MAXN], v[MAXN], ExpF[MAXN];

template<typename T> inline void read(T &a) {
    a = 0; char c = getchar(); int f = 1;
    while(c < '0' || c > '9') {if(c == '-') f = -1; c = getchar();}
    while(c >= '0' && c <= '9') {a = a * 10 + (c ^ 48); c = getchar();}
    a *= f;
}

int qpow(int x, int b) {
    int res = 1;
    for(; b; b >>= 1, (x *= x) %= mod) if(b & 1) (res *= x) %= mod;
    return res;
}

void Init(int n) {
    lim = 1, l = 0;
    for(; lim <= n; lim <<= 1, l++);
    for(int i = 0; i < lim; ++i)
        r[i] = ((r[i >> 1] >> 1) | ((i & 1) << (l - 1)));
}

void NTT(int *A, int opt) {
    for(int i = 0; i < lim; ++i)
        if(r[i] > i) swap(A[r[i]], A[i]);
    for(int mid = 1, w, now, x, y; mid < lim; mid <<= 1) {
        w = qpow(opt == 1 ? GG : Gi, (mod - 1) / (mid << 1));
        for(int i = 0; i < lim; i += (mid << 1)) {
            now = 1;
            for(int j = 0; j < mid; ++j, (now *= w) %= mod) {
                x = A[i + j], y = A[i + j + mid] * now % mod;
                A[i + j] = (x + y) % mod;
                A[i + j + mid] = ((x - y) % mod + mod) % mod;
            }
        }
    }
    if(opt == -1) {
        int ilim = inv(lim);
        for(int i = 0; i < lim; ++i) A[i] = A[i] * ilim % mod;
    }
}

void Inv(int *F, int *G, int len) {
    static int Tmp[MAXN];
    if(len == 1) return (void) (G[0] = inv(F[0]));
    Inv(F, G, (len + 1) >> 1);
    Init(len << 1);
    for(int i = 0; i < len; ++i) Tmp[i] = F[i];
    for(int i = len; i < lim; ++i) Tmp[i] = 0;
    NTT(Tmp, 1); NTT(G, 1);
    for(int i = 0; i < lim; ++i) G[i] = ((2 - Tmp[i] * G[i] % mod) % mod * G[i] % mod + mod) % mod;
    NTT(G, -1);
    for(int i = len; i < lim; ++i) G[i] = 0;
    for(int i = 0; i < lim; ++i) Tmp[i] = 0;
}

void Ln(int *F, int *G, int len) {
    static int A[MAXN], B[MAXN];
    Inv(F, B, len);
    Init(len << 1);
    for(int i = 1; i < len; ++i) A[i - 1] = F[i] * i % mod;
    A[len - 1] = 0;
    NTT(A, 1); NTT(B, 1);
    for(int i = 0; i < lim; ++i) A[i] = A[i] * B[i] % mod;
    NTT(A, -1);
    for(int i = len; i < lim; ++i) A[i] = 0;
    for(int i = len - 1; i >= 1; --i) G[i] = A[i - 1] * inv(i) % mod;
    G[0] = 0;
    for(int i = 0; i < lim; ++i) A[i] = B[i] = 0;
} 

void Exp(int *F, int *G, int len) {
    static int ln[MAXN];
    if(len == 1) return (void)(G[0] = 1);
    Exp(F, G, (len + 1) >> 1);
    for(int i = 0; i < len; ++i) ln[i] = 0;
    Ln(G, ln, len);
    Init(len << 1);
    for(int i = 0; i < len; ++i) ln[i] = ((F[i] - ln[i]) % mod + mod) % mod;
    for(int i = len; i < lim; ++i) ln[i] = G[i] = 0;
    ln[0]++;
    NTT(ln, 1); NTT(G, 1);
    for(int i = 0; i < lim; ++i) G[i] = G[i] * ln[i] % mod;
    NTT(G, -1);
    for(int i = len; i < lim; ++i) G[i] = 0;
    for(int i = 0; i < lim; ++i) ln[i] = 0;     
}

signed main () {
#ifdef FILE
    freopen(".in","r",stdin);
    freopen(".out","w",stdout);
#endif
    read(n); read(m); Gi = inv(GG);
    for(int i = 1, tmp; i <= n; ++i) read(v[i]);
    sort(v + 1, v + n + 1);
    for(int i = 1, tot = 1; i <= n; ++i) {
        if(v[i] != v[i + 1]) {
            for(int j = v[i], cnt = 1; j <= m; j += v[i], cnt++)
                (F[j] += (inv(cnt) * tot % mod) % mod) %= mod;
            tot = 1;
        } else tot++;
    }
    Exp(F, ExpF, m + 1);
    for(int i = 1; i <= m; ++i)
        printf("%lld\n", (ExpF[i] % mod + mod) % mod);
    return 0;
}
添加新评论