-
Notifications
You must be signed in to change notification settings - Fork 4
/
data.py
40 lines (32 loc) · 1.44 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import numpy as np
import scipy.io as scio
import torch.utils.data as data
import pickle
import torch
from dataprocess.arilprocess import ARILData
from dataprocess.wiarprocess import WiARData
from dataprocess.hthidata import HTHIData
from dataprocess_gaussian.arilprocess_gaussian import ARILDataGaussian
from dataprocess_gaussian.hthiprocess_gaussian import HTHIDataGaussian
from dataprocess_gaussian.wiarprocess_guassian import WiARDataGaussian
def getdataloader(dataset_name, filepath, batch_size, trainortest, shuffle=True, detection_gaussian=False):
dataset = None
if dataset_name not in ['ARIL', 'WiAR', 'HTHI']:
raise ValueError("Dataset name error, expected to enter WiAR, ARIL, or HTHI")
if dataset_name == 'ARIL':
if detection_gaussian == "Yes":
dataset = ARILDataGaussian(filepath, trainortest)
else:
dataset = ARILData(filepath, trainortest)
elif dataset_name == 'WiAR':
if detection_gaussian == "Yes":
dataset = WiARDataGaussian(filepath, trainortest)
else:
dataset = WiARData(filepath, trainortest)
else:
if detection_gaussian == "Yes":
dataset = HTHIDataGaussian(filepath, trainortest)
else:
dataset = HTHIData(filepath, trainortest)
data_loader = data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=shuffle, drop_last=True)
return data_loader