Tutorial: Training a Node Classification Model (PyG) on a K8S Cluster

This tutorial presents a server-client example that illustrates how GraphScope trains the GraphSAGE model (implemented in PyG) for a node classification task on a Kubernetes cluster.

Set parameters & load graph

import graphscope as gs
from graphscope.dataset import load_ogbn_arxiv

gs.set_option(log_level="DEBUG")
gs.set_option(show_log=True)

params = {
    "NUM_SERVER_NODES": 2,
    "NUM_CLIENT_NODES": 2,
}

# load the ogbn_arxiv graph as an example.
sess = gs.session(
    with_dataset=True,
    k8s_service_type="NodePort",
    k8s_vineyard_mem="8Gi",
    k8s_engine_mem="8Gi",
    vineyard_shared_mem="8Gi",
    k8s_image_pull_policy="IfNotPresent",
    k8s_image_tag="0.26.0a20240115-x86_64",
    num_workers=params["NUM_SERVER_NODES"],
)
g = load_ogbn_arxiv(sess=sess, prefix="/dataset/ogbn_arxiv")

Launch the Server Engine

glt_graph = gs.graphlearn_torch(
    g,
    edges=[
        ("paper", "citation", "paper"),
    ],
    node_features={
        "paper": [f"feat_{i}" for i in range(128)],
    },
    node_labels={
        "paper": "label",
    },
    edge_dir="out",
    random_node_split={
        "num_val": 0.1,
        "num_test": 0.1,
    },
    num_clients=params["NUM_CLIENT_NODES"],
    # Specify the client yaml with the client pods' configuration.
    manifest_path="./client.yaml",
    # Specify the client folder path that contains the client scripts.
    client_folder_path="./",
)

print("Exiting...")

Configure the parameters for client pods

apiVersion: "kubeflow.org/v1"
kind: PyTorchJob
metadata:
  name: graphlearn-torch-client
  namespace: default
spec:
  pytorchReplicaSpecs:
    Master:
      replicas: 1
      restartPolicy: OnFailure
      template:
        spec:
          containers:
            - name: pytorch
              image: registry.cn-hongkong.aliyuncs.com/graphscope/graphlearn-torch-client:0.26.0a20240115-x86_64
              imagePullPolicy: IfNotPresent
              command:
                - bash
                - -c
                - |- 
                  python3 /workspace/client.py --node_rank 0 --master_addr ${MASTER_ADDR} --num_server_nodes ${NUM_SERVER_NODES} --num_client_nodes ${NUM_CLIENT_NODES}
              volumeMounts:
              - mountPath: /dev/shm
                name: cache-volume
              - mountPath: /workspace
                name: client-volume
          volumes:
            - name: cache-volume
              emptyDir:
                medium: Memory
                sizeLimit: "8G"
            - name: client-volume
              configMap:
                name: graphlearn-torch-client-config
    Worker:
      replicas: ${NUM_WORKER_REPLICAS}
      restartPolicy: OnFailure
      template:
        spec:
          containers:
            - name: pytorch
              image: registry.cn-hongkong.aliyuncs.com/graphscope/graphlearn-torch-client:0.26.0a20240115-x86_64
              imagePullPolicy: IfNotPresent
              command:
                - bash
                - -c
                - |-
                  python3 /workspace/client.py --node_rank $((${MY_POD_NAME: -1}+1)) --master_addr ${MASTER_ADDR} --group_master ${GROUP_MASTER} --num_server_nodes ${NUM_SERVER_NODES} --num_client_nodes ${NUM_CLIENT_NODES}
              env:
                - name: GROUP_MASTER
                  value: graphlearn-torch-client-master-0
                - name: MY_POD_NAME
                  valueFrom:
                    fieldRef:
                      fieldPath: metadata.name
              volumeMounts:
              - mountPath: /dev/shm
                name: cache-volume
              - mountPath: /workspace
                name: client-volume
          volumes:
            - name: cache-volume
              emptyDir:
                medium: Memory
                sizeLimit: "8G"
            - name: client-volume
              configMap:
                name: graphlearn-torch-client-config

