杭电多校第一场


A: Blank

题意

有 $n (n \leq 100)$ 个格子,向其中填入 $0、1、2、3 $这$4$个数,但是有 $m ( m ≤ 100)$ 个限制

限制 $l$ $ r$ $x$ :表示 $l ~ r$ 的格子内不同的数的个数为$x$

要求满足所有限制的方案有多少种?

思路

我们首先设$dp[i][j][k][r]$为这$0,1,2,3$四个数字的最后一次出现的位置,$dp$值为方案数

那么转移可以这样写一下:

$dp[cur][j][k][r] += dp[i][j][k][r], dp[i][cur][k][r] += dp[i][j][k][r]$

$dp[i][j][cur][r] += dp[i][j][k][r], dp[i][j][k][cur] += dp[i][j][k][r]$

因为$i,j,k,r$互不相同, 且当位一定为一个数字并且相互之间有大小顺序,那么我们把$dp$按照大小来转移的话

还是$dp[cur][i][j][k]$ 其中$cur \geq i \geq j \geq k$

那么转移就变成

$dp[cur+1][i][j][k]+=dp[cur][i][j][k], dp[cur+1][cur][j][k] += dp[cur][i][j][k]$

$dp[cur+1][cur][i][k] += dp[cur][i][j][k], dp[cur+1][cur][i][j] += dp[cur][i][j]$

我们不必要区分$0,1,2,3$对应的是哪一个,因为这对结果没影响

这样的$dp$数组太大,我们可以用滚动数组来优化一下空间

AC代码

实测$dp$数组降序会$T$,可能是因为$dp$过程中地址变换太大造成超时.

#include<bits/stdc++.h>
using namespace std;

#define ll long long
const int maxn = 1e2 + 7;
const int inf = 0x3f3f3f3f;
const int mod = 998244353;
typedef pair<int, int> pis;

vector<pis> lo[maxn];

ll dp[maxn][maxn][maxn][2];
//dp[i][j][k][cur] 升序
void add(ll &a, ll b) {
    a = a + b;
    if(a > mod) a -= mod;
    if(a < 0) a += mod;
}
int main() { 
    int t;
    scanf("%d", &t);
    while(t --) {
        int n, m;
        scanf("%d %d", &n, &m);
        for (int i = 1; i <= n; i ++) {
            lo[i].clear();
            lo[i].push_back(pis{i, 1});
        }
        for (int i = 1; i <= m; i ++) {
            int l, r, x;
            scanf("%d %d %d", &l, &r, &x);
            lo[r].push_back(pis{l, x});
        }
        memset(dp, 0, sizeof(dp));
        dp[0][0][0][0] = 1;
        for (int cur = 1; cur <= n; cur ++) {
            int np = cur & 1;
            for (int i = 0; i <= cur; i ++) 
                for (int j = i; j <= cur; j ++) 
                    for (int k = j; k <= cur; k ++) 
                        dp[i][j][k][np] = 0;
        
            for (int i = 0; i <= cur; i ++) 
                for (int j = i; j <= cur; j ++) 
                    for (int k = j; k <= cur; k ++) {
                        /*add(dp[i][k][cur-1][np], dp[i][j][k][np^1]);
                        地址跨越比add(dp[np][cur-1][k][i], dp[np^1][k][j][i]);
                        要大,可能是造成超时的原因
                        */
                        add(dp[j][k][cur-1][np], dp[i][j][k][np^1]);
                        add(dp[i][k][cur-1][np], dp[i][j][k][np^1]);
                        add(dp[i][j][cur-1][np], dp[i][j][k][np^1]);
                        add(dp[i][j][k][np], dp[i][j][k][np^1]);
                    }

            for (int i = 0; i <= cur; i ++) 
                for (int j = i; j <= cur; j ++) 
                    for (int k = j; k <= cur; k ++) 
                        for (pis it: lo[cur]) {
                            int l = it.first, r = cur, x = it.second;
                            int cnt = (i >= l) + (j >= l) + (k >= l) + 1;
                            if(cnt != x) dp[i][j][k][np] = 0;
                        }
                        
        }
        ll ans = 0;
        for (int i = 0; i <= n; i ++) 
            for (int j = i; j <= n; j ++) 
                for (int k = j; k <= n; k ++) add(ans, dp[i][j][k][n&1]);
        printf("%lld\n", ans);
    }
    return 0;
}

