Skip to content

Commit

Permalink
dataloader tests
Browse files Browse the repository at this point in the history
  • Loading branch information
MoustHolmes committed May 4, 2023
1 parent 3c40fce commit 63dcebb
Show file tree
Hide file tree
Showing 8 changed files with 513 additions and 67 deletions.
2 changes: 1 addition & 1 deletion configs/data/upgrade_energy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ target_cols:
truth_table: truth
max_token_count: 4096 #16384
num_workers: 16
multi_processing_reading_service_num_workers: 4
# multi_processing_reading_service_num_workers: 4
# data_dir: ${paths.data_dir}
# batch_size: 256
# # train_val_test_split: [879064,109884,109884] #[0.8, 0.1, 0.1]
Expand Down
19 changes: 9 additions & 10 deletions configs/experiment/opt_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,23 @@ seed: 12345

trainer:
min_epochs: 1
max_epochs: 2
max_epochs: 3
accelerator: gpu
precision: 16-mixed
limit_train_batches: 1000
limit_val_batches: 100
limit_test_batches: 100

limit_train_batches: 20
limit_val_batches: 10
limit_test_batches: 10
model:
model:
num_layers: 17
nhead: 4
num_layers: 2
nhead: 2
d_model: 64
dim_feedforward: 256

data:
max_token_count: 16384
num_workers: 16
multi_processing_reading_service_num_workers: 16
max_token_count: 50000
num_workers: 4


logger:
Expand Down
229 changes: 229 additions & 0 deletions dataloader2_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
from typing import Any, Dict, Optional, Tuple, List, Callable, Union
import numpy as np

import pandas as pd
import torch
import torch.nn as nn
from lightning import LightningDataModule
from torch.utils.data import ConcatDataset, DataLoader, Dataset, SequentialSampler #random_split
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
from torch.utils.data import Dataset, DataLoader, random_split
from torch import Tensor
import sqlite3
import math
from torch import default_generator, randperm
from torch._utils import _accumulate
from torch.utils.data.dataset import Subset

from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService
import torchdata.datapipes as dp
from torchdata.datapipes import functional_datapipe
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe, Mapper, MaxTokenBucketizer, ShardingFilter


@functional_datapipe("read_csv")
class ReadCSV(IterDataPipe):
def __init__(self, csv_file):
self.csv_file = csv_file

def __iter__(self):
with open(self.csv_file, "r") as f:
for line in f:
yield int(line.strip())

@functional_datapipe("read_csv_dp")
class ReadCSVMultiple(IterDataPipe):
def __init__(self, datapipe):
self.datapipe = datapipe

def __iter__(self):
for csv_file_path in self.datapipe:
with open(csv_file_path, "r") as f:
for line in f:
yield int(line.strip())

@functional_datapipe("query_sql")
class QuerySQL(IterDataPipe):
def __init__(self, datapipe, db_path, input_cols, pulsemap, target_cols, truth_table):

self.datapipe = datapipe
self.db_path = db_path
self.input_cols_str = ", ".join(input_cols)
self.target_cols_str = ", ".join(target_cols)
self.pulsemap = pulsemap
self.truth_table = truth_table

def __iter__(self):
with sqlite3.connect(self.db_path) as conn:
for event_no in self.datapipe:
features = torch.Tensor(conn.execute(f"SELECT {self.input_cols_str} FROM {self.pulsemap} WHERE event_no == {event_no}").fetchall())
truth = torch.Tensor(conn.execute(f"SELECT {self.target_cols_str} FROM {self.truth_table} WHERE event_no == {event_no}").fetchall())
yield (features, truth)

def upgrade_transform_func(x):
features, truth = x
features[:, 0] = torch.log10(features[:, 0]) / 2.0 # charge
features[:, 1] /= 2e04 # dom_time
features[:, 1] -= 1.0
features[:, 2] /= 500.0 # dom_x
features[:, 3] /= 500.0 # dom_y
features[:, 4] /= 500.0 # dom_z
features[:, 5] /= 0.05 # pmt_area
# features[:,6] /= 1. # pmt_dir_x
# features[:,7] /= 1. # pmt_dir_y
# features[:,8] /= 1. # pmt_dir_z
truth = torch.log10(truth)
return (features, truth)

@functional_datapipe("transform_data")
class TransfromData(IterDataPipe):
def __init__(self, datapipe, feature_transform, truth_transform = None):
self.datapipe = datapipe
# self.input_cols = input_cols
# self.target_cols = target_cols
self.feature_transform = feature_transform

if not truth_transform:
self.truth_transform = lambda features : features
else:
self.truth_transform = truth_transform


def __iter__(self):
for features, truth in self.datapipe:
features = self.feature_transform(features)
truth = self.truth_transform(truth)

yield (features, truth)


def upgrade_feature_transform(features):
features[:, 0] = torch.log10(features[:, 0]) / 2.0 # charge
features[:, 1] /= 2e04 # dom_time
features[:, 1] -= 1.0
features[:, 2] /= 500.0 # dom_x
features[:, 3] /= 500.0 # dom_y
features[:, 4] /= 500.0 # dom_z
features[:, 5] /= 0.05 # pmt_area
# features[:,6] /= 1. # pmt_dir_x
# features[:,7] /= 1. # pmt_dir_y
# features[:,8] /= 1. # pmt_dir_z
return features


