Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: addition and standardization of docstrings #176

Merged
merged 13 commits into from
Nov 28, 2024
14 changes: 14 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,20 @@ pre-commit install
pytest
```

## Docstring Writing Guidelines

When adding or updating functions or classes, please ensure that each has a docstring that follows this format:

- **Summary**: A brief description of what the function or class does.
- **args**: List each argument with its name, and a short description of its purpose.
- **return**: Describe the return value, including what it represents.
**Note**: After adding or updating the docstring, ensure that the code passes the following command with **no warnings**:

```bash
mkdocs build --clean --strict
```


## Conventional commits and Commitizen

We use [commitizen](https://commitizen-tools.github.io/commitizen/) to manage commits.
Expand Down
12 changes: 12 additions & 0 deletions whittle/models/gpt/blocks/causal_self_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@


class CausalSelfAttention(nn.Module):
"""Extension of litgpt's `litgpt.model.CausalSelfAttention` with support to adapt to sub-network dimensionality."""
zeqri marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, config: Config, block_idx: int) -> None:
super().__init__()
shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
Expand Down Expand Up @@ -48,6 +50,15 @@ def set_sub_network(
sub_network_query_groups: int,
sub_network_head_size: int,
):
"""
Sets the CausalSelfAttention block to the specified sub-network dimensionality.

Args:
sub_network_n_embd: Embedding dimension of the sub-network
sub_network_n_head: Number of attention heads in the sub-network
sub_network_query_groups: Number of query groups for grouped-query attention (GQA).
sub_network_head_size: Size of each attention head in the sub-network.
"""
self.sub_network_n_embd = sub_network_n_embd
self.sub_network_n_head = sub_network_n_head
self.sub_network_query_groups = sub_network_query_groups
Expand All @@ -73,6 +84,7 @@ def set_sub_network(
self.sub_attention_scaler = self.config.attention_scores_scalar

def reset_super_network(self):
"""Resets the dimensionality of the current sub-network to the super-network dimensionality."""
self.sub_network_n_embd = self.config.n_embd
self.sub_network_n_head = self.config.n_head
self.sub_network_head_size = self.config.head_size
Expand Down
22 changes: 22 additions & 0 deletions whittle/models/gpt/blocks/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@


class GptNeoxMLP(litgpt.model.GptNeoxMLP):
"""An extension of litgp's `litgpt.model.GptNeoxMLP` with support to adapt to sub-network dimensionality."""

def __init__(self, config: Config) -> None:
super().__init__(config)
self.fc = Linear(config.n_embd, config.intermediate_size, bias=config.bias)
Expand All @@ -24,6 +26,13 @@ def __init__(self, config: Config) -> None:
def set_sub_network(
self, sub_network_n_embd: int, sub_network_intermediate_size: int
):
"""
Sets the dimensionality of the current sub-network MLP layers.

Args:
sub_network_n_embd: Input and output embedding dimension of the sub-network.
sub_network_intermediate_size: Hidden layer dimension of the sub-network MLP.
"""
self.sub_network_n_embd = sub_network_n_embd
self.sub_network_intermediate_size = sub_network_intermediate_size

Expand All @@ -35,6 +44,7 @@ def set_sub_network(
)

def reset_super_network(self):
"""Resets the MLP dimensions to the original super-network dimensionality."""
self.sub_network_n_embd = self.in_features
self.sub_network_intermediate_size = self.intermediate_size

Expand All @@ -43,6 +53,8 @@ def reset_super_network(self):


class LLaMAMLP(litgpt.model.LLaMAMLP):
"""An extension of litgp's `litgpt.model.LLaMAMLP` with support to adapt to sub-network dimensionality."""

def __init__(self, config: Config) -> None:
super().__init__(config)
self.fc_1 = Linear(config.n_embd, config.intermediate_size, bias=config.bias)
Expand All @@ -57,6 +69,13 @@ def __init__(self, config: Config) -> None:
def set_sub_network(
self, sub_network_n_embd: int, sub_network_intermediate_size: int
):
"""
Sets the dimensionality of the current sub-network MLP layers.

Args:
sub_network_n_embd: Input and output embedding dimension of the sub-network.
sub_network_intermediate_size: Hidden layer dimension of the sub-network MLP.
"""
self.sub_network_n_embd = sub_network_n_embd
self.sub_network_intermediate_size = sub_network_intermediate_size

Expand All @@ -71,6 +90,7 @@ def set_sub_network(
)

def reset_super_network(self):
"""Reset the input dimensionality of the current sub-network to the super-network dimensionality."""
self.sub_network_n_embd = self.in_features
self.sub_network_intermediate_size = self.intermediate_size

Expand All @@ -80,6 +100,8 @@ def reset_super_network(self):


class GemmaMLP(LLaMAMLP):
"""Implementation of the forward pass of LLaMAMLP network."""

def __init__(self, config: Config) -> None:
super().__init__(config)

Expand Down
15 changes: 15 additions & 0 deletions whittle/models/gpt/blocks/transformer_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@


class Block(litgpt.model.Block):
"""An extension of litgpt's Transformer Block with support to adapt to sub-network dimensionality."""

def __init__(self, config: Config, block_idx: int) -> None:
super().__init__(config, block_idx)
self.config = config
Expand Down Expand Up @@ -69,6 +71,16 @@ def set_sub_network(
sub_network_query_groups: int,
sub_network_head_size: int,
) -> None:
"""
Set the Block to the specified sub-network dimensionality.

Args:
sub_network_n_embd: Embedding dimension of the sub-network.
sub_network_intermediate_size: Intermediate size of the sub-network.
sub_network_num_heads: Number of attention heads in the sub-network.
sub_network_query_groups: Number of query groups in the sub-network.
sub_network_head_size: Size of each attention head in the sub-network.
"""
self.sub_network_n_embd = sub_network_n_embd
self.sub_network_intermediate_size = sub_network_intermediate_size
self.sub_network_num_heads = sub_network_num_heads
Expand All @@ -94,6 +106,9 @@ def set_sub_network(
self.post_mlp_norm.set_sub_network(self.sub_network_n_embd)

def reset_super_network(self):
"""
Resets the layers in the Block to it's original super-network dimensionality.
"""
self.sub_network_n_embd = self.config.n_embd
self.sub_network_intermediate_size = self.config.intermediate_size
self.sub_network_num_heads = self.config.n_head
Expand Down
17 changes: 16 additions & 1 deletion whittle/models/gpt/extract.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations


import torch.nn as nn

from collections import OrderedDict
Expand All @@ -8,9 +9,23 @@
from whittle.models.gpt.blocks.mlp import GptNeoxMLP, LLaMAMLP
from whittle.modules.layernorm import LayerNorm
from whittle.modules.rmsnorm import RMSNorm
from litgpt import Config


def extract_sub_network(model: GPT, sub_network_config: Config) -> GPT:
"""
Extracts a sub-network from a given model based on the specified sub-network configuration.
Copies relevant layers, weights, and configurations from the full model into a sub-network model.

Args:
model: The original, full GPT model from which the sub-network is extracted.
sub_network_config: Configuration object for the sub-network, containing the necessary
architecture specifications such as embedding size, number of heads, and number of layers.

Returns:
A new sub-network model instance, initialized with parameters extracted from the original model.
"""

def extract_sub_network(model, sub_network_config):
sub_network = GPT(sub_network_config)

state_dict = extract_linear(model.lm_head)
Expand Down
22 changes: 21 additions & 1 deletion whittle/models/gpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@


class GPT(nn.Module):
"""An extension of litgpt's GPT model with support to adapt to sub-network dimensionality."""

def __init__(self, config: Config) -> None:
super().__init__()
assert config.padded_vocab_size is not None
Expand Down Expand Up @@ -165,6 +167,18 @@ def set_sub_network(
sub_network_query_groups: int | None = None,
sub_network_head_size: int | None = None,
) -> None:
"""
Sets the GPT model to the specified sub-network dimensionality.
Input arguments are set to the specified sub-network dimensionality.

Args:
sub_network_n_embd: Embedding dimension of the sub-network.
sub_network_intermediate_size: Intermediate size of the sub-network.
sub_network_num_heads: Number of attention heads in the sub-network.
sub_network_n_layers: Number of layers in the sub-network.
sub_network_query_groups: Number of query groups in the sub-network. Defaults to None.
sub_network_head_size: Size of each attention head in the sub-network. Defaults to None.
"""
self.sub_network_n_embd = sub_network_n_embd
self.sub_network_intermediate_size = sub_network_intermediate_size
self.sub_network_num_heads = sub_network_num_heads
Expand Down Expand Up @@ -204,7 +218,10 @@ def set_sub_network(
self.sub_network_n_embd, self.config.padded_vocab_size
)

def select_sub_network(self, config):
def select_sub_network(self, config: dict[str, Any]) -> None:
"""
Selects and sets the sub-network configuration based on the provided configuration.
"""
self.set_sub_network(
config["embed_dim"],
config["mlp_ratio"] * config["embed_dim"],
Expand All @@ -213,6 +230,9 @@ def select_sub_network(self, config):
)

def reset_super_network(self):
"""
Resets the GPT model to the original super-network dimensionality.
"""
self.sub_network_n_embd = self.config.n_embd
self.sub_network_intermediate_size = self.config.intermediate_size
self.sub_network_num_heads = self.config.n_head
Expand Down
3 changes: 2 additions & 1 deletion whittle/models/gpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,8 @@ def map_old_state_dict_weights(state_dict: dict, mapping: Mapping, prefix: str)


def get_default_supported_precision(training: bool) -> str:
"""Return default precision that is supported by the hardware: either `bf16` or `16`.
"""
Return default precision that is supported by the hardware: either `bf16` or `16`.

Args:
training: `-mixed` or `-true` version of the precision to use
Expand Down
4 changes: 4 additions & 0 deletions whittle/modules/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@


class Embedding(torch.nn.Embedding):
"An extension of PyTorch's torch.nn.Embedding with support to sub-sample weights corresponding to the sub-network dimensionality"

def __init__(
self,
num_embeddings: int,
Expand Down Expand Up @@ -34,9 +36,11 @@ def __init__(
self.sub_network_embedding_dim: int | None = embedding_dim

def set_sub_network(self, sub_network_embedding_dim: int):
"""Set the embedding dimensionality of the current sub-network."""
self.sub_network_embedding_dim = sub_network_embedding_dim

def reset_super_network(self):
"""Reset the embedding dimensionality of the current sub-network to the super-network dimensionality"""
self.sub_network_embedding_dim = self.embedding_dim

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down
4 changes: 4 additions & 0 deletions whittle/modules/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@


class LayerNorm(torch.nn.LayerNorm):
"""An extension of PyTorch's `torch.nn.LayerNorm` with support with support to sub-sample weights corresponding to the sub-network dimensionality."""

def __init__(self, in_features: int, eps: float = 1e-5):
super().__init__(in_features, eps)
self.in_features = in_features
Expand All @@ -13,9 +15,11 @@ def __init__(self, in_features: int, eps: float = 1e-5):
self.sub_network_in_features = self.in_features

def set_sub_network(self, sub_network_in_features: int):
"""Set the input dimensionality of the current sub-network."""
self.sub_network_in_features = sub_network_in_features

def reset_super_network(self):
"""Reset the input dimensionality of the current sub-network to the super-network dimensionality."""
self.sub_network_in_features = self.in_features

def forward(self, x: torch.Tensor) -> torch.Tensor:
Expand Down
5 changes: 4 additions & 1 deletion whittle/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@


class Linear(nn.Linear):
"""An extension of PyTorch's torch.nn.Linear with flexible input and output dimensionality corresponding to sub-network"""

def __init__(
self,
in_features: int,
Expand All @@ -14,7 +16,6 @@ def __init__(
device=None,
dtype=None,
):
""" """
super().__init__(in_features, out_features, bias, device, dtype)

# Set the current sub-network dimensions equal to super-network
Expand All @@ -25,10 +26,12 @@ def __init__(
def set_sub_network(
self, sub_network_in_features: int, sub_network_out_features: int
):
"""Set the linear transformation dimensions of the current sub-network."""
self.sub_network_in_features = sub_network_in_features
self.sub_network_out_features = sub_network_out_features

def reset_super_network(self):
"""Reset the linear transformation dimensions of the current sub-network to the super-network dimensionality."""
self.sub_network_in_features = self.in_features
self.sub_network_out_features = self.out_features

Expand Down
35 changes: 32 additions & 3 deletions whittle/sampling/random_sampler.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,45 @@
from __future__ import annotations
import warnings

from typing import Any

import numpy as np
from syne_tune.config_space import Categorical, Domain


class RandomSampler:
"""
RandomSampler samples configurations from a given search space using a random state.

Args:
config_space: The search space from which to sample.
seed: Seed for the random number generator. Defaults to None.
"""

def __init__(self, config_space: dict, seed: int | None = None):
self.config_space = config_space
self.rng = np.random.RandomState(seed)

def sample(self):
def sample(self) -> dict[str, Any]:
"""
Gets the smallest sub-network configuration from the search space.

Returns:
The smallest sub-network configuration.
"""
config = {}
for hp_name, hparam in self.config_space.items():
if isinstance(hparam, Domain):
config[hp_name] = hparam.sample(random_state=self.rng)
return config

def get_smallest_sub_network(self):
def get_smallest_sub_network(self) -> dict[str, Any]:
"""
Gets the smallest sub-network configuration from the search space.

Returns:
The smallest sub-network configuration.
"""
config = {}
for k, v in self.config_space.items():
if isinstance(v, Domain):
Expand All @@ -33,7 +55,14 @@ def get_smallest_sub_network(self):
config[k] = v.lower
return config

def get_largest_sub_network(self):
def get_largest_sub_network(self) -> dict[str, Any]:
"""
gets the largest sub-network configuration from the search space.

Returns:
The largest sub-network configuration.
"""

config = {}
for k, v in self.config_space.items():
if isinstance(v, Domain):
Expand Down
Loading