Skip to content

Commit

Permalink
Revise pretrain script to use Mamba2.
Browse files Browse the repository at this point in the history
  • Loading branch information
Adibvafa committed Oct 27, 2024
1 parent 72733e3 commit 8be0a05
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from odyssey.models.cehr_bert.model import BertPretrain
from odyssey.models.cehr_big_bird.model import BigBirdPretrain
from odyssey.models.ehr_mamba.model import MambaPretrain
from odyssey.models.ehr_mamba2.model import Mamba2Pretrain
from odyssey.models.model_utils import (
get_run_id,
load_config,
Expand Down Expand Up @@ -127,6 +128,14 @@ def main(args: argparse.Namespace, model_config: Dict[str, Any]) -> None:
cls_idx=tokenizer.get_class_token_id(),
**model_config,
)
elif args.model_type == "ehr_mamba2":
model = Mamba2Pretrain(
vocab_size=tokenizer.get_vocab_size(),
padding_idx=tokenizer.get_pad_token_id(),
cls_idx=tokenizer.get_class_token_id(),
eos_idx=tokenizer.get_eos_token_id(),
**model_config,
)

run_id = get_run_id(args.checkpoint_dir)

Expand Down Expand Up @@ -177,7 +186,7 @@ def main(args: argparse.Namespace, model_config: Dict[str, Any]) -> None:
"--model-type",
type=str,
required=True,
help="Model type: 'cehr_bert' or 'cehr_bigbird' or 'ehr_mamba'",
help="Model type: 'cehr_bert' or 'cehr_bigbird' or 'ehr_mamba' or 'ehr_mamba2'",
)
parser.add_argument(
"--exp-name",
Expand Down Expand Up @@ -272,9 +281,9 @@ def main(args: argparse.Namespace, model_config: Dict[str, Any]) -> None:

args = parser.parse_args()

if args.model_type not in ["cehr_bert", "cehr_bigbird", "ehr_mamba"]:
if args.model_type not in ["cehr_bert", "cehr_bigbird", "ehr_mamba", "ehr_mamba2"]:
print(
"Invalid model type. Choose 'cehr_bert' or 'cehr_bigbird' or 'ehr_mamba'."
"Invalid model type. Choose 'cehr_bert' or 'cehr_bigbird' or 'ehr_mamba' or 'ehr_mamba2'."
)
sys.exit(1)

Expand Down

0 comments on commit 8be0a05

Please sign in to comment.