def Prometheus_feature_transform(features):
features[:, 0] /= 100.0 # dom_x
features[:, 1] /= 100.0 # dom_y
features[:, 2] += 350.0 # dom_z
features[:, 2] /= 100.0
features[:, 3] /= 1.05e04 # dom_time
features[:, 3] -= 1.0
features[:, 3] *= 20.0
return features

def log10_target_transform(target):
return torch.log10(target)


@functional_datapipe("pad_batch")
class PadBatch(IterDataPipe):
def __init__(self, batch):
self.batch = batch

def __iter__(self):
for batch in self.batch:

(xx, y) = zip(*batch)
x_lens = [len(x) for x in xx]
xx_pad = pad_sequence(xx, batch_first=True, padding_value=0)

pad_mask = torch.zeros_like(xx_pad[:, :, 0]).type(torch.bool)

for i, length in enumerate(x_lens):
pad_mask[i, length:] = True

yield (xx_pad, torch.tensor(y), pad_mask)


def len_fn(datapipe):
features, _ = datapipe
return features.shape[0]

def make_datapipe(
csv_file,
db_path,
input_cols,
pulsemap,
target_cols,
truth_table,
max_token_count,
feature_transform,
truth_transform = None
):#.sharding_filter()
datapipe = ReadCSV( csv_file) \
.query_sql(
db_path = db_path,
input_cols = input_cols,
pulsemap = pulsemap,
target_cols = target_cols,
truth_table = truth_table,
) \
.map(upgrade_transform_func
) \
.max_token_bucketize(
max_token_count = max_token_count,
len_fn = len_fn,
include_padding = True
) \
.pad_batch()
return datapipe

def make_train_test_val_datapipe(
train_csv_file,
test_csv_file,
val_csv_file,
db_path,
input_cols,
pulsemap,
target_cols,
truth_table,
max_token_count,
feature_transform,
truth_transform = None
):
train_datapipe = make_datapipe(
train_csv_file,
db_path,
input_cols,
pulsemap,
target_cols,
truth_table,
max_token_count,
feature_transform,
truth_transform
)
test_datapipe = make_datapipe(
test_csv_file,
db_path,
input_cols,
pulsemap,
target_cols,
truth_table,
max_token_count,
feature_transform,
truth_transform
)
val_datapipe = make_datapipe(
val_csv_file,
db_path,
input_cols,
pulsemap,
target_cols,
truth_table,
max_token_count,
feature_transform,
truth_transform
)
return train_datapipe, test_datapipe, val_datapipe

datapipe = make_datapipe()
24 changes: 4 additions & 20 deletions lightning_no_config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,35 +6,20 @@
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.utils.data import random_split
import pytorch_lightning as pl
from lightning import LightningDataModule, LightningModule, Trainer
import torchmetrics
from torchmetrics import Metric


class MyAccuracy(Metric):
def __init__(self):
super().__init__()
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")

def update(self, preds, target):
preds = torch.argmax(preds, dim=1)
assert preds.shape == target.shape
self.correct += torch.sum(preds == target)
self.total += target.numel()

def compute(self):
return self.correct.float() / self.total.float()


class NN(pl.LightningModule):
class NN(LightningModule):
def __init__(self, input_size, num_classes):
super().__init__()
self.fc1 = nn.Linear(input_size, 50)
self.fc2 = nn.Linear(50, num_classes)
self.loss_fn = nn.CrossEntropyLoss()
self.accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
self.my_accuracy = MyAccuracy()
self.f1_score = torchmetrics.F1Score(task="multiclass", num_classes=num_classes)

def forward(self, x):
Expand All @@ -44,7 +29,6 @@ def forward(self, x):

def training_step(self, batch, batch_idx):
loss, scores, y = self._common_step(batch, batch_idx)
accuracy = self.my_accuracy(scores, y)
f1_score = self.f1_score(scores, y)
self.log_dict({'train_loss': loss, 'train_accuracy': accuracy, 'train_f1_score': f1_score},
on_step=False, on_epoch=True, prog_bar=True)
Expand Down Expand Up @@ -78,7 +62,7 @@ def configure_optimizers(self):
return optim.Adam(self.parameters(), lr=0.001)


class MnistDataModule(pl.LightningDataModule):
class MnistDataModule(LightningDataModule):
def __init__(self, data_dir, batch_size, num_workers):
super().__init__()
self.data_dir = data_dir
Expand Down Expand Up @@ -140,7 +124,7 @@ def test_dataloader(self):

model = NN(input_size=input_size, num_classes=num_classes)
dm = MnistDataModule(data_dir="dataset/", batch_size=batch_size, num_workers=4)
trainer = pl.Trainer(accelerator="gpu", devices=1, min_epochs=1, max_epochs=3, precision=16)
trainer = Trainer(accelerator="gpu", devices=1, min_epochs=1, max_epochs=3, precision=16)
trainer.fit(model, dm)
trainer.validate(model, dm)
trainer.test(model, dm)
Loading

0 comments on commit 63dcebb

Please sign in to comment.