提供一个模拟赛场上的思路。
给你一颗树,点有点权,你需要求出下列式子模 的值( 是质数且有原根 ):
其中 表示 到 的最短路径上所有点的点权按位与在一起之后的值。
保证 ,其中 是叶子个数。
首先考虑一个比较暴力的想法,枚举路径的起点 ,显然,任意一条路径都是 到叶子的路径一个前缀,所以以 为起点不同权值的路径最多只有 种。
如果能快速找到这些终点就可以做到 的时间复杂度。
我们考虑把度数 的点作为关键点建虚树,显然虚树上的边就是原树上的一条路径,把边内部的单独算掉,剩下的就是要经过至少一个关键点的路径。
显然是由一条边的前缀再拓展出去,那么本质不同的起点就是 个,然后本质不同的终点也只有 ,直接枚举即可。
可以预处理线性对数做到 求幂。
constexpr int maxn = 2e5 + 10, maxm = 510, mod = 786433;
constexpr int base = 1 << 30;
int a[maxn], pw[mod], dlog[mod];
int n, m, id[maxn], rnk[maxn];
vector<int> g[maxn];
struct E {
int v;
vector<int> a;
vector<pair<int, int>> pre;
};
vector<E> ed[maxn];
vector<int> cur;
void dfs(int x, int fa, int rt) {
if (g[x].size() != 2 && rt != x) {
ed[id[rt]].emplace_back(E{id[x], cur});
return;
}
if (x != rt) cur.emplace_back(a[x]);
for (int y : g[x]) if (y != fa) dfs(y, x, rt);
if (x != rt) cur.pop_back();
}
void init() {
pw[0] = 1; for (int i = 1; i < mod - 1; ++i) pw[i] = 10 * pw[i - 1] % mod;
for (int i = 0; i < mod - 1; ++i) dlog[pw[i]] = i;
}
int qpow(int a, int b) {
a %= mod; if (!a) return 0;
b %= mod - 1; b = 1ll * b * dlog[a] % (mod - 1);
return pw[b];
}
int calc(int x) {
int A = base - 1 - x % base;
int O = x / base;
if (!A) return 0;
return qpow(A, A);
}
int case1(vector<int> a) {
ll ans = 0;
vector<pair<int, int>> cur;
for (int x : a) {
for (auto &e : cur) e.first |= x;
cur.emplace_back(x, 1);
vector<pair<int, int>> nxt;
for (auto e : cur) {
if (nxt.empty() || nxt.back().first != e.first) nxt.emplace_back(e);
else nxt.back().second += e.second;
}
cur = nxt;
for (auto e : cur) ans += 1ll * calc(e.first) * e.second % mod;
}
return ans % mod;
}
vector<pair<int, int>> res;
void findres(int x, int fa, int cur) {
if (base - 1 - cur % base == 0) return;
cur |= a[rnk[x]];
res.emplace_back(cur, 1);
for (auto &e : ed[x]) if (e.v != fa) {
int tmp = cur;
for (auto &p : e.pre) tmp |= p.first, res.emplace_back(tmp, p.second);
findres(e.v, x, tmp);
}
}
int main() {
init(); n = read();
for (int i = 1; i <= n; ++i) a[i] = base - 1 - read();
for (int i = 1; i < n; ++i) {
int u = read(), v = read();
g[u].emplace_back(v); g[v].emplace_back(u);
}
for (int i = 1; i <= n; ++i) if (g[i].size() != 2) id[i] = ++m, rnk[m] = i;
for (int i = 1; i <= n; ++i) if (g[i].size() != 2) dfs(i, 0, i);
for (int x = 1; x <= m; ++x) {
for (auto &e : ed[x]) {
int cur = 0;
vector<pair<int, int>> tmp;
for (auto k : e.a) cur |= k, tmp.emplace_back(cur, 1);
for (auto p : tmp) {
if (e.pre.empty() || e.pre.back().first != p.first) e.pre.emplace_back(p);
else e.pre.back().second += p.second;
}
}
}
ll ans1 = 0;
for (int x = 1; x <= m; ++x) {
for (auto &e : ed[x]) {
ans1 += case1(e.a);
}
}
ans1 = 1ll * ans1 * (mod + 1 >> 1) % mod;
ll ans2 = 0;
for (int i = 1; i <= m; ++i) {
res.clear();
findres(i, 0, 0);
for (auto e : res) ans2 += 1ll * calc(e.first) * e.second % mod;
}
for (int x = 1; x <= m; ++x) {
for (auto &e : ed[x]) {
res.clear();
findres(x, e.v, 0);
for (auto p : e.pre) {
for (auto q : res) {
ans2 += 1ll * calc(p.first | q.first) * p.second % mod * q.second % mod;
}
}
}
}
for (int i = 1; i <= m; ++i) ans2 += calc(a[rnk[i]]);
ans2 = 1ll * ans2 * (mod + 1 >> 1) % mod;
write((ans1 + ans2) % mod), pc(10);
}