标签 单调栈 下的文章

保序回归入门题

题目链接:P7294 [USACO21JAN] Minimum Cost Paths P

题意:

现有 $n\times m$ 的一个矩形,第 $i$ 列有一个代价 $c_i$。如果现在你在 $(x,y)$,可以花费 $x^2$ 的代价到 $(x,y+1)$,或 $c_y$ 的代价到 $(x+1,y)$。

现有 $q$ 次询问,每次给定 $(x_i,y_i)$,求从 $(1,1)$ 到 $(x_i,y_i)$ 的最小代价。

$2\le n\le 10^9,2\le m,q\le 10^5$。

本题解提供两种做法,第一种是 $\text{B}\red{\text{enq}}$ 给出的正解,还有一种是更为无脑的保序回归。由于作者很懒所以只写了第二种的代码

听神仙zzm说本质是一样的 代码还差不多

Solution 1

本部分参照 $\text{B}\red{\text{enq}}$ 在 $\text{USACO}$ 发布的官方题解。

令从 $(1,1)$ 到 $(x,y)$ 的最小代价为 $\text{ans}_y(x)$,那么能够发现,对于一个固定的 $y$,其函数值随着 $x$ 单调增且 $\text{ans}_{y}(x)-\text{ans}_{y}(x-1)\le \text{ans}_{y}(x+1)-\text{ans}_{y}(x)$。可以这样解释:

考虑从 $\text{ans}_y(x)$ 转移到 $\text{ans}_{y+1}(x)$ 的时候:

  1. 先令 $\text{ans}_{y+1}(x)=\text{ans}_y(x)+x^2$。
  2. 接着令 $\text{ans}_{y+1}(x)=\min(\text{ans}_{y+1}(x),\text{ans}_{y+1}(x-1)+c_y)$。

对于操作 $2$,相当于拿一条斜率为 $c_y$ 的直线去切函数图像,然后把过高的部分砍下来。同时,看到 $1$ 中 $\text{ans}_{y+1}(x)$ 上加了 $x^2$ 就可以大概发现 $\text{ans}_{y}(x)-\text{ans}_{y}(x-1)\le \text{ans}_{y}(x+1)-\text{ans}_{y}(x)$ 的规律。

那么现在可以使用一个栈来维护这样一个类似凸壳的东西,每次在栈里二分找到被直线给砍掉的位置,然后加入一个新点。至于每次的 $1$ 操作,也就是 $+x^2$,可以记录一个时间戳,然后和当前时间戳作差。

是的代码咕了

Solution 2

本部分来自 $\text{zxyhymzg}$ 考场想法。

萌新刚学保序回归 有锅请轻喷

令 $s_i$ 为我们在第 $i$ 行处走到 $(i,s_i)$,其中 $s_0=1$。那么显然有 $s_{i-1}\le s_i$,因为只能向右或者向下走。

计算一下代价,对于从 $(1,1)$ 到 $(x,y)$,$\forall s_i\le x$,代价为:

$$\sum_{i=1}^{y-1} s_i^2+(s_i-s_{i-1})\times c_i$$

后面的部分交换和式,把一项变成只和 $s_i$ 相关:

$$x\times c_y-c_1+\sum_{i=1}^{y-1} s_i^2+(c_i-c_{i+1})\times s_i$$

和式外面是常数,那么只要最小化和式内部。发现是二次函数,先写成顶点式:

$$(s_i-\dfrac{c_{i+1}-c_i}{2})^2-\dfrac{(c_{i+1}-c_i)^2}{4}$$

后面的也是常数,不管。变成最小化:

$$\sum_{i=1}^{y-1} (s_i-\dfrac{c_{i+1}-c_i}{2})^2$$

$\dfrac{c_{i+1}-c_i}{2}$ 是常数,那么就变成了经典的保序回归中的特殊的 $L_2$ 问题,在 $\text{IOI2018}$ 国家候选队论文集中有写到。

需要注意的是,现在对于一个固定的 $y$,发现其 $s_i$ 是固定的。而我们要求 $1\le s_i\le x$,实际上可能会不满足。可以发现将 $s_i$ 向 $[1,x]$ 取整,同时浮点数四舍五入即可。

那么同样的,在保序回归该特殊情况的贪心求解中,我们需要维护一个单调栈,因此对于向 $[1,x]$ 取整的操作可以在单调栈上二分,取整后一段区间的贡献是可以 $\mathcal{O}(1)$ 计算的。

