Tutorial: Training a Node Classification Model (PyG) on Your Local Machine#
This tutorial presents an end-to-end example that illustrates how GraphScope trains the GraphSAGE model (implemented in PyG) for a node classification task.
Load Graph#
import time
import torch
import torch.nn.functional as F
from ogb.nodeproppred import Evaluator
from torch_geometric.nn import GraphSAGE
import graphscope as gs
import graphscope.learning.graphlearn_torch as glt
from graphscope.dataset import load_ogbn_arxiv
from graphscope.learning.graphlearn_torch.typing import Split
gs.set_option(show_log=True)
# load the ogbn_arxiv graph as an example.
g = load_ogbn_arxiv()
Define the evaluation function#
@torch.no_grad()
def test(model, test_loader, dataset_name):
evaluator = Evaluator(name=dataset_name)
model.eval()
xs = []
y_true = []
for i, batch in enumerate(test_loader):
if i == 0:
device = batch.x.device
batch.x = batch.x.to(torch.float32) # TODO
x = model(batch.x, batch.edge_index)[: batch.batch_size]
xs.append(x.cpu())
y_true.append(batch.y[: batch.batch_size].cpu())
del batch
xs = [t.to(device) for t in xs]
y_true = [t.to(device) for t in y_true]
y_pred = torch.cat(xs, dim=0).argmax(dim=-1, keepdim=True)
y_true = torch.cat(y_true, dim=0).unsqueeze(-1)
test_acc = evaluator.eval(
{
"y_true": y_true,
"y_pred": y_pred,
}
)["acc"]
return test_acc
Launch the Learning Engine#
glt_graph = gs.graphlearn_torch(
g,
edges=[
("paper", "citation", "paper"),
],
node_features={
"paper": [f"feat_{i}" for i in range(128)],
},
node_labels={
"paper": "label",
},
edge_dir="out",
random_node_split={
"num_val": 0.1,
"num_test": 0.1,
},
)
print("-- Initializing client ...")
glt.distributed.init_client(
num_servers=1,
num_clients=1,
client_rank=0,
master_addr=glt_graph.master_addr,
master_port=glt_graph.server_client_master_port,
num_rpc_threads=4,
is_dynamic=True,
)
Create neighbor loaderfor training, testing and validation#
device = torch.device("cpu")
# Create distributed neighbor loader on remote server for training.
print("-- Creating training dataloader ...")
train_loader = glt.distributed.DistNeighborLoader(
data=None,
num_neighbors=[15, 10, 5],
input_nodes=Split.train,
batch_size=512,
shuffle=True,
collect_features=True,
to_device=device,
worker_options=glt.distributed.RemoteDistSamplingWorkerOptions(
num_workers=1,
worker_devices=[torch.device("cpu")],
worker_concurrency=1,
buffer_size="1GB",
prefetch_size=1,
glt_graph=glt_graph,
workload_type="train",
),
)
# Create distributed neighbor loader on remote server for testing.
print("-- Creating testing dataloader ...")
test_loader = glt.distributed.DistNeighborLoader(
data=None,
num_neighbors=[15, 10, 5],
input_nodes=Split.test,
batch_size=512,
shuffle=False,
collect_features=True,
to_device=device,
worker_options=glt.distributed.RemoteDistSamplingWorkerOptions(
num_workers=1,
worker_devices=[torch.device("cpu")],
worker_concurrency=1,
buffer_size="1GB",
prefetch_size=1,
glt_graph=glt_graph,
workload_type="test",
),
)
Define the PyG GraphSage Model and optimizer#
print("-- Initializing model and optimizer ...")
model = GraphSAGE(
in_channels=128,
hidden_channels=256,
num_layers=3,
out_channels=47,
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
Train and test#
print("-- Start training and testing ...")
epochs = 10
dataset_name = "ogbn-arxiv"
for epoch in range(0, epochs):
model.train()
start = time.time()
for batch in train_loader:
optimizer.zero_grad()
batch.x = batch.x.to(torch.float32) # TODO
out = model(batch.x, batch.edge_index)[: batch.batch_size].log_softmax(dim=-1)
loss = F.nll_loss(out, batch.y[: batch.batch_size])
loss.backward()
optimizer.step()
end = time.time()
print(f"-- Epoch: {epoch:03d}, Loss: {loss:.4f}, Epoch Time: {end - start}")
# Test accuracy.
if epoch == 0 or epoch > (epochs // 2):
test_acc = test(model, test_loader, dataset_name)
print(f"-- Test Accuracy: {test_acc:.4f}")
print("-- Shutting down ...")
glt.distributed.shutdown_client()