CorrectAndSmooth
gtrick.dgl.CorrectAndSmooth(num_correction_layers, correction_alpha, num_smoothing_layers, smoothing_alpha, autoscale=True, scale=1.0)
¶
Bases: nn.Module
The correct and smooth (C&S) post-processing model from the "Combining Label Propagation And Simple Models Out-performs Graph Neural Networks paper, where soft predictions \(\mathbf{Z}\) (obtained from a simple base predictor) are first corrected based on ground-truth training label information \(\mathbf{Y}\) and residual propagation
where \(\gamma\) denotes the scaling factor (either fixed or automatically determined), and then smoothed over the graph via label propagation
to obtain the final prediction \(\mathbf{\hat{Z}}^{(L_2)}\).
This trick is helpful for Node Level Task.
Note
To use this trick, call correct
at first, then call smooth
.
Examples: CorrectAndSmooth (DGL)
Parameters:
Name | Type | Description | Default |
---|---|---|---|
num_correction_layers |
int
|
The number of propagations \(L_1\). |
required |
correction_alpha |
float
|
The \(\alpha_1\) coefficient. |
required |
num_smoothing_layers |
int
|
The number of propagations \(L_2\). |
required |
smoothing_alpha |
float
|
The \(\alpha_2\) coefficient. |
required |
autoscale |
bool
|
If set to |
True
|
scale |
float
|
The scaling factor \(\gamma\), in case
|
1.0
|
correct(graph, y_soft, y_true, mask, edge_weight=None)
¶
Parameters:
Name | Type | Description | Default |
---|---|---|---|
graph |
dgl.DGLGraph
|
The graph. |
required |
y_soft |
Tensor
|
The soft predictions \(\mathbf{Z}\) obtained from a simple base predictor. |
required |
y_true |
Tensor
|
The ground-truth label information \(\mathbf{Y}\) of training nodes. |
required |
mask |
LongTensor or BoolTensor
|
A mask or index tensor denoting which nodes were used for training. |
required |
edge_weight |
Tensor
|
The edge weights. |
None
|
Returns:
Type | Description |
---|---|
torch.Tensor
|
The corrected prediction. |
smooth(graph, y_soft, y_true, mask, edge_weight=None)
¶
Parameters:
Name | Type | Description | Default |
---|---|---|---|
graph |
dgl.DGLGraph
|
The graph. |
required |
y_soft |
Tensor
|
The soft predictions \(\mathbf{Z}\) obtained from a simple base predictor. |
required |
y_true |
Tensor
|
The ground-truth label information \(\mathbf{Y}\) of training nodes. |
required |
mask |
LongTensor or BoolTensor
|
A mask or index tensor denoting which nodes were used for training. |
required |
edge_weight |
Tensor
|
The edge weights. |
None
|
Returns:
Type | Description |
---|---|
torch.Tensor
|
The final prediction. |