diff --git a/examples/multigpu/graphbolt/node_classification.py b/examples/multigpu/graphbolt/node_classification.py index 2d3344ce301c..5ef93311fe55 100644 --- a/examples/multigpu/graphbolt/node_classification.py +++ b/examples/multigpu/graphbolt/node_classification.py @@ -126,9 +126,6 @@ def create_dataloader( shuffle=shuffle, drop_uneven_inputs=drop_uneven_inputs, ) - datapipe = datapipe.sample_neighbor(graph, args.fanout) - datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"]) - ############################################################################ # [Note]: # datapipe.copy_to() / gb.CopyTo() @@ -137,8 +134,14 @@ def create_dataloader( # [Output]: # A CopyTo object copying data in the datapipe to a specified device.\ ############################################################################ - datapipe = datapipe.copy_to(device) - dataloader = gb.DataLoader(datapipe, num_workers=args.num_workers) + if not args.cpu_sampling: + datapipe = datapipe.copy_to(device, extra_attrs=["seed_nodes"]) + datapipe = datapipe.sample_neighbor(graph, args.fanout) + datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"]) + if args.cpu_sampling: + datapipe = datapipe.copy_to(device) + + dataloader = gb.DataLoader(datapipe, args.num_workers) # Return the fully-initialized DataLoader object. return dataloader @@ -272,15 +275,18 @@ def run(rank, world_size, args, devices, dataset): rank=rank, ) - graph = dataset.graph - features = dataset.feature + # Pin the graph and features to enable GPU access. + if not args.cpu_sampling: + dataset.graph.pin_memory_() + dataset.feature.pin_memory_() + train_set = dataset.tasks[0].train_set valid_set = dataset.tasks[0].validation_set test_set = dataset.tasks[0].test_set args.fanout = list(map(int, args.fanout.split(","))) num_classes = dataset.tasks[0].metadata["num_classes"] - in_size = features.size("node", None, "feat")[0] + in_size = dataset.feature.size("node", None, "feat")[0] hidden_size = 256 out_size = num_classes @@ -291,8 +297,8 @@ def run(rank, world_size, args, devices, dataset): # Create data loaders. train_dataloader = create_dataloader( args, - graph, - features, + dataset.graph, + dataset.feature, train_set, device, drop_last=False, @@ -301,8 +307,8 @@ def run(rank, world_size, args, devices, dataset): ) valid_dataloader = create_dataloader( args, - graph, - features, + dataset.graph, + dataset.feature, valid_set, device, drop_last=False, @@ -311,8 +317,8 @@ def run(rank, world_size, args, devices, dataset): ) test_dataloader = create_dataloader( args, - graph, - features, + dataset.graph, + dataset.feature, test_set, device, drop_last=False, @@ -387,6 +393,11 @@ def parse_args(): parser.add_argument( "--num-workers", type=int, default=0, help="The number of processes." ) + parser.add_argument( + "--cpu-sampling", + action="store_true", + help="Disables GPU sampling and utilizes the CPU for dataloading.", + ) return parser.parse_args()