L: Sequence

题意

给一个长度为n的数组,有m次操作,操作有3种,给一个x,每次改变序列的值$b_i=\sum\limits_{j=i-k*x}a_j$

求改变完了的序列的$(i\times a[i])$值的异或和

思路

通过打表观察可以发现,一种操作多次操作就是把序列$a$和组合数序列进行卷积,然后就直接用ntt就行了

AC代码

#include<bits/stdc++.h>
using namespace std;

#define ll long long
const int maxn = 5e5 + 7;
const int inf = 0x3f3f3f3f;
const int mod = 998244353;
typedef pair<int, int> pis;
#define g 3
#define Mod(x) ((x)>=mod?(x)-mod:(x))

ll rnk[maxn];
ll a[maxn], b[maxn];

ll Ksm(ll a, ll b) {
    ll res = 1;
    while(b) {
        if(b & 1) res = res * a % mod;
        a = a * a % mod;
        b >>= 1;
    }
    return res;
}

ll Fac[1000005], inv[1000005];

void FacPre() {
    inv[0] = Fac[0] = 1;
    for (int i = 1; i <= 1000000; i ++)
        Fac[i] = 1ll * Fac[i-1] * i % mod;
    inv[1000000] = Ksm(Fac[1000000], mod-2);
    for (int i = 999999; i >= 1; i --)
        inv[i] = 1ll * inv[i+1] * (i+1) % mod;
}

ll C(int n, int m) {
    if(m > n) return 0;
    return 1ll * Fac[n] * inv[m] % mod * inv[n-m] % mod;
}

void ntt(long long *a, int op, int n) {
    for (int i = 0; i < n; i ++) 
        if(i < rnk[i]) swap(a[i], a[rnk[i]]);
        for (int i = 2; i <= n; i <<= 1) {
            int nw = Ksm(g, (mod-1)/i);
            if(op == -1) nw = Ksm(nw, mod-2);
            for (int j = 0, m = i >> 1; j < n; j += i) 
                for (int k = 0, w = 1; k < m; k ++) {
                    int t = 1ll * a[j+k+m] * w % mod;
                    a[j+k+m] = Mod(a[j+k]-t+mod);
                    a[j+k] = Mod(a[j+k]+t);
                    w = 1ll * w * nw % mod;
                }
        }
        if(op == -1) 
            for (int i = 0, inv = Ksm(n, mod-2); i < n; i ++)
                a[i] = 1ll * a[i] * inv % mod;
}

void solve(ll *a, ll *b, int len) {
    int n = 1, lim = 0;
    while(n <= len + len) n <<= 1, lim++;
    for (int i = 0; i < n; i ++)
        rnk[i] = (rnk[i>>1]>>1) | ((i&1) << (lim-1));
    ntt(a, 1, n); ntt(b, 1, n);
    for (int i = 0; i < n; i ++)
        a[i] = (1ll * a[i] * b[i]) % mod;
    ntt(a, -1, n);
    for (int i = len; i < n; i ++) a[i] = 0;
} 

int cnt[5];

int main() { 
    FacPre();
    int t;
    scanf("%d", &t);
    while(t --) {
        memset(cnt, 0, sizeof(cnt));
        int n, m;
        scanf("%d %d", &n, &m);
        for (int i = 0; i < n; i ++)
            scanf("%lld", &a[i]);
        for (int i = 1, op; i <= m; i ++) {
            scanf("%d", &op);
            cnt[op] ++;
        }
        for (int i = 1; i <= 3; i ++) {
            memset(b, 0, sizeof(b));
            for (int j = 0; j * i < n; j ++) 
                b[j*i] = C(cnt[i]-1+j, j);
            if(cnt[i] == 0) b[0] = 1;
            solve(a, b, n);
        }
        ll ans = 0;
        for (int i = 0; i < n; i ++) ans = ans ^ (1ll * (i+1) * a[i]);
        printf("%lld\n", ans);
    } 
    return 0;
}

文章作者: Mug-9
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 Mug-9 !
评论
  目录