Search for a command to run...
Progressive hints first, then the full explanation and implementation when you're ready to cash out.
Review status
AI-generated and still unreviewed. Double-check the details before internalizing them.
Hints
Open only as much as you need to keep the solve alive.
For a fixed query range, split the entries into fixed numbers and holes (). If the fixed sum is , then the holes must receive exactly total mass. If , the answer is already .
Look at one prefix position . Its prefix sum is not mysterious: it is
where is the sum of fixed values in this prefix, is the number of holes in this prefix, and is the sum assigned to the first holes.
Over all weak compositions of into holes, the random variable has very simple moments:
and
So a range query only needs aggregate prefix statistics:
plus the total fixed sum and number of holes. These statistics are mergeable, because concatenating two segments just shifts every prefix of the right segment by the left segmentβs total .
For , with , the answer is exactly
Maintain those values in a segment tree. Thatβs the whole beast.
Consider one query range .
Let:
If , no valid sequence exists.
If , there is only one possible sequence, and it is valid only when .
Now assume . The holes receive a weak composition
The number of such compositions is
For every position inside the range, define:
For a concrete filling of the holes, let
Then the actual prefix sum at position is
So we need to sum
over all weak compositions.
This means we only need the first two moments of over all weak compositions. Nice. The combinatorics does the heavy lifting; we just donβt drop it on our foot.
Over all weak compositions of into parts, we have:
and
One way to justify this: a uniform weak composition is a Dirichlet-multinomial distribution with all parameters equal to . If is the total probability mass of the first boxes, then
so
Since ,
which simplifies to the formula above. If probability makes you itchy, the same identities follow from Vandermonde sums.
Define aggregate prefix sums over all positions in the queried range:
Then for :
where
All divisions are modular divisions modulo . Since , and are invertible.
For a segment, store:
We also store the exact fixed sum as long long, because we must compare it with before taking modulo.
When merging left segment and right segment , every prefix inside gets shifted by the total pair of :
Therefore the merge formulas are direct expansions of squares/products. For example:
and
The other sums are similarly boring β exactly what we want in a segment tree.
Precompute factorials and inverse factorials up to
for combinations.
Each update and query costs
and preprocessing costs
Across all test cases:
which easily fits. The segment tree is doing the accounting; the formula is doing the magic.
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
void setIO() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
}
const int MOD = 998244353;
const int MAXC = 1300005;
int mod_pow(ll a, ll e) {
ll r = 1;
while (e) {
if (e & 1) r = r * a % MOD;
a = a * a % MOD;
e >>= 1;
}
return int(r);
}
vector<int> fact, ifact, invv;
void init_combi() {
fact.assign(MAXC + 1, 1);
ifact.assign(MAXC + 1, 1);
invv.assign(MAXC + 1, 1);
for (int i = 1; i <= MAXC; i++) fact[i] = 1LL * fact[i - 1] * i % MOD;
ifact[MAXC] = mod_pow(fact[MAXC], MOD - 2);
for (int i = MAXC; i >= 1; i--) ifact[i - 1] = 1LL * ifact[i] * i % MOD;
invv[0] = 0;
invv[1] = 1;
for (int i = 2; i <= MAXC; i++) {
invv[i] = int(MOD - 1LL * (MOD / i) * invv[MOD % i] % MOD);
}
}
int C(int n, int r) {
if (r < 0 || r > n) return 0;
return 1LL * fact[n] * ifact[r] % MOD * ifact[n - r] % MOD;
}
struct Node {
int len = 0;
int unk = 0;
ll fixed = 0; // exact fixed sum, for comparison with m
int totA = 0; // total fixed sum modulo MOD
int sumA = 0;
int sumA2 = 0;
int sumU = 0;
int sumU2 = 0;
int sumAU = 0;
};
Node make_node(int v) {
Node x;
x.len = 1;
if (v == -1) {
x.unk = 1;
x.sumU = 1;
x.sumU2 = 1;
} else {
int z = v % MOD;
x.fixed = v;
x.totA = z;
x.sumA = z;
x.sumA2 = 1LL * z * z % MOD;
}
return x;
}
Node merge_node(const Node& L, const Node& R) {
if (L.len == 0) return R;
if (R.len == 0) return L;
Node res;
res.len = L.len + R.len;
res.unk = L.unk + R.unk;
res.fixed = L.fixed + R.fixed;
res.totA = L.totA + R.totA;
if (res.totA >= MOD) res.totA -= MOD;
ll a = L.totA;
ll k = L.unk;
ll lenR = R.len;
res.sumA = (L.sumA + R.sumA + lenR * a) % MOD;
res.sumA2 = (
1LL * L.sumA2 + R.sumA2
+ 2LL * a % MOD * R.sumA
+ lenR % MOD * a % MOD * a
) % MOD;
res.sumU = (L.sumU + R.sumU + lenR % MOD * (k % MOD)) % MOD;
res.sumU2 = (
1LL * L.sumU2 + R.sumU2
+ 2LL * (k % MOD) % MOD * R.sumU
+ lenR % MOD * (k % MOD) % MOD * (k % MOD)
) % MOD;
res.sumAU = (
1LL * L.sumAU + R.sumAU
+ lenR % MOD * a % MOD * (k % MOD)
+ a * R.sumU % MOD
+ (k % MOD) * R.sumA % MOD
) % MOD;
return res;
}
struct SegTree {
int n, sz;
vector<Node> st;
SegTree() {}
SegTree(const vector<int>& a) {
init(a);
}
void init(const vector<int>& a) {
n = int(a.size());
sz = 1;
while (sz < n) sz <<= 1;
st.assign(2 * sz, Node());
for (int i = 0; i < n; i++) st[sz + i] = make_node(a[i]);
for (int i = sz - 1; i >= 1; i--) st[i] = merge_node(st[i << 1], st[i << 1 | 1]);
}
void update(int pos, int val) {
int p = sz + pos;
st[p] = make_node(val);
for (p >>= 1; p; p >>= 1) st[p] = merge_node(st[p << 1], st[p << 1 | 1]);
}
Node query(int l, int r) {
Node left, right;
l += sz;
r += sz;
while (l <= r) {
if (l & 1) left = merge_node(left, st[l++]);
if (!(r & 1)) right = merge_node(st[r--], right);
l >>= 1;
r >>= 1;
}
return merge_node(left, right);
}
};
int answer_query(const Node& nd, int m) {
ll Mll = 1LL * m - nd.fixed;
if (Mll < 0) return 0;
int K = nd.unk;
if (K == 0) {
return Mll == 0 ? nd.sumA2 : 0;
}
int M = int(Mll % MOD);
int ways = C(int(Mll) + K - 1, K - 1);
ll term = nd.sumA2;
// 2M/K * S_AU
term += 2LL * M % MOD * invv[K] % MOD * nd.sumAU;
term %= MOD;
// M(M+K)/(K(K+1)) * S_U
ll denom = 1LL * invv[K] * invv[K + 1] % MOD;
term += 1LL * M * ((Mll + K) % MOD) % MOD * denom % MOD * nd.sumU;
term %= MOD;
// M(M-1)/(K(K+1)) * S_U2
ll mm1 = (Mll - 1) % MOD;
if (mm1 < 0) mm1 += MOD;
term += 1LL * M * mm1 % MOD * denom % MOD * nd.sumU2;
term %= MOD;
return int(1LL * ways * term % MOD);
}
int main() {
setIO();
init_combi();
int T;
cin >> T;
while (T--) {
int n, q;
cin >> n >> q;
vector<int> a(n);
for (int i = 0; i < n; i++) cin >> a[i];
SegTree seg(a);
while (q--) {
int op;
cin >> op;
if (op == 1) {
int p, v;
cin >> p >> v;
--p;
seg.update(p, v);
} else {
int l, r, m;
cin >> l >> r >> m;
--l; --r;
Node nd = seg.query(l, r);
cout << answer_query(nd, m) << '\n';
}
}
}
return 0;
}#include <bits/stdc++.h>
using namespace std;
using ll = long long;
void setIO() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
}
const int MOD = 998244353;
const int MAXC = 1300005;
int mod_pow(ll a, ll e) {
ll r = 1;
while (e) {
if (e & 1) r = r * a % MOD;
a = a * a % MOD;
e >>= 1;
}
return int(r);
}
vector<int> fact, ifact, invv;
void init_combi() {
fact.assign(MAXC + 1, 1);
ifact.assign(MAXC + 1, 1);
invv.assign(MAXC + 1, 1);
for (int i = 1; i <= MAXC; i++) fact[i] = 1LL * fact[i - 1] * i % MOD;
ifact[MAXC] = mod_pow(fact[MAXC], MOD - 2);
for (int i = MAXC; i >= 1; i--) ifact[i - 1] = 1LL * ifact[i] * i % MOD;
invv[0] = 0;
invv[1] = 1;
for (int i = 2; i <= MAXC; i++) {
invv[i] = int(MOD - 1LL * (MOD / i) * invv[MOD % i] % MOD);
}
}
int C(int n, int r) {
if (r < 0 || r > n) return 0;
return 1LL * fact[n] * ifact[r] % MOD * ifact[n - r] % MOD;
}
struct Node {
int len = 0;
int unk = 0;
ll fixed = 0; // exact fixed sum, for comparison with m
int totA = 0; // total fixed sum modulo MOD
int sumA = 0;
int sumA2 = 0;
int sumU = 0;
int sumU2 = 0;
int sumAU = 0;
};
Node make_node(int v) {
Node x;
x.len = 1;
if (v == -1) {
x.unk = 1;
x.sumU = 1;
x.sumU2 = 1;
} else {
int z = v % MOD;
x.fixed = v;
x.totA = z;
x.sumA = z;
x.sumA2 = 1LL * z * z % MOD;
}
return x;
}
Node merge_node(const Node& L, const Node& R) {
if (L.len == 0) return R;
if (R.len == 0) return L;
Node res;
res.len = L.len + R.len;
res.unk = L.unk + R.unk;
res.fixed = L.fixed + R.fixed;
res.totA = L.totA + R.totA;
if (res.totA >= MOD) res.totA -= MOD;
ll a = L.totA;
ll k = L.unk;
ll lenR = R.len;
res.sumA = (L.sumA + R.sumA + lenR * a) % MOD;
res.sumA2 = (
1LL * L.sumA2 + R.sumA2
+ 2LL * a % MOD * R.sumA
+ lenR % MOD * a % MOD * a
) % MOD;
res.sumU = (L.sumU + R.sumU + lenR % MOD * (k % MOD)) % MOD;
res.sumU2 = (
1LL * L.sumU2 + R.sumU2
+ 2LL * (k % MOD) % MOD * R.sumU
+ lenR % MOD * (k % MOD) % MOD * (k % MOD)
) % MOD;
res.sumAU = (
1LL * L.sumAU + R.sumAU
+ lenR % MOD * a % MOD * (k % MOD)
+ a * R.sumU % MOD
+ (k % MOD) * R.sumA % MOD
) % MOD;
return res;
}
struct SegTree {
int n, sz;
vector<Node> st;
SegTree() {}
SegTree(const vector<int>& a) {
init(a);
}
void init(const vector<int>& a) {
n = int(a.size());
sz = 1;
while (sz < n) sz <<= 1;
st.assign(2 * sz, Node());
for (int i = 0; i < n; i++) st[sz + i] = make_node(a[i]);
for (int i = sz - 1; i >= 1; i--) st[i] = merge_node(st[i << 1], st[i << 1 | 1]);
}
void update(int pos, int val) {
int p = sz + pos;
st[p] = make_node(val);
for (p >>= 1; p; p >>= 1) st[p] = merge_node(st[p << 1], st[p << 1 | 1]);
}
Node query(int l, int r) {
Node left, right;
l += sz;
r += sz;
while (l <= r) {
if (l & 1) left = merge_node(left, st[l++]);
if (!(r & 1)) right = merge_node(st[r--], right);
l >>= 1;
r >>= 1;
}
return merge_node(left, right);
}
};
int answer_query(const Node& nd, int m) {
ll Mll = 1LL * m - nd.fixed;
if (Mll < 0) return 0;
int K = nd.unk;
if (K == 0) {
return Mll == 0 ? nd.sumA2 : 0;
}
int M = int(Mll % MOD);
int ways = C(int(Mll) + K - 1, K - 1);
ll term = nd.sumA2;
// 2M/K * S_AU
term += 2LL * M % MOD * invv[K] % MOD * nd.sumAU;
term %= MOD;
// M(M+K)/(K(K+1)) * S_U
ll denom = 1LL * invv[K] * invv[K + 1] % MOD;
term += 1LL * M * ((Mll + K) % MOD) % MOD * denom % MOD * nd.sumU;
term %= MOD;
// M(M-1)/(K(K+1)) * S_U2
ll mm1 = (Mll - 1) % MOD;
if (mm1 < 0) mm1 += MOD;
term += 1LL * M * mm1 % MOD * denom % MOD * nd.sumU2;
term %= MOD;
return int(1LL * ways * term % MOD);
}
int main() {
setIO();
init_combi();
int T;
cin >> T;
while (T--) {
int n, q;
cin >> n >> q;
vector<int> a(n);
for (int i = 0; i < n; i++) cin >> a[i];
SegTree seg(a);
while (q--) {
int op;
cin >> op;
if (op == 1) {
int p, v;
cin >> p >> v;
--p;
seg.update(p, v);
} else {
int l, r, m;
cin >> l >> r >> m;
--l; --r;
Node nd = seg.query(l, r);
cout << answer_query(nd, m) << '\n';
}
}
}
return 0;
}