至此即可解决本题,时间复杂度 $\mathcal{O}(n\log n)$,实际上好像效率海星?

$\text{Code:}$

//Code By CXY07
#include<bits/stdc++.h>
using namespace std;

//#define FILE
#define int long long
#define debug(x) cout << #x << " = " << x << endl
#define file(FILENAME) freopen(FILENAME".in", "r", stdin), freopen(FILENAME".out", "w", stdout)
#define LINE() cout << "LINE = " << __LINE__ << endl
#define LL long long
#define ull unsigned long long
#define pii pair<int,int>
#define mp make_pair
#define pb push_back
#define fst first
#define scd second
#define inv(x) qpow((x),mod - 2)
#define lowbit(x) ((x) & (-(x)))
#define abs(x) ((x) < 0 ? (-(x)) : (x))
#define randint(l, r) (rand() % ((r) - (l) + 1) + (l))
#define vec vector

const int MAXN = 2e5 + 10;
const int INF = 2e9;
const double PI = acos(-1);
const double eps = 1e-6;
//const int mod = 1e9 + 7;
//const int mod = 998244353;
//const int G = 3;
//const int base = 131;

struct Query {
    int x, y, id;
    Query(int _x = 0, int _y = 0, int _id = 0) : x(_x), y(_y), id(_id) {}
    bool operator < (const Query &b) const {return y < b.y;}
} q[MAXN];

int n, m, qs, top;
int c[MAXN], sum[MAXN], stk[MAXN];
//为了尽量避免浮点数计算,sum数组中先不将c_{i+1}-c_i除2
int Ans[MAXN], num[MAXN], val[MAXN];

template<typename T> inline bool read(T &a) {
    a = 0; char c = getchar(); int f = 1;
    while(c < '0' || c > '9') {if(c == '-') f = -1; c = getchar();}
    while(c >= '0' && c <= '9') {a = a * 10 + (c ^ 48); c = getchar();}
    a *= f;
    return 1;
}

template<typename A, typename ...B>
inline bool read(A &x, B &...y) {return read(x) && read(y...);}

double aver(int l, int r) {
    return 1. * (sum[r] - sum[l - 1]) / (2. * (r - l + 1));
}

int calc(int l, int r, int x) {
    if(l > r) return 0;
    return x * x * (r - l + 1) - x * (sum[r] - sum[l - 1]);
}

void Insert(int x) {
    while(top && aver(stk[top - 1] + 1, stk[top]) + eps > aver(stk[top] + 1, x)) top--;
    stk[++top] = x;
    num[top] = (int)round(aver(stk[top - 1] + 1, stk[top]));
    val[top] = val[top - 1] + calc(stk[top - 1] + 1, stk[top], num[top]);
}

signed main () {
#ifdef FILE
    freopen(".in","r",stdin);
    freopen(".out","w",stdout);
#endif
    read(n), read(m);
    for(int i = 1; i <= m; ++i) read(c[i]);
    for(int i = 1; i <= m; ++i) sum[i] = sum[i - 1] + c[i + 1] - c[i];
    read(qs);
    for(int i = 1, x, y; i <= qs; ++i) {
        read(x), read(y);
        q[i] = Query(x, y, i);
    }
    sort(q + 1, q + qs + 1);
    int nowy = 0;
    for(int p = 1; p <= qs; ++p) {
        if(q[p].y == 1) {
            Ans[q[p].id] = (q[p].x - 1) * c[1];
            continue;
        }
        while(nowy + 1 < q[p].y) {
            nowy++;
            Insert(nowy);
        }
        Ans[q[p].id] = val[top] + q[p].x * c[q[p].y] - c[1];
        int L = 0, R = top, mid;
        while(L < R) {
            mid = (L + R + 1) >> 1;
            if(num[mid] > q[p].x) R = mid - 1;
            else L = mid;
        }
        Ans[q[p].id] += calc(stk[L] + 1, nowy, q[p].x) - (val[top] - val[L]);
        L = 1, R = top + 1;
        while(L < R) {
            mid = (L + R) >> 1;
            if(num[mid] <= 0) L = mid + 1;
            else R = mid;
        }
        Ans[q[p].id] += calc(1, stk[L - 1], 1) - val[L - 1];
    }
    for(int i = 1; i <= qs; ++i)
        printf("%lld\n", Ans[i]);
    return 0;
}