【题解】[WC2021] 表达式求值

趣味妙题

题目链接:[WC2021] 表达式求值

题意:

给定 $m$ 个长度为 $n$ 的数组,现在有两种二元操作,表示两个数组每一位取 $\max$ 或 $\min$。

现给定一个包括该两个二元操作与 "$?$" 的表达式 $E$,求将 "$?$" 任意填写后,得到的所有数组的元素的和对 $10^9+7$ 取模。

$1\le n\le 5\times 10^4,1\le |E|\le 5\times 10^4,\ 1\le m\le 10$。

这题后半部分实际上是我高一军训期间无聊,出给机房同学的简单题,考场上瞬间联想到这个,然后被自己给震惊到。赞赞

考场只写了 $\text{70pts}$,但是实际上做法和正解只差一点了。 真的亏

首先肯定不能直接对着表达式做,所以建出表达式树。那么现在叶子上有若干值,每一个非叶子节点有一个 $\max,\min$ 或 "$?$" 的标记,表示对两个儿子做该操作。

发现每一位是独立的,因为二元操作都是按位操作的,所以枚举每一位,问题转化为叶子上是一个数,而非数组。

考虑一个错误的想法,如果考场上你以为这个答案可以利用单调性,然后二分一个数 $x$,那么就可以套路地将 $\ge x$ 的数写作 $1$,$<x$ 的部分写作 $0$,这样一个 $\max,\min$ 就变成了按位或、按位与。

但你发现这样会算重,因为二分的时候你不知道要给答案加入多少,所以不能二分。

考虑到 $m\le 10$,那就枚举好了。直接将某一位的 $m$ 个值从小到大排序,依次按照这 $m$ 个值进行划分。值的总和只需要进行简单容斥就可以轻松求出。

至于上述的计数 $\text{dp}$,设 $\text{dp}_{x,0/1}$ 为当前在 $x$ 点,数是 $0/1$ 的方案数,随便转移一下,最后答案就是 $\text{dp}_{\text{root},1}$,单次 $\mathcal{O}(|E|)$。

这样的复杂度是 $\mathcal{O}(nm|E|)$,加上没有”$?$“ 的暴力就可以拿到 $\text{70pts}$。由于本人当时还对 $\texttt{T1}$ 毫无思路,所以就直接走了,现在回想起来还是有些可惜的。

如何优化复杂度呢?每一位单独枚举是不太能优化的了,所以考虑从 $m$ 入手。

$m\le 10$,为什么不直接 $2^m$ 暴力枚举在某一种划分的方案下,每一个值是 $0/1$ 的方案数呢?预处理下来,之后每次查询就只需要 $\mathcal{O}(1)$ 了。

于是就可以通过本题,时间复杂度 $\mathcal{O}(2^m|E|+nm\log m)$。

$\text{Code:}$

//Code By CXY07
#include<bits/stdc++.h>
using namespace std;

//#define FILE
//#define int long long
#define debug(x) cout << #x << " = " << x << endl
#define file(FILENAME) freopen(FILENAME".in", "r", stdin), freopen(FILENAME".out", "w", stdout)
#define LINE() cout << "LINE = " << __LINE__ << endl
#define LL long long
#define ull unsigned long long
#define pii pair<int,int>
#define mp make_pair
#define pb push_back
#define fst first
#define scd second
#define inv(x) qpow((x),mod - 2)
#define lowbit(x) ((x) & (-(x)))
#define abs(x) ((x) < 0 ? (-(x)) : (x))
#define randint(l, r) (rand() % ((r) - (l) + 1) + (l))
#define vec vector

const int MAXN = 5e4 + 10;
const int MAXM = 11;
const int INF = 2e9;
const double PI = acos(-1);
const double eps = 1e-6;
const int mod = 1e9 + 7;
//const int mod = 998244353;
//const int G = 3;
//const int base = 131;

int n, m, cnt, top;
int a[MAXM][MAXN], opt[MAXN], st[MAXN];
int ls[MAXN], rs[MAXN];
int id[MAXM], p2[MAXM], sav[1 << MAXM];
LL dp[MAXN][2], Ans;
char stk[MAXN], s[MAXN], og[MAXN];
pii now[MAXM];

