1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
| #include<iostream> #include<cstdio> #include<cstring> #include<string> #include<algorithm> using namespace std;
#define int long long
const int CN = 1e6 + 6; const int INF = 0x3f3f3f3f3f3f3f3f;
class fs {public: int to,nxt,w; void init(int t,int n,int d) {to = t, nxt = n, w = d;} } E[CN << 1]; int hd[CN], ecnt = 0; void add(int x,int y,int z) {E[++ecnt].init(y, hd[x], z); hd[x] = ecnt;}
int n, q;
int fa[CN][21], dep[CN], dis[CN], dfn[CN], idx = 0; void dfs(int u, int p){ fa[u][0] = p, dep[u] = dep[p] + 1, dfn[u] = ++idx; for(int k = hd[u]; k; k = E[k].nxt) {int v = E[k].to; if(v ^ p) dis[v] = dis[u] + E[k].w, dfs(v, u);} } int lca(int u, int v){ if(dep[u] < dep[v]) swap(u, v); for(int k = 20; k + 1; k--) if(dep[ fa[u][k] ] >= dep[v]) u = fa[u][k]; if(u ^ v){ for(int k = 20; k + 1; k--) if(fa[u][k] ^ fa[v][k]) u = fa[u][k], v = fa[v][k]; u = fa[u][0]; } return u; }
bool cmp(int i, int j) {return dfn[i] < dfn[j];} int stk[CN], top = 0, a[CN], rt; void bd(){ sort(a + 1, a + a[0] + 1, cmp); stk[top = 1] = a[1], hd[ a[1] ] = 0, ecnt = 0; for(int i = 2; i <= a[0]; i++){ int l = lca(stk[top], a[i]); if(l ^ stk[top]){ while(dfn[ stk[top - 1] ] > dfn[l]) add(stk[top - 1], stk[top], 0), top--; if(l ^ stk[top - 1]) hd[l] = 0, add(l, stk[top], 0), stk[top] = l; else add(l, stk[top], 0), top--; } hd[ a[i] ] = 0, stk[++top] = a[i]; } rt = stk[1]; for(int i = 1; i < top; i++) add(stk[i], stk[i + 1], 0); }
int mn[CN], mx[CN], sz[CN], amn, amx, ans, tmp1[4], tmp2[4]; bool is[CN]; void dp(int u, int p){ if(!hd[u]) return (void)(mn[u] = mx[u] = 0, sz[u] = 1);
mn[u] = INF, mx[u] = sz[u] = 0; int Mn = INF, pMn = INF, Mx = 0, pMx = 0; for(int k = hd[u]; k; k = E[k].nxt){ int v = E[k].to, d = dis[v] - dis[u]; if(v == p) continue; dp(v, u); mn[u] = min(mn[u], mn[v] + d), tmp1[0] = mn[v] + d, mx[u] = max(mx[u], mx[v] + d), tmp2[0] = mx[v] + d; sz[u] += sz[v];
ans += d * (a[0] - sz[v]) * sz[v];
tmp1[1] = Mn, tmp1[2] = pMn, tmp2[1] = Mx, tmp2[2] = pMx; sort(tmp1, tmp1 + 3), sort(tmp2, tmp2 + 3, greater<int>()); Mn = tmp1[0], pMn = tmp1[1], Mx = tmp2[0], pMx = tmp2[1]; }
amn = min(amn, Mn + pMn), amx = max(amx, Mx + pMx); if(is[u]) sz[u]++, amn = min(amn, Mn), mn[u] = 0; }
signed main() { n = read(); for(int i = 1; i < n; i++) {int u = read(), v = read(); add(u, v, 1), add(v, u, 1);} dfs(1, 0); for(int k = 1; k <= 20; k++) for(int i = 1; i <= n; i++) fa[i][k] = fa[ fa[i][k - 1] ][k - 1];
q = read(); while(q--){ a[0] = read(); for(int i = 1; i <= a[0]; i++) a[i] = read(), is[ a[i] ] = true; bd(), amn = INF, amx = ans = 0; if(a[0] == 1) ans = amn = amx = 0; dp(rt, 0), printf("%lld %lld %lld", ans, amn, amx), puts(""); for(int i = 1; i <= a[0]; i++) is[ a[i] ] = false; } }
|