Java基础内容/基础语法/流程控制

HDU - 1004 Let the Balloon Rise

  返回  

2017年西安区域赛 Sum of xor sum(线段树)

2021/8/21 20:43:26 浏览:

传送门

题意:

给出一个长度为 n n n 的数组,有 q q q 次查询,每次查询给出一个区间 [ l , r ] [l,r] [l,r] ,求这段区间里面所有子区间的异或和的总和。

题解:

不难想到,要按位考虑贡献,对于第 i i i 位的贡献是 2 i 2^i 2i 乘上区间 1 1 1的个数为奇数的子区间的数量。

考虑利用线段树维护一个区间中包含奇数个 1 1 1的子区间数量。

如何区间合并?

设左区间范围是 [ l , m i d ] [l,mid] [l,mid] ,右区间范围是 [ m i d + 1 , r ] [mid+1,r] [mid+1,r] , a n s ans ans表示区间中包含奇数个 1 1 1​的子区间数量。

那么 a n s = l e f t . a n s + r i g h t . a n s + ans=left.ans+right.ans+ ans=left.ans+right.ans+ m i d mid mid为右端点包含奇数个 1 1 1的区间数量 × \times × m i d + 1 mid+1 mid+1为左端点包含偶数个 1 1 1的区间数量 + + + m i d mid mid为右端点包含偶数个 1 1 1的区间数量 × \times × m i d + 1 mid+1 mid+1为左端点包含奇数个 1 1 1​的区间数量。

那么维护三个东西即可:区间中包含奇数个 1 1 1的子区间数量 ,以 m i d mid mid为右端点包含奇数个 1 1 1的区间数量,以 m i d + 1 mid+1 mid+1为左端点包含奇数个 1 1 1的区间数量。

代码:

#pragma GCC diagnostic error "-std=c++11"
#include <algorithm>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <ctime>
#include <iostream>
#include <map>
#include <queue>
#include <set>
#include <stack>
#define iss ios::sync_with_stdio(false)
using namespace std;
typedef unsigned long long ull;
typedef long long ll;
typedef pair<int, int> pii;
const int mod = 1e9 + 7;
const int MAXN = 2e5 + 5;
const int inf = 0x3f3f3f3f;
int a[MAXN], base[22];
struct Node {
    int l, r, sum[22];
    ll ans[22], lsum[22], rsum[22];
} node[MAXN << 2];
Node combine(Node x,Node y)
{
    Node k;
    k.l = x.l;
    k.r = y.r;
    int len1 = x.r - x.l + 1;
    int len2 = y.r - y.l + 1;
    for (int i = 0; i <= 20;i++)
    {
        k.sum[i] = x.sum[i] + y.sum[i];
        k.ans[i] = (x.ans[i] + y.ans[i] + x.rsum[i] * (len2 - y.lsum[i])%mod + (len1 - x.rsum[i]) * y.lsum[i]%mod)%mod;
        if(x.sum[i]&1)
            k.lsum[i] = x.lsum[i] + (len2 - y.lsum[i]);
        else
            k.lsum[i] = x.lsum[i] + y.lsum[i];
        if(y.sum[i]&1)
            k.rsum[i] = y.rsum[i] + (len1 - x.rsum[i]);
        else
            k.rsum[i] = y.rsum[i] + x.rsum[i];
    }
    return k;
}
void build(int l, int r, int num)
{
    node[num].l = l;
    node[num].r = r;
    if (l == r) {
        for (int i = 20; i >= 0; i--) {
            if((a[l]>>i)&1){
                node[num].ans[i] = node[num].lsum[i] = node[num].rsum[i] =node[num].sum[i]=1;
            }
            else {
                node[num].ans[i] = node[num].lsum[i] = node[num].rsum[i] =node[num].sum[i]= 0;
            }
        }
        return;
    }
    int mid = (l + r) >> 1;
    build(l, mid, num << 1);
    build(mid + 1, r, num << 1|1);
    node[num] = combine(node[num << 1], node[num << 1 | 1]);
}
Node query(int l,int r,int num)
{
    if(node[num].l>=l&&node[num].r<=r)
    {
        return node[num];
    }
    int mid = (l + r) >> 1;
    if(r<=mid)
        return query(l, r, num << 1);
    else if(l>mid)
        return query(l, r, num << 1 | 1);
    else {
        Node tmp1 = query(l, r, num << 1);
        Node tmp2 = query(l, r, num << 1 | 1);
        Node tmp = combine(tmp1, tmp2);
        return tmp;
    }
}
int main()
{
    base[0] = 1;
    for (int i = 1; i <= 20; i++)
        base[i] = base[i - 1] * 2;
    int t;
    scanf("%d", &t);
    while (t--) {
        int n, q;
        scanf("%d%d", &n, &q);
        for (int i = 1; i <= n; i++) {
            scanf("%d", &a[i]);
        }
        build(1, n, 1);
        while ((q--))
        {
            int l, r;
            scanf("%d%d", &l, &r);
            Node ans = query(l, r, 1);
            ll sum = 0;
            for (int i = 0; i <= 20;i++)
            {
                sum = (sum + base[i] * ans.ans[i]%mod)%mod;
            }
            printf("%lld\n", sum);
        }
        
    }

}

联系我们

如果您对我们的服务有兴趣,请及时和我们联系!

服务热线:18288888888
座机:18288888888
传真:
邮箱:888888@qq.com
地址:郑州市文化路红专路93号