Skip to content

Commit

Permalink
Merge branch 'sdavis_trn' into 'master'
Browse files Browse the repository at this point in the history
transformer

See merge request machine-learning/bonito!166
  • Loading branch information
iiSeymour committed Mar 28, 2024
2 parents bef9276 + 983c9be commit e33a860
Show file tree
Hide file tree
Showing 7 changed files with 264 additions and 14 deletions.
23 changes: 18 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Bonito

[![PyPI version](https://badge.fury.io/py/ont-bonito.svg)](https://badge.fury.io/py/ont-bonito)
[![PyPI version](https://badge.fury.io/py/ont-bonito.svg)](https://badge.fury.io/py/ont-bonito)
[![py38](https://img.shields.io/badge/python-3.8-brightgreen.svg)](https://img.shields.io/badge/python-3.8-brightgreen.svg)
[![py39](https://img.shields.io/badge/python-3.9-brightgreen.svg)](https://img.shields.io/badge/python-3.9-brightgreen.svg)
[![py310](https://img.shields.io/badge/python-3.10-brightgreen.svg)](https://img.shields.io/badge/python-3.10-brightgreen.svg)
[![py311](https://img.shields.io/badge/python-3.11-brightgreen.svg)](https://img.shields.io/badge/python-3.11-brightgreen.svg)
[![cu117](https://img.shields.io/badge/cuda-11.7-blue.svg)](https://img.shields.io/badge/cuda-11.7-blue.svg)
[![cu118](https://img.shields.io/badge/cuda-11.8-blue.svg)](https://img.shields.io/badge/cuda-11.8-blue.svg)

Bonito is an open source research basecaller for Oxford Nanopore reads.

Expand All @@ -30,6 +30,15 @@ $ bonito download --models --show # show all available models
$ bonito download --models # download all available models
```

## Transformer Models

The `bonito.transformer` package requires
[flash-attn](https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features).

This must be manually installed as the `flash-attn` packaging system prevents it from being listed as a normal dependency.

Setting `CUDA_HOME` to the relevant library directory will help avoid CUDA version mismatches between packages.

## Modified Bases

Modified base calling is handled by [Remora](https://github.com/nanoporetech/remora).
Expand All @@ -49,7 +58,7 @@ $ bonito basecaller dna_r9.4.1 --save-ctc --reference reference.mmi /data/reads
$ bonito train --directory /data/training/ctc-data /data/training/model-dir
```

In addition to training a new model from scratch you can also easily fine tune one of the pretrained models.
In addition to training a new model from scratch you can also easily fine tune one of the pretrained models.

```bash
bonito train --epochs 1 --lr 5e-4 --pretrained dna_r10.4.1_e8.2_400bps_hac@v4.0.0 --directory /data/training/ctc-data /data/training/fine-tuned-model
Expand All @@ -62,7 +71,7 @@ $ bonito download --training
$ bonito train /data/training/model-dir
```

All training calls use Automatic Mixed Precision to speed up training. To disable this, set the `--no-amp` flag to True.
All training calls use Automatic Mixed Precision to speed up training. To disable this, set the `--no-amp` flag to True.

## Developer Quickstart

Expand All @@ -72,9 +81,13 @@ $ cd bonito
$ python3 -m venv venv3
$ source venv3/bin/activate
(venv3) $ pip install --upgrade pip
(venv3) $ pip install -e .
(venv3) $ pip install -e .[cu118] --extra-index-url https://download.pytorch.org/whl/cu118
```

The `ont-bonito[cu118]` and `ont-bonito[cu121]` optional dependencies can be used, along
with the corresponding `--extra-index-url`, to ensure the PyTorch package matches the
local CUDA setup.

## Interface

- `bonito view` - view a model architecture for a given `.toml` file and the number of parameters in the network.
Expand Down
79 changes: 78 additions & 1 deletion bonito/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
Bonito nn modules.
"""

from collections import OrderedDict

import torch
from torch.nn import Module
from torch.nn.init import orthogonal_
Expand Down Expand Up @@ -94,6 +96,81 @@ def to_dict(self, include_weights=False):
def __repr__(self):
return torch.nn.ModuleList.__repr__(self)


@register
class Stack(Serial):
@classmethod
def from_dict(cls, model_dict, layer_types=None):
return cls([from_dict(model_dict["layer"], layer_types) for _ in range(model_dict["depth"])])

def to_dict(self, include_weights=False):
if include_weights:
raise NotImplementedError
layer_dicts = [to_dict(layer) for layer in self]
for layer_dict in layer_dicts[1:]:
assert layer_dict == layer_dicts[0], "all layers should be the same"
return {"layer": layer_dicts[0], "depth": len(self)}


@register
class NamedSerial(torch.nn.Sequential):
@classmethod
def from_dict(cls, model_dict, layer_types=None):
return cls({k: from_dict(v, layer_types) for k, v in model_dict.items()})

def __init__(self, layers):
# Sequential throws error if given dict
super().__init__(OrderedDict(layers.items()))

def to_dict(self, include_weights=False):
if include_weights:
raise NotImplementedError
return {k: to_dict(v) for k, v in self.named_children()}


class MakeContiguous(Module):
def __init__(self):
super().__init__()

def forward(self, x):
return x.contiguous()


@register
class LinearUpsample(Module):
"""
Applies a linear transformation to upsample the sequence length by ``scale_factor``.
"""

def __init__(self, d_model, scale_factor, batch_first=True):
super().__init__()
self.d_model = d_model
self.scale_factor = scale_factor
self.batch_first = batch_first
self.linear = torch.nn.Linear(d_model, self.scale_factor * d_model)

def forward(self, src):
if not self.batch_first:
src = src.permute([1, 0, 2])
N, L, E = src.shape
h = self.linear(src).reshape(N, self.scale_factor * L, E)
if not self.batch_first:
h = h.permute([1, 0, 2])
return h

def output_stride(self, input_stride):
return input_stride // self.scale_factor

def to_dict(self, include_weights=False):
if include_weights:
raise NotImplementedError
return {
"d_model": self.d_model,
"scale_factor": self.scale_factor,
"batch_first": self.batch_first
}


@register
class Reverse(Module):

Expand Down Expand Up @@ -248,7 +325,7 @@ def extra_repr(self):
if self.permute:
rep += f', permute={self.permute}'
return rep


@register
class Permute(Module):
Expand Down
2 changes: 2 additions & 0 deletions bonito/transformer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .model import Model
from .basecall import basecall
1 change: 1 addition & 0 deletions bonito/transformer/basecall.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from bonito.crf import basecall
154 changes: 154 additions & 0 deletions bonito/transformer/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import logging
import types
from functools import lru_cache

logger = logging.getLogger(__name__)

import torch
import torch.nn.functional as F
try:
from flash_attn import flash_attn_qkvpacked_func
from flash_attn.layers.rotary import RotaryEmbedding
from flash_attn.modules.mlp import GatedMlp
from flash_attn.ops.triton.layer_norm import RMSNorm
except ImportError:
logger.warning(
"please install flash-attn to use the transformer module: "
"`pip install flash-attn --no-build-isolation`"
)

from bonito.crf.model import SeqdistModel
from bonito.nn import from_dict, register, LinearCRFEncoder, MakeContiguous, Module, Permute, Serial


def deepnorm_params(depth):
"""
Returns the DeepNorm (https://arxiv.org/abs/2203.00555) alpha and beta parameters.
"""
alpha = round((2*depth)**0.25, 7)
beta = round((8*depth)**(-1/4), 7)
return alpha, beta


@lru_cache(maxsize=2)
def sliding_window_mask(seq_len, window, device):
band = torch.full((seq_len, seq_len), fill_value=1.0)
band = torch.triu(band, diagonal=-window[0])
band = band * torch.tril(band, diagonal=window[1])
band = band.to(torch.bool).to(device)
return band


class MultiHeadAttention(Module):
def __init__(self, d_model, nhead, qkv_bias=False, out_bias=True, rotary_dim=None, attn_window=None):
super().__init__()
assert d_model % nhead == 0, "d_model must be divisible by nhead"

self.d_model = d_model
self.nhead = nhead
self.head_dim = d_model // nhead
self.rotary_dim = self.head_dim if rotary_dim is None else rotary_dim

self.Wqkv = torch.nn.Linear(d_model, 3 * d_model, bias=qkv_bias)
self.out_proj = torch.nn.Linear(d_model, d_model, bias=out_bias)

self.rotary_emb = RotaryEmbedding(self.rotary_dim, interleaved=False)
self.attn_window = (-1, -1) if attn_window is None else tuple(attn_window)

def attn_func(self, qkv):
if torch.cuda.get_device_capability(qkv.device)[0] >= 8 and (torch.is_autocast_enabled() or qkv.dtype == torch.half):
attn_output = flash_attn_qkvpacked_func(qkv, window_size=self.attn_window)
else:
q, k, v = torch.chunk(qkv.permute(0, 2, 3, 1, 4), chunks=3, dim=1)
mask = sliding_window_mask(qkv.shape[1], self.attn_window, q.device)
attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
attn_output = attn_output.permute(0, 1, 3, 2, 4)
return attn_output

def forward(self, x):
N, T, _ = x.shape

qkv = self.Wqkv(x).view(N, T, 3, self.nhead, self.head_dim)

qkv = self.rotary_emb(qkv)

attn_output = self.attn_func(qkv).reshape(N, T, self.d_model)

out = self.out_proj(attn_output)

return out


@register
class TransformerEncoderLayer(Module):
def __init__(self, d_model, nhead, dim_feedforward, deepnorm_alpha, deepnorm_beta, attn_window=None):
super().__init__()
self.kwargs = {
"d_model": d_model,
"nhead": nhead,
"dim_feedforward": dim_feedforward,
"deepnorm_alpha": deepnorm_alpha,
"deepnorm_beta": deepnorm_beta,
"attn_window": attn_window
}

self.self_attn = MultiHeadAttention(
d_model=d_model,
nhead=nhead,
qkv_bias=False,
out_bias=True,
attn_window=attn_window
)
self.ff = GatedMlp(
d_model,
hidden_features=dim_feedforward,
activation=F.silu,
bias1=False,
bias2=False,
multiple_of=1,
)
self.norm1 = RMSNorm(d_model)
self.norm2 = RMSNorm(d_model)

self.register_buffer("deepnorm_alpha", torch.tensor(deepnorm_alpha))
self.reset_parameters()

def reset_parameters(self):
db = self.kwargs["deepnorm_beta"]
d_model = self.kwargs["d_model"]
torch.nn.init.xavier_normal_(self.ff.fc1.weight, gain=db)
torch.nn.init.xavier_normal_(self.ff.fc2.weight, gain=db)
torch.nn.init.xavier_normal_(self.self_attn.out_proj.weight, gain=db)
torch.nn.init.xavier_normal_(self.self_attn.Wqkv.weight[2*d_model:], gain=db)
torch.nn.init.xavier_normal_(self.self_attn.Wqkv.weight[:2*d_model], gain=1)

def forward(self, x):
x = self.norm1(self.self_attn(x), self.deepnorm_alpha*x)
x = self.norm2(self.ff(x), self.deepnorm_alpha*x)
return x

def to_dict(self, include_weights=False):
if include_weights:
raise NotImplementedError
return self.kwargs


def use_koi(self, **kwargs):
# koi needs modified LinearCRFLayer settings
def _expand_blanks(m):
if isinstance(m, LinearCRFEncoder):
m.expand_blanks = False
self.encoder.apply(_expand_blanks)
self.encoder = Serial([
self.encoder,
Permute([1, 0, 2]),
MakeContiguous(),
])


def Model(config):
model_config = {k: v for k, v in config["model"].items() if k != "package"}
model = from_dict(model_config)
model.config = config
model.use_koi = types.MethodType(use_koi, model)
return model
7 changes: 4 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,18 @@ mappy==2.24
toml==0.10.2
tqdm>4,<5
scipy==1.10.1
networkx~=3.1.0
numpy~=1.24.2
pysam==0.21.0
parasail==1.3.4
pandas>1,<2
requests~=2.28.2
requests~=2.31.0
ont-koi==0.3.0
ont-remora==2.1.3
ont-fast5-api==3.3.0
pod5==0.3.6
fast-ctc-decode==0.3.5
python-dateutil==2.8.2
edlib==1.3.9
# cuda requirements
torch~=2.0.0
torch~=2.2.1
wheel
12 changes: 7 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import re
from setuptools import setup, find_packages
from setuptools.command.install import install


__pkg_name__ = 'bonito'
Expand Down Expand Up @@ -35,12 +34,15 @@
author='Oxford Nanopore Technologies, Ltd',
author_email='support@nanoporetech.com',
url='https://github.com/nanoporetech/bonito',
entry_points = {
extras_require={
# --extra-index-url https://download.pytorch.org/whl/cu118
"cu118": ["torch==2.2.1+cu118"],
# --extra-index-url https://download.pytorch.org/whl/cu121
"cu121": ["torch==2.2.1+cu121"],
},
entry_points={
'console_scripts': [
'{0} = {0}:main'.format(__pkg_name__)
]
},
dependency_links=[
'https://download.pytorch.org/whl/cu113',
]
)

0 comments on commit e33a860

Please sign in to comment.