Skip to content

Commit

Permalink
Adding --num_workers input parameter to the EEG_GCNN example.
Browse files Browse the repository at this point in the history
  • Loading branch information
drivanov committed Oct 18, 2023
1 parent 9d5b897 commit 9bda3db
Showing 1 changed file with 10 additions and 4 deletions.
14 changes: 10 additions & 4 deletions examples/pytorch/eeg-gcnn/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ def _load_memory_mapped_array(file_name):
parser.add_argument(
"--num_nodes", type=int, default=8, help="Number of nodes in the graph"
)
parser.add_argument(
"--num_workers",
type=int,
default=4,
help="Number of epochs used to train",
)
parser.add_argument(
"--gpu_idx",
type=int,
Expand Down Expand Up @@ -149,7 +155,7 @@ def _load_memory_mapped_array(file_name):
# Dataloader========================================================================================================

# use WeightedRandomSampler to balance the training dataset
NUM_WORKERS = 4
_NUM_WORKERS = args.num_workers

labels_unique, counts = np.unique(y, return_counts=True)

Expand All @@ -172,7 +178,7 @@ def _load_memory_mapped_array(file_name):
dataset=train_dataset,
batch_size=_BATCH_SIZE,
sampler=weighted_sampler,
num_workers=NUM_WORKERS,
num_workers=_NUM_WORKERS,
pin_memory=True,
)

Expand All @@ -181,7 +187,7 @@ def _load_memory_mapped_array(file_name):
dataset=train_dataset,
batch_size=_BATCH_SIZE,
shuffle=False,
num_workers=NUM_WORKERS,
num_workers=_NUM_WORKERS,
pin_memory=True,
)

Expand All @@ -194,7 +200,7 @@ def _load_memory_mapped_array(file_name):
dataset=test_dataset,
batch_size=_BATCH_SIZE,
shuffle=False,
num_workers=NUM_WORKERS,
num_workers=_NUM_WORKERS,
pin_memory=True,
)

Expand Down

0 comments on commit 9bda3db

Please sign in to comment.