diff --git a/examples/ltc_backend_bert.py b/examples/ltc_backend_bert.py index d309ba87136..a8c3d43a674 100644 --- a/examples/ltc_backend_bert.py +++ b/examples/ltc_backend_bert.py @@ -17,11 +17,9 @@ import sys from typing import List -import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend import torch import torch._C import torch._lazy -import torch._lazy.ts_backend from datasets import load_dataset from datasets.dataset_dict import DatasetDict from torch.utils.data import DataLoader @@ -146,9 +144,11 @@ def main(device='lazy', full_size=False): if args.device in ("TS", "MLIR_EXAMPLE"): if args.device == "TS": + import torch._lazy.ts_backend torch._lazy.ts_backend.init() elif args.device == "MLIR_EXAMPLE": + import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend ltc_backend._initialize() device = "lazy" diff --git a/examples/ltc_backend_mnist.py b/examples/ltc_backend_mnist.py index 65f8a4e5f37..07c8b5d1108 100644 --- a/examples/ltc_backend_mnist.py +++ b/examples/ltc_backend_mnist.py @@ -8,10 +8,8 @@ import argparse import sys -import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend import torch import torch._lazy -import torch._lazy.ts_backend import torch.nn.functional as F @@ -91,9 +89,11 @@ def forward(self, x): if args.device in ("TS", "MLIR_EXAMPLE"): if args.device == "TS": + import torch._lazy.ts_backend torch._lazy.ts_backend.init() elif args.device == "MLIR_EXAMPLE": + import ltc_backend.ltc_backend._EXAMPLE_MLIR_BACKEND as ltc_backend ltc_backend._initialize() device = "lazy"