VirtualNode
gtrick.pyg.VirtualNode(in_feats, out_feats, dropout=0.5, residual=False)
¶
Bases: nn.Module
Virtual Node from OGB Graph Property Prediction Examples.
It adds an virtual node to all nodes in the graph. This trick is helpful for Graph Level Task.
Note
To use this trick, call update_node_emb
at first, then call update_vn_emb
.
Examples:
Parameters:
Name | Type | Description | Default |
---|---|---|---|
in_feats |
int
|
Feature size before conv layer. |
required |
out_feats |
int
|
Feature size after conv layer. |
required |
dropout |
float
|
Dropout rate on virtual node embedding. |
0.5
|
residual |
bool
|
If True, use residual connection. |
False
|
update_node_emb(x, edge_index, batch, vx=None)
¶
Add message from virtual nodes to graph nodes.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
torch.Tensor
|
The input node feature. |
required |
edge_index |
torch.LongTensor
|
Graph connectivity. |
required |
batch |
torch.LongTensor
|
Batch vector, which assigns each node to a specific example. |
required |
vx |
torch.Tensor
|
Optional virtual node embedding. |
None
|
Returns:
Type | Description |
---|---|
torch.Tensor
|
The output node feature. |
torch.Tensor
|
The output virtual node embedding. |
update_vn_emb(x, batch, vx)
¶
Add message from graph nodes to virtual node.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x |
torch.Tensor
|
The input node feature. |
required |
batch |
LongTensor
|
Batch vector, which assigns each node to a specific example. |
required |
vx |
torch.Tensor
|
Optional virtual node embedding. |
required |
Returns:
Type | Description |
---|---|
torch.Tensor
|
The output virtual node embedding. |