LabelPropagation
gtrick.dgl.LabelPropagation(num_layers, alpha)
¶
Bases: nn.Module
The label propagation operator from the "Learning from Labeled and Unlabeled Datawith Label Propagation" paper.
This trick is helpful for Node Level Task.
\[
\mathbf{Y}^{\prime} = \alpha \cdot \mathbf{D}^{-1/2} \mathbf{A}
\mathbf{D}^{-1/2} \mathbf{Y} + (1 - \\alpha) \mathbf{Y},
\]
where unlabeled data is inferred by labeled data via propagation.
Examples:
Parameters:
Name | Type | Description | Default |
---|---|---|---|
num_layers |
int
|
The number of propagations. |
required |
alpha |
float
|
The \(\alpha\) coefficient. |
required |
forward(graph, y, mask=None, edge_weight=None, post_step=lambda y: y.clamp_(0.0, 1.0))
¶
Parameters:
Name | Type | Description | Default |
---|---|---|---|
graph |
dgl.DGLGraph
|
The graph. |
required |
y |
torch.Tensor
|
The ground-truth label information of training nodes. |
required |
mask |
torch.LongTensor or BoolTensor
|
A mask or index tensor denoting which nodes were used for training. |
None
|
edge_weight |
torch.Tensor
|
The edge weights. |
None
|
post_step |
Callable[[torch.Tensor], torch.Tensor]
|
The post-process function. |
lambda y: y.clamp_(0.0, 1.0)
|
Returns:
Type | Description |
---|---|
torch.Tensor
|
The obtained prediction. |