Write training and testing script

Import packages


import argparse
import time
from typing import List

import torch
import torch.nn.functional as F
from torch.distributed.algorithms.join import Join
from torch.nn.parallel import DistributedDataParallel
from torch_geometric.nn import GraphSAGE

import graphscope as gs
import graphscope.learning.graphlearn_torch as glt
from graphscope.learning.gl_torch_graph import GLTorchGraph
from graphscope.learning.graphlearn_torch.typing import Split

gs.set_option(log_level="DEBUG")
gs.set_option(show_log=True)

Define test function

@torch.no_grad()
def test(model, test_loader, dataset_name):
    model.eval()
    xs = []
    y_true = []
    for i, batch in enumerate(test_loader):
        if i == 0:
            device = batch.x.device
        batch.x = batch.x.to(torch.float32)  # TODO
        x = model.module(batch.x, batch.edge_index)[: batch.batch_size]
        xs.append(x.cpu())
        y_true.append(batch.y[: batch.batch_size].cpu())
        del batch

    xs = [t.to(device) for t in xs]
    y_true = [t.to(device) for t in y_true]
    y_pred = torch.cat(xs, dim=0).argmax(dim=-1, keepdim=True)
    y_true = torch.cat(y_true, dim=0)
    test_acc = sum((y_pred.T == y_true.T)[0]) / len(y_true.T)

    return test_acc.item()

Define the loader and training process

