Skip to content

Commit

Permalink
Refactor: Add Numpy style documentation in dataset_classes.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sshivam95 committed Mar 28, 2024
1 parent f12c476 commit cbe6da0
Showing 1 changed file with 157 additions and 60 deletions.
217 changes: 157 additions & 60 deletions dicee/dataset_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,35 @@ def __getitem__(self, idx):


class NegSampleDataset(torch.utils.data.Dataset):
"""
A dataset for training knowledge graph embedding models using negative sampling.
For each positive triple from the knowledge graph, a negative triple is generated by corrupting either
the head or the tail entity with a randomly selected entity.
Parameters
----------
train_set : np.ndarray
The training set of triples, where each triple consists of indices of the head entity, relation, and tail entity.
num_entities : int
The total number of unique entities in the knowledge graph.
num_relations : int
The total number of unique relations in the knowledge graph.
neg_sample_ratio : int, default=1
The ratio of negative samples to positive samples. Currently, it generates one negative sample per positive sample.
Attributes
----------
train_set : torch.Tensor
The training set converted to a PyTorch tensor and expanded to include a batch dimension.
length : int
The total number of triples in the training set.
num_entities : torch.tensor
A tensor containing the total number of entities.
num_relations : torch.tensor
A tensor containing the total number of relations.
neg_sample_ratio : torch.tensor
A tensor containing the ratio of negative to positive samples.
"""
def __init__(self, train_set: np.ndarray, num_entities: int, num_relations: int, neg_sample_ratio: int = 1):
assert isinstance(train_set, np.ndarray)
# https://pytorch.org/docs/stable/data.html#multi-process-data-loading
Expand All @@ -913,10 +942,32 @@ def __init__(self, train_set: np.ndarray, num_entities: int, num_relations: int,
self.num_entities = torch.tensor(num_entities)
self.num_relations = torch.tensor(num_relations)

def __len__(self):
def __len__(self) -> int:
"""
Returns the total number of triples in the dataset.
Returns
-------
int
The total number of triples.
"""
return self.length

def __getitem__(self, idx):
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Retrieves a pair consisting of a positive triple and a generated negative triple along with their labels.
Parameters
----------
idx : int
The index of the triple to retrieve.
Returns
-------
Tuple[torch.Tensor, torch.Tensor]
A tuple where the first element is a tensor containing a pair of positive and negative triples,
and the second element is a tensor containing their respective labels (1 for positive, 0 for negative).
"""
# (1) Get a triple.
triple = self.train_set[idx]
# (2) Sample an entity.
Expand All @@ -937,35 +988,39 @@ def __getitem__(self, idx):

class TriplePredictionDataset(torch.utils.data.Dataset):
"""
Triple Dataset
D:= {(x)_i}_i ^N, where
. x:(h,r, t) \in KG is a unique h \in E and a relation r \in R and
. collact_fn => Generates negative triples
collect_fn: \forall (h,r,t) \in G obtain, create negative triples{(h,r,x),(,r,t),(h,m,t)}
y:labels are represented in torch.float16
Parameters
----------
train_set_idx
Indexed triples for the training.
entity_idxs
mapping.
relation_idxs
mapping.
form
?
store
?
label_smoothing_rate
collate_fn: batch:List[torch.IntTensor]
Returns
-------
torch.utils.data.Dataset
"""
A dataset for triple prediction using negative sampling and label smoothing.
D:= {(x)_i}_i ^N, where
- x:(h,r, t) \in KG is a unique h \in E and a relation r \in R and
- collact_fn => Generates negative triples
collect_fn: \forall (h,r,t) \in G obtain, create negative triples{(h,r,x),(,r,t),(h,m,t)}
y: labels are represented in torch.float16
This dataset generates negative triples by corrupting either the head or the tail of each positive triple
from the training set. The corruption is performed by randomly replacing the head or the tail with another entity
from the entity set. The dataset supports label smoothing to soften the target labels, which can help improve
generalization.
Parameters
----------
train_set : np.ndarray
The training set consisting of triples in the form of (head, relation, tail) indices.
num_entities : int
The total number of unique entities in the knowledge graph.
num_relations : int
The total number of unique relations in the knowledge graph.
neg_sample_ratio : int, optional
The ratio of negative samples to generate for each positive sample. Default is 1.
label_smoothing_rate : float, optional
The rate of label smoothing to apply to the target labels. Default is 0.0.
Notes
-----
The `collate_fn` should be passed to the DataLoader's `collate_fn` argument to ensure proper
batch processing and negative sample generation.
"""

@timeit
def __init__(self, train_set: np.ndarray, num_entities: int, num_relations: int, neg_sample_ratio: int = 1,
Expand All @@ -982,13 +1037,47 @@ def __init__(self, train_set: np.ndarray, num_entities: int, num_relations: int,
self.num_entities = torch.tensor(num_entities)
self.num_relations = torch.tensor(num_relations)

def __len__(self):
def __len__(self) -> int:
"""
Returns the total number of triples in the dataset.
Returns
-------
int
The total number of triples.
"""
return self.length

def __getitem__(self, idx):
def __getitem__(self, idx: int) -> torch.Tensor:
"""
Retrieves a triple for the given index.
Parameters
----------
idx : int
The index of the triple to retrieve.
Returns
-------
torch.Tensor
The triple at the specified index.
"""
return self.train_set[idx]

def collate_fn(self, batch: List[torch.Tensor]):
def collate_fn(self, batch: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Custom collate function to generate a batch of positive and negative triples along with their labels.
Parameters
----------
batch : List[torch.Tensor]
A list of tensors representing triples.
Returns
-------
Tuple[torch.Tensor, torch.Tensor]
A tuple containing a tensor of triples and a tensor of corresponding labels.
"""
batch = torch.stack(batch, dim=0)
h, r, t = batch[:, 0], batch[:, 1], batch[:, 2]
size_of_batch, _ = batch.shape
Expand Down Expand Up @@ -1049,32 +1138,32 @@ def collate_fn(self, batch: List[torch.Tensor]):

class CVDataModule(pl.LightningDataModule):
"""
Create a Dataset for cross validation
Parameters
----------
train_set_idx
Indexed triples for the training.
num_entities
entity to index mapping.
num_relations
relation to index mapping.
batch_size
int
form
?
num_workers
int for https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
Returns
-------
?
"""

def __init__(self, train_set_idx: np.ndarray, num_entities, num_relations, neg_sample_ratio, batch_size,
num_workers):
A LightningDataModule for setting up data loaders for cross-validation training of knowledge graph embedding models.
Parameters
----------
train_set_idx : np.ndarray
An array of indexed triples for training, where each triple consists of indices of the head entity, relation,
and tail entity.
num_entities : int
The total number of unique entities in the knowledge graph.
num_relations : int
The total number of unique relations in the knowledge graph.
neg_sample_ratio : int
The ratio of negative samples to positive samples for each positive triple.
batch_size : int
The number of samples in each batch of data.
num_workers : int
The number of subprocesses to use for data loading. https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
Returns
-------
DataLoader
A PyTorch DataLoader for the training dataset.
"""

def __init__(self, train_set_idx: np.ndarray, num_entities: int, num_relations: int, neg_sample_ratio: int,
batch_size: int, num_workers: int):
super().__init__()
assert isinstance(train_set_idx, np.ndarray)
self.train_set_idx = train_set_idx
Expand All @@ -1085,6 +1174,14 @@ def __init__(self, train_set_idx: np.ndarray, num_entities, num_relations, neg_s
self.num_workers = num_workers

def train_dataloader(self) -> DataLoader:
"""
Creates a DataLoader for the training dataset.
Returns
-------
DataLoader
A DataLoader object that loads the training data.
"""
train_set = TriplePredictionDataset(self.train_set_idx,
num_entities=self.num_entities,
num_relations=self.num_relations,
Expand Down

0 comments on commit cbe6da0

Please sign in to comment.