2.5 Graph Neural Networks
GNN 处理的是离散结构:节点、边、邻接关系。它的关键不是把 graph flatten 成向量,而是在 graph 上定义可学习的 message passing,使每个节点的表示吸收局部拓扑和属性信息。
Running Example: Three Papers Cite Each Other
考虑一个很小的 citation graph:
A -- B -- C
三个节点是三篇论文,边表示引用或主题相近。每篇论文有一个初始特征,比如 [is_theory, is_system]:
\[ x_A=[1,0],\qquad x_B=[1,1],\qquad x_C=[0,1]. \]
如果我们想判断每篇论文属于 theory-heavy 还是 system-heavy,仅看自身特征可能不够。节点 \(B\) 同时连接 \(A\) 和 \(C\),它的表示应该融合两边信息;节点 \(A\) 也应该知道自己连接了一个 mixed paper。
一轮最简单的 mean aggregation 是:
\[ h_A' = \frac{x_A+x_B}{2}=[1,0.5], \]
\[ h_B' = \frac{x_A+x_B+x_C}{3}=\left[\frac{2}{3},\frac{2}{3}\right], \]
\[ h_C' = \frac{x_B+x_C}{2}=[0.5,1]. \]
这就是 message passing 的最小直觉:每个节点用邻居的信息修正自己。真正的 GNN 会在 aggregation 前后加线性层、非线性、attention 或 edge features。
Graphs and Permutation Equivariance
A graph is \(G=(V,E)\) with node set \(V\), edge set \(E\), node features \(x_v\), and optionally edge features \(e_{uv}\). A graph neural network maps these discrete relational structures into learned representations.
GNN 必须尊重 permutation symmetry:节点编号本身没有语义。若我们重排节点顺序,输出也应对应重排。
Let \(P\) be a permutation matrix. A node-level model \(f\) is permutation equivariant if \[ f(PAP^\top, PX)=P f(A,X). \] For graph-level outputs, the model should be permutation invariant.
这个条件不是“好看”的数学性质,而是必要条件。节点编号通常只是数据文件里的行号。若把节点 \(A,B,C\) 改名为 \(2,0,1\),模型结论不应该改变,只应该随编号重排。
Message Passing
A message passing layer updates node states by \[ m_v^{(k)} = \operatorname{AGG}^{(k)} \left( \{M^{(k)}(h_v^{(k)},h_u^{(k)},e_{uv}):u\in\mathcal{N}(v)\} \right), \] \[ h_v^{(k+1)} = U^{(k)}(h_v^{(k)},m_v^{(k)}). \] The aggregation operator must be permutation invariant.
常见的 aggregation 包括 sum、mean、max 和 attention-weighted sum。sum 最有表达力,mean 更稳定,max 更像检测局部 motif。
为什么 aggregation 必须是 permutation invariant?因为邻居集合
\[ \mathcal{N}(v)=\{u_1,u_2,\ldots,u_d\} \]
没有自然顺序。若把邻居输入顺序打乱,节点 \(v\) 的表示不应该变。也就是说,aggregation 是一个 multiset function:
\[ \operatorname{AGG}(\{h_{u_1},\ldots,h_{u_d}\}) = \operatorname{AGG}(\{h_{u_{\pi(1)}},\ldots,h_{u_{\pi(d)}}\}) \]
对任意置换 \(\pi\) 成立。
For any permutation of neighbor order, sum, mean, and coordinatewise max aggregations return the same value.
对 sum:
\[ \sum_{i=1}^{d}h_{u_{\pi(i)}} = \sum_{i=1}^{d}h_{u_i}, \]
因为加法交换律和结合律成立。mean 只是 sum 除以固定的 \(d\),所以也不变。coordinatewise max 对每个维度取集合最大值,输入顺序不改变集合,因此最大值不变。
实现上,GNN aggregation 通常不是手写 Python for-loop,而是基于 edge list 的 scatter/reduce。设有 directed edge list:
src = [u_1, u_2, ...]
dst = [v_1, v_2, ...]
一层 sum aggregation 可以写成:
msg = message_mlp(h[src], edge_attr)
out = torch.zeros(num_nodes, msg.size(-1), device=h.device)
out.index_add_(0, dst, msg)mean aggregation 还要除以入度:
deg = torch.bincount(dst, minlength=num_nodes).clamp_min(1)
out = out / deg[:, None]这段代码暴露了两个真实工程问题:第一,边方向必须约定清楚,src -> dst 还是 dst -> src 写反会 silently wrong;第二,isolated node 的 degree 为 0,需要 self-loop、residual 或 clamp_min(1) 处理。
GCN
Graph Convolutional Network 使用归一化邻接矩阵:
\[ H^{(k+1)} = \sigma \left( \tilde{D}^{-\frac12} \tilde{A} \tilde{D}^{-\frac12} H^{(k)}W^{(k)} \right), \]
其中 \(\tilde{A}=A+I\),\(\tilde{D}\) 是 degree matrix。
这个公式可以看成:每个节点把邻居表示加权平均后再线性变换。归一化避免高 degree 节点的表示尺度过大。
如果不用归一化,直接用
\[ AHW \]
那么 degree 高的节点会收到更多消息,表示范数可能系统性变大。最朴素的 random-walk normalization 是
\[ D^{-1}A, \]
它让每个节点取邻居平均。但 \(D^{-1}A\) 通常不是对称矩阵,在无向图上会破坏一些谱性质。GCN 使用 symmetric normalization:
\[ S=\tilde{D}^{-1/2}\tilde{A}\tilde{D}^{-1/2}, \]
每条边 \((u,v)\) 的权重是
\[ S_{vu} = \frac{1}{\sqrt{\tilde{d}_v\tilde{d}_u}}. \]
于是节点更新可写成局部形式:
\[ h_v^{(k+1)} = \sigma\left( \sum_{u\in\mathcal{N}(v)\cup\{v\}} \frac{1}{\sqrt{\tilde{d}_v\tilde{d}_u}} h_u^{(k)}W^{(k)} \right). \]
self-loop 的意义也很具体:没有 self-loop 时,节点下一层完全由邻居决定,自身信息只能通过偶数层间接回来;加上 \(I\) 后,节点每层都保留自己的表示通道。
The GCN renormalization trick replaces \(A\) by \(\tilde{A}=A+I\) and uses \(\tilde{D}^{-1/2}\tilde{A}\tilde{D}^{-1/2}\) to combine self-information and degree-normalized neighbor messages.
稀疏实现通常不显式构造 dense \(N\times N\) 矩阵,而是在 edge list 上预计算 norm:
row, col = edge_index # messages col -> row by convention
deg = torch.bincount(row, minlength=num_nodes).float()
deg_inv_sqrt = deg.clamp_min(1).pow(-0.5)
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
msg = h[col] * norm[:, None]
agg = torch.zeros_like(h)
agg.index_add_(0, row, msg)
out = agg @ weight实际库里还要先加 self-loops,并确保 degree 是加 self-loop 后的 \(\tilde{d}\)。常见 bug 是先算 degree 再加 self-loop,导致 normalization 和公式不一致。
For any permutation matrix \(P\), the GCN update \[ F(A,H)=\sigma(\tilde{D}^{-\frac12}\tilde{A}\tilde{D}^{-\frac12}HW) \] satisfies \[ F(PAP^\top,PH)=P F(A,H). \]
令 \(\tilde{A}=A+I\)。重排后
\[ \tilde{A}'=P\tilde{A}P^\top. \]
degree matrix 也随之重排:
\[ \tilde{D}'=P\tilde{D}P^\top. \]
因为 \(P\) 是置换矩阵,\(P^\top P=I\),所以
\[ (\tilde{D}')^{-\frac12}\tilde{A}'(\tilde{D}')^{-\frac12} = P\tilde{D}^{-\frac12}\tilde{A}\tilde{D}^{-\frac12}P^\top. \]
对输入特征 \(H'=PH\),
\[ F(A',H') = \sigma(P\tilde{D}^{-\frac12}\tilde{A}\tilde{D}^{-\frac12}P^\top PHW) = \sigma(P\tilde{D}^{-\frac12}\tilde{A}\tilde{D}^{-\frac12}HW). \]
非线性 \(\sigma\) 逐元素作用,因此与置换可交换:
\[ \sigma(PZ)=P\sigma(Z). \]
所以 \(F(A',H')=PF(A,H)\)。
GAT
Graph Attention Network 学习邻居权重:
\[ e_{uv} = a^\top[Wh_u\Vert Wh_v], \]
\[ \alpha_{uv} = \frac{\exp(\operatorname{LeakyReLU}(e_{uv}))} {\sum_{r\in\mathcal{N}(v)}\exp(\operatorname{LeakyReLU}(e_{rv}))}, \]
\[ h_v' = \sigma\left(\sum_{u\in\mathcal{N}(v)}\alpha_{uv}Wh_u\right). \]
GAT 让模型决定哪些邻居更重要,这与 Transformer attention 的思想相通,但 graph attention 受邻接结构限制。
注意 \(\alpha_{uv}\) 的 softmax 是对固定目标节点 \(v\) 的入边集合做归一化:
\[ \sum_{u\in\mathcal{N}(v)}\alpha_{uv}=1. \]
这叫 edge softmax 或 segment softmax。不能把全图所有边一起 softmax,否则不同节点之间的注意力会互相竞争,完全改变模型语义。
概念实现:
score = leaky_relu(attn_mlp(torch.cat([Wh[src], Wh[dst]], dim=-1)))
alpha = segment_softmax(score, dst) # normalize over edges with same dst
msg = alpha[:, None] * Wh[src]
out = scatter_sum(msg, dst, dim=0, dim_size=num_nodes)multi-head GAT 通常把多个 head concat 或 average:
\[ h_v' = \Vert_{m=1}^{M} \sigma\left( \sum_{u\in\mathcal{N}(v)} \alpha_{uv}^{(m)}W^{(m)}h_u \right). \]
concat 增加表示维度,average 保持维度并更稳定。和 Transformer 类似,多头不是为了“更大参数量”本身,而是让不同 head 学不同关系:同领域引用、同作者、同方法、同时间等。
GAT attention is only computed on existing graph edges unless extra edges are added. It cannot attend to an arbitrary non-neighbor the way a full Transformer can.
GraphSAGE and Inductive Node Embeddings
GCN 常被讲成整图矩阵乘法,但许多真实任务是 inductive 的:测试时会出现训练中没见过的新节点。GraphSAGE 的思想是学习一个邻居采样和聚合函数,而不是为每个节点学习一个固定 embedding。
一层 mean GraphSAGE 可以写成:
\[ m_v^{(k)} = \frac{1}{|\mathcal{S}(v)|} \sum_{u\in\mathcal{S}(v)} h_u^{(k)}, \]
\[ h_v^{(k+1)} = \sigma\left( W^{(k)} [h_v^{(k)}\Vert m_v^{(k)}] \right), \]
其中 \(\mathcal{S}(v)\) 是采样到的邻居子集。采样让计算从“全邻居”变成固定 fanout:
\[ \text{nodes sampled} \approx B\prod_{\ell=1}^{K}f_\ell, \]
其中 \(B\) 是 target batch size,\(f_\ell\) 是第 \(\ell\) 层 fanout。两层 fanout [15, 10] 意味着一个 target batch 会扩展到约 \(150B\) 个邻居节点,实际去重后会小一些。
An inductive GNN learns functions of node features and neighborhoods so that it can compute embeddings for nodes or graphs not seen during training.
GraphSAGE 的工程关键是 mini-batch 里有多层 block:
input nodes for layer 0 -> hidden nodes for layer 1 -> target nodes
每一层只对当前 block 的边做 message passing。伪代码:
seed = target_nodes
blocks = []
frontier = seed
for fanout in reversed([15, 10]):
src, dst, edge = sample_neighbors(frontier, fanout)
blocks.append((src, dst, edge))
frontier = unique(src)
h = x[frontier]
for block, layer in zip(blocks, gnn_layers):
h = layer(block, h)
logits = classifier(h[target_positions])这和普通 dataloader 最大的不同是:一个 batch 的输入节点不等于监督节点。监督只在 target nodes 上,计算却需要 sampled source nodes。
Expressivity: Sum, Mean, and GIN
不是所有 aggregation 一样强。考虑两个邻居 multiset:
\[ \{1,1,2\} \qquad\text{and}\qquad \{1,2,2\}. \]
mean 分别为 \(4/3\) 和 \(5/3\),能区分;但
\[ \{1,3\} \qquad\text{and}\qquad \{2,2\} \]
mean 都是 \(2\),max 都是 \(3\) vs \(2\) 可区分;再换一些集合,max 又会丢失计数信息。sum aggregation 在足够丰富的特征映射后可以保留 multiset 计数,因此更接近 Weisfeiler-Lehman test 的表达力。
Graph Isomorphism Network (GIN) 的形式是
\[ h_v^{(k+1)} = \operatorname{MLP}^{(k)} \left( (1+\epsilon^{(k)})h_v^{(k)} + \sum_{u\in\mathcal{N}(v)}h_u^{(k)} \right). \]
The 1-WL test iteratively colors each node by hashing its current color together with the multiset of neighbor colors. Many message-passing GNNs can be viewed as differentiable analogues of this refinement process.
这给 GNN 一个很清楚的上限:普通 message passing 很难区分 1-WL 也区分不了的图结构。要超越这个限制,通常要加入 higher-order GNN、subgraph features、positional encodings、random features 或 domain-specific structure。
更形式化地说,一层 message passing 的更新只看
\[ \left(h_v^{(k)},\ \{h_u^{(k)}:u\in\mathcal{N}(v)\}\right). \]
如果两个节点在某一轮有相同自身表示和相同邻居表示 multiset,那么任何使用同一个 deterministic aggregation/update 的 MPNN 都会给它们相同的下一层表示。这正是 1-WL color refinement 的连续版本。
If two nodes have identical 1-WL colors at iteration \(k\), then a message-passing GNN initialized from those colors assigns them identical representations at layer \(k\).
用归纳法。初始层 \(k=0\),节点表示由初始 color/feature 决定,因此相同 color 给出相同表示。假设第 \(k\) 层成立。若两个节点在第 \(k+1\) 轮 1-WL 中仍有相同 color,则它们的自身 color 相同,邻居 color multiset 相同。根据归纳假设,这等价于自身表示相同、邻居表示 multiset 相同。message passing 使用同一个 multiset aggregation 和 update function,所以输出表示也相同。
sum aggregation 的优势来自它可以保留计数。对离散有限特征集合,如果为每种 feature 分配一个 one-hot basis,则
\[ \sum_{u\in\mathcal{N}(v)}\phi(h_u) \]
就是邻居 multiset 的 histogram。mean 会除以 degree,可能丢失计数;max 会只保留是否出现或最大响应,也容易丢 multiplicity。GIN 用 sum + MLP,就是为了尽量接近“可学习的 multiset hash”。
Making the MLP deeper or wider does not let a standard message-passing GNN distinguish structures that its aggregation inputs make identical.
Edge Features and Relational Graphs
很多图不只是“有边/无边”。分子图有 bond type,知识图谱有 relation type,交通图有距离和方向。此时 message function 应该依赖边特征:
\[ m_{u\to v}^{(k)} = M^{(k)}(h_u^{(k)},h_v^{(k)},e_{uv}). \]
Relational GCN 对不同关系 \(r\) 使用不同权重:
\[ h_v^{(k+1)} = \sigma\left( \sum_{r\in\mathcal{R}} \sum_{u\in\mathcal{N}_r(v)} \frac{1}{c_{v,r}}W_r^{(k)}h_u^{(k)} +W_0^{(k)}h_v^{(k)} \right). \]
这里 \(r\) 可以是 citation type、knowledge-graph predicate、molecule bond type。关系太多时,\(W_r\) 参数量会爆炸,常用 basis decomposition 或 block decomposition 降参数。
Oversmoothing and Oversquashing
Stacking many message passing layers can cause oversmoothing, where node embeddings become indistinguishable, and oversquashing, where information from exponentially many distant nodes is compressed into fixed-size vectors.
Oversmoothing 来自反复邻居平均;oversquashing 来自图上远距离依赖需要穿过少数瓶颈边。解决方法包括 residual connection、normalization、jumping knowledge、graph rewiring、positional encoding。
可以用一个极简线性 GCN 看 oversmoothing。若忽略 \(W\) 和非线性,反复更新是
\[ H^{(K)} = S^KX, \qquad S=\tilde{D}^{-\frac12}\tilde{A}\tilde{D}^{-\frac12}. \]
在连通图上,\(S^K\) 会逐渐压制非主特征方向,节点表示变得接近低频平滑信号。对 node classification,有时这正是需要的 inductive bias;但层数太深时,不同类别边界也被抹平。
Oversquashing 则更像信息瓶颈。若一个节点 \(v\) 要接收 \(K\) 跳外指数增长的节点信息,但 hidden size 固定为 \(d\),所有信息都被压进一个 \(d\) 维向量。图上存在桥边或树状扩张时,这个问题尤其严重。
谱视角下,如果 \(S\) 的特征分解为
\[ S=U\Lambda U^\top, \]
则
\[ S^KX = U\Lambda^K U^\top X. \]
当 \(|\lambda_i|<1\) 时,\(\lambda_i^K\to 0\),高频或非主方向逐渐被抹掉,只剩低频平滑分量。这就是 oversmoothing 的数学图像:深层 GCN 越来越像在图上反复做 low-pass filtering。
Oversquashing occurs when information from many distant nodes must be compressed through a small number of paths or a fixed-dimensional hidden state before reaching a target node.
oversquashing 和图曲率/瓶颈有关。比如一棵二叉树从深度 \(K\) 的叶子向根传信息,叶子数是 \(2^K\),但根节点每层只接收固定维度 message。即使没有 oversmoothing,信息量也被结构性压缩。常见缓解策略:
| problem | symptom | mitigation |
|---|---|---|
| oversmoothing | node embeddings become too similar | residual, normalization, jumping knowledge |
| oversquashing | long-range signals cannot pass bottlenecks | graph rewiring, virtual nodes, attention edges |
| heterophily | neighbors often have different labels | separate ego/neighbor channels, signed/typed edges |
heterophily 特别值得注意。GCN 的 local smoothing 假设邻居倾向于相似;在 fraud detection、交易图、某些网页图里,连接可能表示交互或对抗,不表示同类。此时简单平均邻居会把有用的类别边界抹掉。一个常见改法是保留 ego representation:
\[ h_v^{(k+1)} = \sigma\left( W_{\text{self}}h_v^{(k)} + W_{\text{nbr}} \operatorname{AGG}\{h_u^{(k)}:u\in\mathcal{N}(v)\} \right), \]
让模型自己学“自身特征”和“邻居特征”应不应该相似。
Training Objectives
GNN 不只是 supervised node classification。常见训练范式包括:
| Task | Objective | Example |
|---|---|---|
| node classification | cross entropy over labeled nodes | citation networks |
| link prediction | score positive/negative edges | recommender systems |
| graph classification | pooled graph representation | molecules |
| contrastive learning | align augmented graph views | self-supervised GNN |
| masked attribute modeling | reconstruct masked nodes/edges | graph pretraining |
这就引出了离散结构学习的一个重要主题:很多监督信号不是连续像素,而是 node id、edge existence、subgraph pattern、token identity 这类离散对象。
node classification 常见于半监督 transductive setting:训练时整张图的结构和节点特征可见,但只有一部分节点有 label。loss 只在 labeled mask 上计算:
\[ \mathcal{L}_{\text{node}} = \frac{1}{|\mathcal{V}_{\text{train}}|} \sum_{v\in\mathcal{V}_{\text{train}}} \operatorname{CE}(y_v,\hat{y}_v). \]
这和 attention mask / loss mask 的区别很像:message passing 可以使用 unlabeled nodes 的特征和边,但监督项只来自 train mask。若把 validation/test labels 用进训练 loss,就是直接泄漏;若图结构包含未来边,则可能是更隐蔽的泄漏。
Link Prediction and Negative Sampling
Link prediction 常用一个 edge score:
\[ s(u,v)=h_u^\top h_v \]
或 MLP:
\[ s(u,v)=\operatorname{MLP}([h_u\Vert h_v\Vert h_u\odot h_v]). \]
正样本是存在的边 \((u,v)\in E\),负样本是不存在的边 \((u,v)\notin E\)。binary cross entropy:
\[ \mathcal{L} = - \sum_{(u,v)\in E^+}\log\sigma(s(u,v)) - \sum_{(u,v)\in E^-}\log(1-\sigma(s(u,v))). \]
负采样不是小细节。若负样本太容易,模型只学会 degree 或 popularity;若负样本包含未来信息,评估会泄漏。推荐系统和 citation prediction 中,时间切分通常比随机切边更可信。
负采样分布会改变学习到的 score calibration。若从所有非边均匀采样,绝大多数负样本可能非常容易;若按 degree 或 hard negative 采样,任务更难但更贴近 ranking。一个常见 sampled BCE 是
\[ \mathcal{L} = - \sum_{(u,v)\in E^+} \left[ \log\sigma(s(u,v)) + \sum_{v^-\sim q(\cdot\mid u)} \log\sigma(-s(u,v^-)) \right]. \]
这里 \(q\) 是负采样分布。训练的是 sampled objective,不是全体非边上的完整 likelihood。因此比较不同论文/系统时,需要同时看 negative sampler、正负比例和评估候选集合。
temporal link prediction 应该按时间切分:
train edges: t <= T_train
valid edges: T_train < t <= T_valid
test edges: t > T_valid
并且构造 node features 时也不能使用未来信息。随机删边评估更容易,但经常高估真实预测能力。
If validation edges are removed from the label set but their endpoints remain connected through derived features or future edges, link prediction metrics can be overly optimistic.
Graph Readout
graph-level representation 通常需要 readout:
\[ h_G=\operatorname{READOUT}(\{h_v:v\in V\}). \]
READOUT 必须 permutation invariant。常见选择是 sum/mean pooling、attention pooling、Set2Set。
Mini-Batch Training on Large Graphs
整图 GCN 的矩阵形式清楚,但大图上不能每次都处理所有节点。常见 mini-batch 策略:
| Strategy | Idea | Risk |
|---|---|---|
| neighbor sampling | sample fixed neighbors per layer | biased neighborhood |
| subgraph sampling | train on induced subgraphs | boundary effects |
| cluster sampling | partition graph into clusters | cross-cluster edges lost |
| layer-wise sampling | sample nodes per layer | implementation complexity |
以两层 neighbor sampling 为例,要预测一批 target nodes \(B\),需要先采它们的一跳邻居,再采这些邻居的一跳邻居。实际计算图是一个 computation subgraph:
2-hop sampled nodes -> 1-hop sampled nodes -> target nodes
这和普通 i.i.d. mini-batch 不一样。一个 target node 的表示依赖被采到的邻居集合,因此采样策略本身就是模型近似的一部分。
fanout 的计算非常直观但容易低估。若 batch 有 \(B\) 个 target nodes,两层 fanout 是 \((f_1,f_2)\),最坏情况下计算节点数近似
\[ B(1+f_1+f_1f_2). \]
三层就变成 \(B(1+f_1+f_1f_2+f_1f_2f_3)\)。这就是为什么深层 GNN 在大图上不只是 oversmoothing,还会有 sampling explosion。
Neighbor Sampling as an Estimator
neighbor sampling 不是普通 dataloader 的小优化,它改变了每层 aggregation 的估计方式。以 mean aggregator 为例,完整邻居平均是:
\[ \mu_v = \frac{1}{|\mathcal{N}(v)|} \sum_{u\in\mathcal{N}(v)}h_u. \]
若从 \(\mathcal{N}(v)\) 中均匀无放回采样 \(s\) 个邻居,记采样集合为 \(\mathcal{S}(v)\),sample mean 是:
\[ \hat{\mu}_v = \frac{1}{s} \sum_{u\in\mathcal{S}(v)}h_u. \]
For uniform sampling without replacement from \(\mathcal{N}(v)\), \[ \mathbb{E}[\hat{\mu}_v]=\mu_v. \]
令 \(I_u\) 表示邻居 \(u\) 是否被采到。无放回均匀采样 \(s\) 个邻居时:
\[ \mathbb{E}[I_u]=\frac{s}{|\mathcal{N}(v)|}. \]
于是:
\[ \mathbb{E}[\hat{\mu}_v] = \mathbb{E} \left[ \frac1s\sum_{u\in\mathcal{N}(v)}I_uh_u \right] = \frac1s \sum_{u\in\mathcal{N}(v)} \mathbb{E}[I_u]h_u \]
\[ = \frac1s \sum_{u\in\mathcal{N}(v)} \frac{s}{|\mathcal{N}(v)|}h_u = \frac{1}{|\mathcal{N}(v)|} \sum_{u\in\mathcal{N}(v)}h_u = \mu_v. \]
这条结论只说明“线性 mean aggregation 的邻居均值估计无偏”。完整 GNN layer 通常还有 nonlinear update:
\[ h_v^{(k+1)} = \sigma \left( W[h_v^{(k)}\Vert \hat{\mu}_v^{(k)}] \right). \]
一般来说:
\[ \mathbb{E} \left[ \sigma(W[h_v\Vert \hat{\mu}_v]) \right] \neq \sigma(W[h_v\Vert \mu_v]). \]
所以 neighbor sampling 不是“无损缩小 batch”,而是一个随机近似。fanout 越小,方差越大;层数越深,采样噪声会穿过更多 nonlinear updates。
The sampled mean can be an unbiased estimator of the full-neighborhood mean, but the full nonlinear GNN output and training gradient are generally biased after nonlinear transformations and multi-layer sampling.
若采样不是 uniform,而是按分布 \(q(u\mid v)\) 抽样,可以用 importance correction 估计 sum aggregator:
\[ \sum_{u\in\mathcal{N}(v)}h_u \approx \frac1s \sum_{j=1}^{s} \frac{h_{u_j}}{q(u_j\mid v)}. \]
mean aggregator 则除以 \(|\mathcal{N}(v)|\):
\[ \mu_v \approx \frac{1}{s|\mathcal{N}(v)|} \sum_{j=1}^{s} \frac{h_{u_j}}{q(u_j\mid v)}. \]
这个修正理论上漂亮,但工程上不总是用。原因是 \(1/q\) 可能让高方差邻居支配训练,而 GNN 任务往往更关心 ranking/分类效果和吞吐。真实系统需要把 sampler 当成模型的一部分来调,而不是当成透明加速器。
一个最小 uniform fanout sampler 可以写成:
import torch
def sample_fanout(csr_indptr, csr_indices, seed_nodes, fanout, generator):
sampled_src = []
sampled_dst = []
for dst in seed_nodes.tolist():
start = int(csr_indptr[dst])
end = int(csr_indptr[dst + 1])
neigh = csr_indices[start:end]
deg = neigh.numel()
if deg == 0:
continue
if fanout < 0 or deg <= fanout:
picked = neigh
else:
perm = torch.randperm(deg, generator=generator, device=neigh.device)
picked = neigh[perm[:fanout]]
sampled_src.append(picked)
sampled_dst.append(torch.full_like(picked, dst))
if not sampled_src:
empty = torch.empty(0, dtype=seed_nodes.dtype, device=seed_nodes.device)
return empty, empty
return torch.cat(sampled_src), torch.cat(sampled_dst)这里使用 CSR adjacency 是因为大图采样的瓶颈通常是邻接访问,而不是神经网络本身。若每次从 COO edge_index 里筛选邻居,会把采样变成昂贵的全边扫描。
Log sampled degree histograms per layer. If most high-degree nodes are truncated to the same fanout, the model sees a flattened neighborhood distribution and may overfit to sampler artifacts.
真实系统常把采样结果表示成 blocks:
| object | meaning |
|---|---|
| seed nodes | 有监督 loss 的 target nodes |
| input nodes | 为计算 seed 表示而需要读取特征的所有节点 |
| block edges | 当前层从 source nodes 到 destination nodes 的 sampled edges |
| node mapping | global node ids 到 mini-batch local ids 的映射 |
Block Mapping and Feature Fetch Contract
mini-batch GNN 的实现细节常败在 id mapping。采样器输出的是 global node id;message passing kernel 需要的是 batch-local contiguous id。若映射错了,shape 仍然对,但特征会对错节点。
A GNN block is a bipartite computation graph for one message-passing layer, containing source nodes, destination nodes, sampled edges from sources to destinations, and the mapping between global ids and local tensor rows.
对一层 block,应该明确四个对象:
\[ \texttt{src\_global}, \quad \texttt{dst\_global}, \quad \texttt{edge\_local}, \quad \texttt{dst\_pos}. \]
其中 edge_local[0] 索引 src_global 的行,edge_local[1] 索引 dst_global 的行。一个安全构造流程是:
def build_block(src_global, dst_global, edge_src_global, edge_dst_global):
src_global = torch.unique(src_global)
dst_global = torch.unique(dst_global)
src_pos = {int(n): i for i, n in enumerate(src_global.tolist())}
dst_pos = {int(n): i for i, n in enumerate(dst_global.tolist())}
edge_src = torch.tensor(
[src_pos[int(n)] for n in edge_src_global.tolist()],
dtype=torch.long,
)
edge_dst = torch.tensor(
[dst_pos[int(n)] for n in edge_dst_global.tolist()],
dtype=torch.long,
)
edge_local = torch.stack([edge_src, edge_dst], dim=0)
return src_global, dst_global, edge_local这个 Python dict 版本适合教学;生产实现会用排序、hash table 或框架内置的 relabeling kernel。无论实现如何,必须满足:
\[ \texttt{feature\_rows}[i] = x_{\texttt{src\_global}[i]}. \]
一个非常有用的 smoke test 是用 node id 自身作为特征,检查 message 方向和 mapping:
def check_block_mapping(x_global, src_global, edge_local, edge_src_global):
x_block = x_global[src_global]
src_local = edge_local[0]
recovered_global = x_block[src_local]
expected_global = x_global[edge_src_global]
assert torch.equal(recovered_global, expected_global)如果这个测试失败,GNN 仍可能跑出 loss,但它学的是错位特征。大图系统里这种 bug 很隐蔽,因为每个 batch 的局部 id 都不同,错误不会表现为固定一行错位。
Temporal and Leakage-Aware Sampling
很多图是时间图:交易、点击、引用、社交互动都带时间。此时训练节点或训练边的邻居集合不能包含未来边。对时间 \(t\) 的样本,message passing 应该只使用:
\[ \mathcal{N}_{\leq t}(v) = \{u:(u,v,\tau)\in E,\ \tau\leq t\}. \]
若 sampler 在全图上采邻居,即使 loss 只在 train labels 上算,也可能把未来结构泄漏进 node representation。
def sample_temporal_neighbors(csr_by_node, node, cutoff_time, fanout, generator):
neigh, times = csr_by_node[node]
valid = neigh[times <= cutoff_time]
if valid.numel() <= fanout:
return valid
idx = torch.randperm(valid.numel(), generator=generator)[:fanout]
return valid[idx]Removing future labels is not enough for temporal graph prediction. Node features, degree features, sampled neighborhoods, and negative candidates must also be constructed using only information available at the prediction time.
一个训练 step 的高层逻辑:
for seed_nodes, blocks in dataloader:
x = feature_store.fetch(blocks[0].src_nodes)
h = x
for layer, block in zip(layers, blocks):
h = layer(block.edge_index, h)
logits = classifier(h[block.target_positions])
loss = F.cross_entropy(logits, y[seed_nodes])大图训练的瓶颈经常不是矩阵乘,而是采样、特征读取和 CPU-GPU 传输。因此要记录的不只是 accuracy:
| metric | why it matters |
|---|---|
| sampled nodes/edges per batch | fanout 是否爆炸 |
| feature fetch time | 是否被 CPU 或存储拖慢 |
| GPU utilization | message passing kernel 是否吃满 |
| duplicate node ratio | 采样去重效率 |
| sampled degree histogram | fanout 是否截断了重要高阶节点 |
| mapping check failures | global/local id 是否错位 |
| train/valid edge-time split | 是否发生未来泄漏 |
Neighbor-sampled batches overlap and depend on graph topology. Shuffling seed nodes does not make sampled computation graphs independent in the usual tabular-data sense.
Relation to Transformers
Transformer 可以看作完全图上的 message passing:每个 token 都能向每个 token 发送消息,edge weight 由 attention 动态决定。GNN 则通常有稀疏显式 graph structure。
| Model | Connectivity | Weighting | Inductive bias |
|---|---|---|---|
| GCN | fixed graph edges | degree normalized | local smoothing |
| GAT | fixed graph edges | learned attention | relational selectivity |
| Transformer | complete graph | learned attention | content-addressed routing |
所以 GNN 是理解 Transformer 的另一个入口:attention 本质上也是 message passing,只是图结构从固定邻接变成了动态可学习邻接。
Implementation Checklist
实现或调试 GNN 时,可以按下面顺序检查:
edge_index方向是否符合 message convention,例如src -> dst;- 是否按公式添加 self-loops,并在加 self-loop 后重新计算 degree;
- GCN normalization 是否使用 \(\tilde{D}^{-1/2}\tilde{A}\tilde{D}^{-1/2}\) 对应的边权;
- isolated nodes 是否有 self-loop、residual 或安全的 degree clamp;
- aggregation 是否 permutation invariant,且没有依赖邻居输入顺序;
- GAT softmax 是否在同一个目标节点的入边上做 segment softmax;
- node classification loss 是否只在 train node mask 上计算;
- link prediction negative sampler 是否和评估候选集合一致;
- edge split 是否按时间或任务语义避免未来信息泄漏;
- neighbor sampling 的 fanout 是否导致 sampled nodes 爆炸;
- batch blocks 的 global/local node id mapping 是否正确;
- 是否记录 sampled edge 数、feature fetch time、GPU utilization 和 duplicate ratio。
- temporal graph sampler 是否只使用 cutoff time 前的边和特征;
- sampled degree distribution 是否被 fanout 截断得过于平坦。
两个 smoke tests 很有帮助:
# 1. permutation equivariance test
perm = torch.randperm(num_nodes)
out1 = model(edge_index, x)
edge_perm = remap_edges(edge_index, perm)
out2 = model(edge_perm, x[perm])
assert torch.allclose(out2, out1[perm], atol=1e-5)
# 2. edge-direction test on a tiny chain
# A -> B -> C should update B from A and C from B, not the reverse.如果这两个测试失败,模型即使在某个 benchmark 上能训练,也很可能是在利用实现偏差或数据泄漏,而不是正确的 graph inductive bias。