Skip to content

VirtualNode

gtrick.pyg.VirtualNode(in_feats, out_feats, dropout=0.5, residual=False)

Bases: nn.Module

Virtual Node from OGB Graph Property Prediction Examples.

It adds an virtual node to all nodes in the graph. This trick is helpful for Graph Level Task.

Note

To use this trick, call update_node_emb at first, then call update_vn_emb.

Examples:

VirtualNode (PyG)

Parameters:

Name Type Description Default
in_feats int

Feature size before conv layer.

required
out_feats int

Feature size after conv layer.

required
dropout float

Dropout rate on virtual node embedding.

0.5
residual bool

If True, use residual connection.

False

update_node_emb(x, edge_index, batch, vx=None)

Add message from virtual nodes to graph nodes.

Parameters:

Name Type Description Default
x torch.Tensor

The input node feature.

required
edge_index torch.LongTensor

Graph connectivity.

required
batch torch.LongTensor

Batch vector, which assigns each node to a specific example.

required
vx torch.Tensor

Optional virtual node embedding.

None

Returns:

Type Description
torch.Tensor

The output node feature.

torch.Tensor

The output virtual node embedding.

update_vn_emb(x, batch, vx)

Add message from graph nodes to virtual node.

Parameters:

Name Type Description Default
x torch.Tensor

The input node feature.

required
batch LongTensor

Batch vector, which assigns each node to a specific example.

required
vx torch.Tensor

Optional virtual node embedding.

required

Returns:

Type Description
torch.Tensor

The output virtual node embedding.