def run_client_proc(
    glt_graph,
    group_master: str,
    num_servers: int,
    num_clients: int,
    client_rank: int,
    server_rank_list: List[int],
    dataset_name: str,
    epochs: int,
    batch_size: int,
    training_pg_master_port: int,
):

    print("-- Initializing client ...")
    glt.distributed.init_client(
        num_servers=num_servers,
        num_clients=num_clients,
        client_rank=client_rank,
        master_addr=glt_graph.master_addr,
        master_port=glt_graph.server_client_master_port,
        num_rpc_threads=4,
        client_group_name="k8s_glt_client",
        is_dynamic=True,
    )

    # Initialize training process group of PyTorch.
    current_ctx = glt.distributed.get_context()

    torch.distributed.init_process_group(
        backend="gloo",
        rank=current_ctx.rank,
        world_size=current_ctx.world_size,
        init_method="tcp://{}:{}".format(group_master, training_pg_master_port),
    )

    device = torch.device("cpu")
    # Create distributed neighbor loader on remote server for training.
    print("-- Creating training dataloader ...")
    train_loader = glt.distributed.DistNeighborLoader(
        data=None,
        num_neighbors=[5, 3, 2],
        input_nodes=Split.train,
        batch_size=batch_size,
        shuffle=True,
        collect_features=True,
        to_device=device,
        worker_options=glt.distributed.RemoteDistSamplingWorkerOptions(
            server_rank=server_rank_list,
            num_workers=1,
            worker_devices=[torch.device("cpu")],
            worker_concurrency=1,
            buffer_size="256MB",
            prefetch_size=1,
            glt_graph=glt_graph,
            workload_type="train",
        ),
    )

    # Create distributed neighbor loader on remote server for testing.
    print("-- Creating testing dataloader ...")
    test_loader = glt.distributed.DistNeighborLoader(
        data=None,
        num_neighbors=[5, 3, 2],
        input_nodes=Split.test,
        batch_size=batch_size,
        shuffle=False,
        collect_features=True,
        to_device=device,
        worker_options=glt.distributed.RemoteDistSamplingWorkerOptions(
            server_rank=server_rank_list,
            num_workers=1,
            worker_devices=[torch.device("cpu")],
            worker_concurrency=1,
            buffer_size="256MB",
            prefetch_size=1,
            glt_graph=glt_graph,
            workload_type="test",
        ),
    )

    # Define model and optimizer.
    print("-- Initializing model and optimizer ...")
    model = GraphSAGE(
        in_channels=128,
        hidden_channels=128,
        num_layers=3,
        out_channels=47,
    ).to(device)
    model = DistributedDataParallel(model, device_ids=None)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # Train and test.
    print("-- Start training and testing ...")
    epochs = 10
    dataset_name = "ogbn-arxiv"
    for epoch in range(0, epochs):
        model.train()
        start = time.time()
        with Join([model]):
            for batch in train_loader:
                optimizer.zero_grad()
                batch.x = batch.x.to(torch.float32)  # TODO
                out = model(batch.x, batch.edge_index)[: batch.batch_size].log_softmax(
                    dim=-1
                )
                loss = F.nll_loss(out, torch.flatten(batch.y[: batch.batch_size]))
                loss.backward()
                optimizer.step()

        end = time.time()
        print(f"-- Epoch: {epoch:03d}, Loss: {loss:04f} Epoch Time: {end - start}")
        torch.distributed.barrier()
        # Test accuracy.
        if epoch == 0 or epoch > (epochs // 2):
            test_acc = test(model, test_loader, dataset_name)
            print(f"-- Test Accuracy: {test_acc:.4f}")
            torch.distributed.barrier()

    print("-- Shutdowning ...")
    glt.distributed.shutdown_client()

    print("-- Exited ...")

main function

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Arguments for distributed training of supervised SAGE with servers."
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default="ogbn-arxiv",
        help="The name of ogbn arxiv.",
    )
    parser.add_argument(
        "--num_server_nodes",
        type=int,
        default=2,
        help="Number of server nodes for remote sampling.",
    )
    parser.add_argument(
        "--num_client_nodes",
        type=int,
        default=1,
        help="Number of client nodes for training.",
    )
    parser.add_argument(
        "--node_rank",
        type=int,
        default=0,
        help="The node rank of the current role.",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=10,
        help="The number of training epochs. (client option)",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=256,
        help="Batch size for the training and testing dataloader.",
    )
    parser.add_argument(
        "--training_pg_master_port",
        type=int,
        default=9997,
        help="The port used for PyTorch's process group initialization across all training processes.",
    )
    parser.add_argument(
        "--train_loader_master_port",
        type=int,
        default=9998,
        help="The port used for RPC initialization across all sampling workers of training loader.",
    )
    parser.add_argument(
        "--test_loader_master_port",
        type=int,
        default=9999,
        help="The port used for RPC initialization across all sampling workers of testing loader.",
    )
    parser.add_argument(
        "--master_addr",
        type=str,
        default="localhost",
        help="The master address of the graphlearn server.",
    )
    parser.add_argument(
        "--group_master",
        type=str,
        default="localhost",
        help="The master address of the training process group.",
    )
    args = parser.parse_args()

    print(
        f"--- Distributed training example of supervised SAGE with server-client mode. Client {args.node_rank} ---"
    )
    print(f"* dataset: {args.dataset}")
    print(f"* total server nodes: {args.num_server_nodes}")
    print(f"* total client nodes: {args.num_client_nodes}")
    print(f"* node rank: {args.node_rank}")

    num_servers = args.num_server_nodes
    num_clients = args.num_client_nodes

    print(f"* epochs: {args.epochs}")
    print(f"* batch size: {args.batch_size}")
    print(f"* training process group master port: {args.training_pg_master_port}")
    print(f"* training loader master port: {args.train_loader_master_port}")
    print(f"* testing loader master port: {args.test_loader_master_port}")

    client_rank = args.node_rank
    print("--- Loading graph info ...")
    glt_graph = GLTorchGraph(
        [
            args.master_addr + ":9001",
            args.master_addr + ":9002",
            args.master_addr + ":9003",
            args.master_addr + ":9004",
        ]
    )
    print("--- Launching client processes ...")
    run_client_proc(
        glt_graph,
        args.group_master,
        num_servers,
        num_clients,
        client_rank,
        [server_rank for server_rank in range(num_servers)],
        args.dataset,
        args.epochs,
        args.batch_size,
        args.training_pg_master_port,
    )

Run the script

python3 k8s_launch.py