LOJ

## 题解

• 1.插入x

对ans减去\(dis(pre,nxt)\)，再加上\(dis(pre,x)+dis(x,nxt)\)，然后插入\(x\)
• 2.删除x

先删去\(x\)，对ans减去\(dis(pre,x)+dis(x,nxt)\)，再加上\(dis(pre,nxt)\)

#include <bits/stdc++.h>
using namespace std; #define ll long long
#define N 100100
int n, m, lim, f[N][20], dep[N], dfn[N], id[N];
ll d[N];
struct edge {
int to, nxt, v;
}e[N<<1];
set<int>s; void ins(int u, int v, int z) {
e[++cnt] = (edge) {v, head[u], z};
} int tim = 0;
void dfs(int u) {
dfn[u] = ++tim;
for(int i = head[u]; i; i = e[i].nxt) {
if(e[i].to == f[u][0]) continue; int v = e[i].to;
f[v][0] = u; dep[v] = dep[u] + 1; d[v] = d[u] + e[i].v;
for(int j = 1; j <= lim; ++j) f[v][j] = f[f[v][j - 1]][j - 1];
dfs(v);
}
} int lca(int x, int y) {
if(dep[x] < dep[y]) swap(x, y);
for(int i = lim; i >= 0; --i) if(dep[f[x][i]] >= dep[y]) x = f[x][i];
if(x == y) return y;
for(int i = lim; i >= 0; --i) if(f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
return f[x][0];
} ll dis(int x, int y) {
return d[x] + d[y] - 2 * d[lca(x, y)];
} ll ans = 0;
#define iter set<int>::iterator
if(s.empty()) { s.insert(dfn[x]); return; }
iter it = s.lower_bound(dfn[x]);
if(it == s.end() || it == s.begin()) {
iter st = s.begin(), ed = s.end(); --ed; // s.end()是一个空指针
ans -= dis(id[*st], id[*ed]);
ans += dis(x, id[*ed]); ans += dis(id[*st], x);
s.insert(dfn[x]);
} else {
iter pre = it, nxt = it; --pre;
ans -= dis(id[*pre], id[*nxt]);
ans += dis(id[*pre], x); ans += dis(x, id[*nxt]);
s.insert(dfn[x]);
}
} void del(int x) {
s.erase(dfn[x]); if(s.empty()) return;
iter it = s.lower_bound(dfn[x]);
if(it == s.end() || it == s.begin()) {
iter st = s.begin(), ed = s.end(); --ed;
ans += dis(id[*st], id[*ed]);
ans -= dis(id[*st], x) + dis(x, id[*ed]);
} else {
iter pre = it, nxt = it; --pre;
ans += dis(id[*pre], id[*nxt]);
ans -= dis(id[*pre], x) + dis(x, id[*nxt]);
}
} int main() {
scanf("%d", &n); char ch[2];
for(int x, y, z, i = 1; i < n; ++i)
scanf("%d%d%d", &x, &y, &z), ins(x, y, z), ins(y, x, z);
scanf("%d", &m);
lim = (int)(log(n)/log(2))+1;
dep[1] = 1; dfs(1);
for(int i = 1; i <= n; ++i) id[dfn[i]] = i;
for(int x, i = 1; i <= m; ++i) {
scanf("%s", ch);
if(ch[0] == '+') scanf("%d", &x), add(x);
else if(ch[0] == '-') scanf("%d", &x), del(x);
else printf("%lld\n", ans / 2);
}
}

