CSES - Shared codeLink to this code:
https://cses.fi/paste/4f4b13d8100becce8830e1/
#include<iostream>
#include<cassert>
#define ll long long
#define ull unsigned long long
using namespace std;
int q;
int n;
const int N = 200000;
int A[N + 1];
struct node {
ll sum;
ll lz_add;
ll lz_set;
node() {}
} t[N << 2];
void pushUp(int i) {
t[i].sum = t[i<<1].sum + t[i<<1|1].sum;
}
void pushDown(int i, int l, int r) {
int m = l + (r - l) / 2;
if (t[i].lz_set) {
t[i<<1].lz_set = t[i<<1|1].lz_set = t[i].lz_set;
t[i<<1].sum = t[i].lz_set * (m - l + 1);
t[i<<1|1].sum = t[i].lz_set * (r - m);
t[i<<1].lz_add = t[i<<1|1].lz_add = 0;
t[i].lz_set = 0;
}
else if (t[i].lz_add) {
if (t[i<<1].lz_set) {
t[i<<1].lz_set += t[i].lz_add;
}
else
t[i<<1].lz_add += t[i].lz_add;
t[i<<1].sum += t[i].lz_add * (m - l + 1);
if (t[i<<1|1].lz_set) {
t[i<<1|1].lz_set += t[i].lz_add;
}
else
t[i<<1|1].lz_add += t[i].lz_add;
t[i<<1|1].sum += t[i].lz_add * (r - m);
t[i].lz_add = 0;
}
}
void build(int i, int l, int r) {
t[i].lz_add = 0;
t[i].lz_set = 0;
if (l == r) {
t[i].sum = A[l];
}
else {
int m = l + (r - l) / 2;
build(i<<1, l, m);
build(i<<1|1, m + 1, r);
t[i].sum = t[i<<1].sum + t[i<<1|1].sum;
}
}
void add(int i, int l, int r, int a, int b, ll x) {
if (a > b) {
return;
}
else if (l == a && r == b) {
if (t[i].lz_set) {
t[i].lz_set += x;
}
else {
t[i].lz_add += x;
}
t[i].sum += x * (r - l + 1);
}
else {
pushDown(i, l, r);
int m = l + (r - l) / 2;
add(i<<1, l, m, a, min(b, m), x);
add(i<<1|1, m + 1, r, max(a, m + 1), b, x);
pushUp(i);
}
}
void set(int i, int l, int r, int a, int b, ll x) {
if (a > b)
return;
else if (l == a && r == b) {
t[i].lz_add = 0;
t[i].lz_set = x;
t[i].sum = x * (r - l + 1);
}
else {
pushDown(i, l, r);
int m = l + (r - l) / 2;
set(i<<1, l, m, a, min(b, m), x);
set(i<<1|1, m + 1, r, max(a, m + 1), b, x);
pushUp(i);
}
}
ll sum(int i, int l, int r, int a, int b) {
if (a > b)
return 0;
else if (a == l && r == b) {
return t[i].sum;
}
else {
pushDown(i, l, r);
int m = l + (r - l) / 2;
return sum(i<<1, l, m, a, min(b, m))
+ sum(i<<1|1, m + 1, r, max(a, m + 1), b);
}
}
int main() {
cin >> n >> q;
for (int i = 1; i <= n; ++i) cin >> A[i];
build(1, 1, n);
int flag, a, b, x;
while (q--) {
cin >> flag >> a >> b;
if (flag == 1) {
cin >> x;
add(1, 1, n, a, b, x);
}
else if (flag == 2) {
cin >> x;
set(1, 1, n, a, b, x);
}
else {
cout << sum(1, 1, n, a, b) << '\n';
}
}
}