Skip to content

Commit

Permalink
add cpu sampling as an option
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Jan 17, 2024
1 parent a067a89 commit 1786338
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions examples/multigpu/graphbolt/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,14 @@ def create_dataloader(
# [Output]:
# A CopyTo object copying data in the datapipe to a specified device.\
############################################################################
datapipe = datapipe.copy_to(device, extra_attrs=["seed_nodes"])
if not args.cpu_sampling:
datapipe = datapipe.copy_to(device, extra_attrs=["seed_nodes"])
datapipe = datapipe.sample_neighbor(graph, args.fanout)
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
if args.cpu_sampling:
datapipe = datapipe.copy_to(device)

dataloader = gb.DataLoader(datapipe)
dataloader = gb.DataLoader(datapipe, args.num_workers)

# Return the fully-initialized DataLoader object.
return dataloader
Expand Down Expand Up @@ -273,8 +276,9 @@ def run(rank, world_size, args, devices, dataset):
)

# Pin the graph and features to enable GPU access.
dataset.graph.pin_memory_()
dataset.feature.pin_memory_()
if not args.cpu_sampling:
dataset.graph.pin_memory_()
dataset.feature.pin_memory_()

train_set = dataset.tasks[0].train_set
valid_set = dataset.tasks[0].validation_set
Expand Down Expand Up @@ -386,6 +390,14 @@ def parse_args():
help="Fan-out of neighbor sampling. It is IMPORTANT to keep len(fanout)"
" identical with the number of layers in your model. Default: 15,10,5",
)
parser.add_argument(
"--num-workers", type=int, default=0, help="The number of processes."
)
parser.add_argument(
"--cpu-sampling",
action="store_true",
help="Disables GPU sampling and utilizes the CPU for dataloading.",
)
return parser.parse_args()


Expand Down

0 comments on commit 1786338

Please sign in to comment.