diff --git a/examples/multigpu/graphbolt/node_classification.py b/examples/multigpu/graphbolt/node_classification.py index 9349be4b74c4..5ef93311fe55 100644 --- a/examples/multigpu/graphbolt/node_classification.py +++ b/examples/multigpu/graphbolt/node_classification.py @@ -134,11 +134,14 @@ def create_dataloader( # [Output]: # A CopyTo object copying data in the datapipe to a specified device.\ ############################################################################ - datapipe = datapipe.copy_to(device, extra_attrs=["seed_nodes"]) + if not args.cpu_sampling: + datapipe = datapipe.copy_to(device, extra_attrs=["seed_nodes"]) datapipe = datapipe.sample_neighbor(graph, args.fanout) datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"]) + if args.cpu_sampling: + datapipe = datapipe.copy_to(device) - dataloader = gb.DataLoader(datapipe) + dataloader = gb.DataLoader(datapipe, args.num_workers) # Return the fully-initialized DataLoader object. return dataloader @@ -273,8 +276,9 @@ def run(rank, world_size, args, devices, dataset): ) # Pin the graph and features to enable GPU access. - dataset.graph.pin_memory_() - dataset.feature.pin_memory_() + if not args.cpu_sampling: + dataset.graph.pin_memory_() + dataset.feature.pin_memory_() train_set = dataset.tasks[0].train_set valid_set = dataset.tasks[0].validation_set @@ -386,6 +390,14 @@ def parse_args(): help="Fan-out of neighbor sampling. It is IMPORTANT to keep len(fanout)" " identical with the number of layers in your model. Default: 15,10,5", ) + parser.add_argument( + "--num-workers", type=int, default=0, help="The number of processes." + ) + parser.add_argument( + "--cpu-sampling", + action="store_true", + help="Disables GPU sampling and utilizes the CPU for dataloading.", + ) return parser.parse_args()