template<typename T> inline bool read(T &a) {
    a = 0; char c = getchar(); int f = 1;
    while(c < '0' || c > '9') {if(c == '-') f = -1; c = getchar();}
    while(c >= '0' && c <= '9') {a = a * 10 + (c ^ 48); c = getchar();}
    a *= f;
    return 1;
}

void Modadd(LL &x, LL b) {
    x += b;
    if(x >= mod) x -= mod;
}

void build() {
    scanf("%s", og + 1);
    int len = strlen(og + 1), pos = 0;
    for(int i = 1; i <= len; ++i) {
        if('0' <= og[i] && og[i] <= '9') s[++pos] = og[i];
        else {
            if(og[i] == '<' || og[i] == '>' || og[i] == '?') {
                while(top && stk[top] != '(' && stk[top] != ')') s[++pos] = stk[top--];
                stk[++top] = og[i];
            } else {
                if(og[i] == '(') stk[++top] = '(';
                else {
                    while(top && stk[top] != '(') s[++pos] = stk[top--];
                    top--;
                }
            }
        }
    }
    while(top) if(stk[top] != '(' && stk[top] != ')') s[++pos] = stk[top--];
    top = 0;
    for(int i = 1; i <= pos; ++i) {
        cnt++;
        if('0' <= s[i] && s[i] <= '9') opt[cnt] = s[i] - '0' + 1, st[++top] = cnt;
        else {
            if(s[i] == '?') opt[cnt] = 11;
            else if(s[i] == '<') opt[cnt] = 12;
            else opt[cnt] = 13;
            ls[cnt] = st[top - 1], rs[cnt] = st[top];
            top = top - 2; st[++top] = cnt;
        }
    }
}

void calc(int x) {
    dp[x][0] = dp[x][1] = 0;
    if(opt[x] <= m) dp[x][id[opt[x]]] = 1;
    int u = ls[x], v = rs[x];
    if(!u || !v) return;
    calc(u), calc(v);
    if(opt[x] == 11 || opt[x] == 12) {
        Modadd(dp[x][0], dp[u][0] * dp[v][0] % mod);
        Modadd(dp[x][0], dp[u][0] * dp[v][1] % mod);
        Modadd(dp[x][0], dp[u][1] * dp[v][0] % mod);
        Modadd(dp[x][1], dp[u][1] * dp[v][1] % mod);
    } 
    if(opt[x] == 11 || opt[x] == 13) {
        Modadd(dp[x][0], dp[u][0] * dp[v][0] % mod);
        Modadd(dp[x][1], dp[u][0] * dp[v][1] % mod);
        Modadd(dp[x][1], dp[u][1] * dp[v][0] % mod);
        Modadd(dp[x][1], dp[u][1] * dp[v][1] % mod);
    }
}

inline int DP() {calc(cnt); return dp[cnt][1];}

inline void solve() {
    for(int S = 0; S < p2[m]; ++S) {
        for(int k = 0; k < m; ++k) id[k + 1] = (S & p2[k]) ? 1 : 0;
        sav[S] = DP();
    }
    for(int i = 1; i <= n; ++i) {
        for(int j = 1; j <= m; ++j) now[j] = mp(a[j][i], j);
        sort(now + 1, now + m + 1); 
        now[0] = mp(0, 0);
        int S = p2[m] - 1, pos = 1;
        for(int j = 1; j <= m; ++j) {
            while(pos < j && now[pos].fst < now[j].fst) S ^= p2[now[pos++].scd - 1];
            Modadd(Ans, 1ll * (now[j].fst - now[j - 1].fst) * sav[S] % mod);
        }
    }
    printf("%lld\n", Ans);
}

signed main () {
#ifdef FILE
    freopen("expr.in","r",stdin);
    freopen("expr.out","w",stdout);
#endif
    read(n), read(m);
    p2[0] = 1;
    for(int i = 1; i <= m; ++i) p2[i] = p2[i - 1] << 1;
    for(int i = 1; i <= m; ++i)
        for(int j = 1; j <= n; ++j) read(a[i][j]);
    build(); solve();
    return 0;
}
添加新评论