Skip to content

FLAG

gtrick.FLAG(emb_dim, loss_func, optimizer, m=3, step_size=0.001, mag=-1)

FLAG is an adversarial data augmentation method for Graph Neural Networks, which comes from Robust Optimization as Data Augmentation for Large-scale Graphs.

This trick is helpful for Node Level Task and Graph Level Task.

Example

FLAG (DGL), FALG (PyG)

Parameters:

Name Type Description Default
emb_dim int

Node feature dim.

required
loss_func torch.nn.Module

Loss function.

required
optimizer torch.optim.Optimizer)

Optimizer.

required
m int

Ascent steps. Train the same minibatch m times.

3
step_size float

Ascent step size. If mag <= 0, perturb is initialized from uniform distribution [-step_size, step_size].

0.001
mag float

If mag > 0, it controls the max norm of perturb.

-1

__call__(model, forward, num_nodes, y)

Parameters:

Name Type Description Default
model torch.nn.Module

The model.

required
forward Callable[[torch.Tensor], torch.Tensor]

The function that inputs perturb and gets output.

required
num_nodes int

The number of nodes.

required
y torch.Tensor

The ground truth label.

required

Returns:

Type Description
torch.Tensor

The loss.

torch.Tensor

The output of the model.