FLAG
gtrick.FLAG(emb_dim, loss_func, optimizer, m=3, step_size=0.001, mag=-1)
¶
FLAG is an adversarial data augmentation method for Graph Neural Networks, which comes from Robust Optimization as Data Augmentation for Large-scale Graphs.
This trick is helpful for Node Level Task and Graph Level Task.
Example
Parameters:
Name | Type | Description | Default |
---|---|---|---|
emb_dim |
int
|
Node feature dim. |
required |
loss_func |
torch.nn.Module
|
Loss function. |
required |
optimizer |
torch.optim.Optimizer)
|
Optimizer. |
required |
m |
int
|
Ascent steps. Train the same minibatch m times. |
3
|
step_size |
float
|
Ascent step size. If mag <= 0, perturb is initialized from uniform distribution [-step_size, step_size]. |
0.001
|
mag |
float
|
If mag > 0, it controls the max norm of perturb. |
-1
|
__call__(model, forward, num_nodes, y)
¶
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model |
torch.nn.Module
|
The model. |
required |
forward |
Callable[[torch.Tensor], torch.Tensor]
|
The function that inputs perturb and gets output. |
required |
num_nodes |
int
|
The number of nodes. |
required |
y |
torch.Tensor
|
The ground truth label. |
required |
Returns:
Type | Description |
---|---|
torch.Tensor
|
The loss. |
torch.Tensor
|
The output of the model. |