diff --git a/models/transformers/mpt.py b/models/transformers/mpt.py index 951258b..f0e084d 100644 --- a/models/transformers/mpt.py +++ b/models/transformers/mpt.py @@ -1,9 +1,10 @@ - # labels: name::mpt author::transformers task::Generative_AI license::apache-2.0 from turnkeyml.parser import parse from transformers import MptModel, AutoConfig import torch +torch.manual_seed(0) + # Parsing command-line arguments pretrained, batch_size, max_seq_length = parse( ["pretrained", "batch_size", "max_seq_length"] @@ -25,4 +26,3 @@ # Call model model(**inputs) - \ No newline at end of file