Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training equivariant transformer with OptimizedDistance #203

Closed
FranklinHu1 opened this issue Jul 17, 2023 · 9 comments
Closed

Training equivariant transformer with OptimizedDistance #203

FranklinHu1 opened this issue Jul 17, 2023 · 9 comments

Comments

@FranklinHu1
Copy link

Hello,

I am currently trying to train the equivariant transformer model using the OptimizedDistance module by replacing the call to Distance() with OptimizedDistance() in torchmd-net/torchmdnet/models/torchmd_et.py. I want to train on a system with periodic boundary conditions. However, when I try running the training, I get the following traceback:

Traceback (most recent call last):
  File "/home/frankhu/torchmd-net/torchmdnet/scripts/train.py", line 189, in <module>
    main()
  File "/home/frankhu/torchmd-net/torchmdnet/scripts/train.py", line 137, in main
    model = LNNP(args, prior_model=prior_models, mean=data.mean, std=data.std)
  File "/home/frankhu/torchmd-net/torchmdnet/module.py", line 29, in __init__
    self.model = create_model(self.hparams, prior_model, mean, std)
  File "/home/frankhu/torchmd-net/torchmdnet/models/model.py", line 70, in create_model
    representation_model = TorchMD_ET(
  File "/home/frankhu/torchmd-net/torchmdnet/models/torchmd_et.py", line 118, in __init__
    self.distance = OptimizedDistance(
  File "/home/frankhu/torchmd-net/torchmdnet/models/utils.py", line 199, in __init__
    from torchmdnet.neighbors import get_neighbor_pairs_kernel
  File "/home/frankhu/torchmd-net/torchmdnet/neighbors/__init__.py", line 15, in <module>
    compile_extension()
  File "/home/frankhu/torchmd-net/torchmdnet/neighbors/__init__.py", line 11, in compile_extension
    cpp_extension.load(
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.10/site-packages/torch/utils/cpp_extension.py", line 1269, in load
    return _jit_compile(
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.10/site-packages/torch/utils/cpp_extension.py", line 1453, in _jit_compile
    version = JIT_EXTENSION_VERSIONER.bump_version_if_changed(
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.10/site-packages/torch/utils/_cpp_extension_versioner.py", line 45, in bump_version_if_changed
    hash_value = hash_source_files(hash_value, source_files)
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.10/site-packages/torch/utils/_cpp_extension_versioner.py", line 15, in hash_source_files
    with open(filename) as file:
FileNotFoundError: [Errno 2] No such file or directory: '/home/frankhu/torchmd-net/torchmdnet/neighbors/backwards.cu'

I saw in a previous commit that this file was removed, but it seems like the model cannot proceed with training without it. For reference, here is the change I made within torchmd_et.py:

self.distance = OptimizedDistance(
            cutoff_lower,
            cutoff_upper,
            max_num_pairs = -max_num_neighbors,
            return_vecs = False,
            loop = False,
            strategy = 'brute',
            include_transpose = True,
            resize_to_fit = False,
            check_errors = False,
            box = torch.diag(torch.tensor(pbc_box))
        )

I am running on one Nvidia H100 GPU. Any help/clarification would be greatly appreciated.

Thank you!

@RaulPPelaez
Copy link
Collaborator

Hi!
I accidentally mentioned this old file in the list of sources for the extension comp. The CI missed it because it is GPU only and my local tests too because the file was present in my local copy.
Sorry for the inconvenience, I fix this in #204

@RaulPPelaez
Copy link
Collaborator

#204 is merged, please pull and try again.

@RaulPPelaez
Copy link
Collaborator

RaulPPelaez commented Jul 18, 2023

Also, be careful with the parameters, the current ET expects Distance to return_vecs and have self loops. You should replace the Distance line with:

        self.distance = OptimizedDistance(
            cutoff_lower,
            cutoff_upper,
            max_num_pairs=-max_num_neighbors,
            return_vecs=True,
            loop=True,
            check_errors = False, # Note that this line will silently leave neighbors out of the list if there are too much   
            box = torch.diag(torch.tensor(pbc_box))
        )

@FranklinHu1
Copy link
Author

Perfect, thank you so much! I ultimately need this model to compile down to a torchscript module for use with OpenMM, so are there any additional changes I should make to ensure that that will work? Or is replacing Distance with Optimized distance sufficient?

@RaulPPelaez
Copy link
Collaborator

That should be it! But if you find any issues just let us know.

@FranklinHu1
Copy link
Author

I tried training the equivariant transformer with the OptimizedDistance replaced as suggested. I am now running into the following error whenever the model tries to run through the test set after training:

Traceback (most recent call last):
  File "/home/frankhu/torchmd-net/torchmdnet/scripts/train.py", line 189, in <module>
    main()
  File "/home/frankhu/torchmd-net/torchmdnet/scripts/train.py", line 185, in main
    trainer.test(model, data)
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 936, in test
    return self._call_and_handle_interrupt(self._test_impl, model, dataloaders, ckpt_path, verbose, datamodule)
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 721, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 983, in _test_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1222, in _run
    self._log_hyperparams()
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1277, in _log_hyperparams
    raise MisconfigurationException(
pytorch_lightning.utilities.exceptions.MisconfigurationException: Error while merging hparams: the keys ['dtype'] are present in both the LightningModule's and LightningDataModule's hparams but have different values.

For context, I am training the model using a dataset organized in an HDF5 format. Looking at the input.yaml file that is generated, it says that the dtype is float32. My data is generated and organized using numpy, and I have a datatype of np.float32 for all of my entries except for the types entry, which has a datatype of np.int8.

Note that this does not occur during the fitting stage, and the model can successfully train for any number of epochs. This only occurs after training is complete. In that sense, it does not seem like a critical issue.

Thanks!

@FranklinHu1
Copy link
Author

I found a more serious issue with training the equivariant transformers using the OptimizedDistance class.

My goal with the equivariant transformer is to train a model using the implemented periodic boundary conditions such that it can be used with the TorchForce plugin from openmm-torch to drive dynamics in openmm. However, after generating a TorchScript module from the trained model, when I try to load my module as a force:

from openmmtorch import TorchForce
force = TorchForce("generated_mod.pt") #My generated TorchScript module

I run into the following error when trying to generate a TorchForce object:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/frankhu/mambaforge/envs/openmm_env/lib/python3.11/site-packages/openmmtorch.py", line 239, in __init__
    _openmmtorch.TorchForce_swiginit(self, _openmmtorch.new_TorchForce(*args))
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Exception:
Unknown builtin op: torchmdnet_neighbors::get_neighbor_pairs.
Could not find any similar ops to torchmdnet_neighbors::get_neighbor_pairs. This op may not exist or may not be currently supported in TorchScript.
:
  File "code/__torch__/torchmdnet/models/utils.py", line 45
    box0 = self.box
    use_periodic = self.use_periodic
    edge_index, edge_vec, edge_weight, num_pairs = ops.torchmdnet_neighbors.get_neighbor_pairs(strategy, pos, batch0, box0, use_periodic, cutoff_lower, cutoff_upper, max_pairs, loop, include_transpose)
                                                   ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
    check_errors = self.check_errors
    if check_errors:

Looking back at the documentation for OptimizedDistance, it says that for the operation to be placed inside a CUDA graph, resize_to_fit and check_errors have to both be False. I'm not sure if it is necessary for the operation to be CUDA graph compatible to work properly with TorchScript and TorchForce (I assume it is but please let me know if I am wrong), but I tried setting both of those arguments to False in torchmd_et.py in the Distance line. When I set resize_to_fit to False, I get the following error:

Traceback (most recent call last):
File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 719, in _call_and_handle_interrupt
    return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs)
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.10/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 93, in launch
    return function(*args, **kwargs)
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 809, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1234, in _run
    results = self._run_stage()
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1321, in _run_stage
    return self._run_train()
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1343, in _run_train
    self._run_sanity_check()
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1411, in _run_sanity_check
    val_loop.run()
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.10/site-packages/pytorch_lightning/loops/base.py", line 204, in run
    self.advance(*args, **kwargs)
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.10/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 154, in advance
    dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs)
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.10/site-packages/pytorch_lightning/loops/base.py", line 204, in run
    self.advance(*args, **kwargs)
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 127, in advance
    output = self._evaluation_step(**kwargs)
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 222, in _evaluation_step
    output = self.trainer._call_strategy_hook("validation_step", *kwargs.values())
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1763, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.10/site-packages/pytorch_lightning/strategies/ddp.py", line 347, in validation_step
    return self.model(*args, **kwargs)
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1156, in forward
    output = self._run_ddp_forward(*inputs, **kwargs)
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1110, in _run_ddp_forward
    return module_to_run(*inputs[0], **kwargs[0])  # type: ignore[index]
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.10/site-packages/pytorch_lightning/overrides/base.py", line 93, in forward
    return self.module.validation_step(*inputs, **kwargs)
  File "/home/frankhu/torchmd-net/torchmdnet/module.py", line 77, in validation_step
    return self.step(batch, mse_loss, "val")
  File "/home/frankhu/torchmd-net/torchmdnet/module.py", line 92, in step
    y, neg_dy = self(
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/frankhu/torchmd-net/torchmdnet/module.py", line 69, in forward
    return self.model(z, pos, batch=batch, q=q, s=s, extra_args=extra_args)
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/frankhu/torchmd-net/torchmdnet/models/model.py", line 262, in forward
    x = self.output_model.pre_reduce(x, v, z, pos, batch)
  File "/home/frankhu/torchmd-net/torchmdnet/models/output_modules.py", line 98, in pre_reduce
    x, v = layer(x, v)
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/frankhu/torchmd-net/torchmdnet/models/utils.py", line 499, in forward
    if not mask.all():
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/frankhu/torchmd-net/torchmdnet/scripts/train.py", line 189, in <module>
    main()
  File "/home/frankhu/torchmd-net/torchmdnet/scripts/train.py", line 180, in main
    trainer.fit(model, data)
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 768, in fit
    self._call_and_handle_interrupt(
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 736, in _call_and_handle_interrupt
    self._teardown()
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1298, in _teardown
    self.strategy.teardown()
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.10/site-packages/pytorch_lightning/strategies/ddp.py", line 474, in teardown
    self.lightning_module.cpu()
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.10/site-packages/pytorch_lightning/core/mixins/device_dtype_mixin.py", line 147, in cpu
    return super().cpu()
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.10/site-packages/torch/nn/modules/module.py", line 954, in cpu
    return self._apply(lambda t: t.cpu())
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.10/site-packages/torch/nn/modules/module.py", line 797, in _apply
    module._apply(fn)
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.10/site-packages/torch/nn/modules/module.py", line 797, in _apply
    module._apply(fn)
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.10/site-packages/torch/nn/modules/module.py", line 797, in _apply
    module._apply(fn)
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.10/site-packages/torch/nn/modules/module.py", line 820, in _apply
    param_applied = fn(param)
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.10/site-packages/torch/nn/modules/module.py", line 954, in <lambda>
    return self._apply(lambda t: t.cpu())
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

The slurm output file also has the following messages printed out:

/home/conda/feedstock_root/build_artifacts/pytorch-recipe_1680572619157/work/aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [335,0,0], thread: [64,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/home/conda/feedstock_root/build_artifacts/pytorch-recipe_1680572619157/work/aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [335,0,0], thread: [65,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/home/conda/feedstock_root/build_artifacts/pytorch-recipe_1680572619157/work/aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [335,0,0], thread: [66,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/home/conda/feedstock_root/build_artifacts/pytorch-recipe_1680572619157/work/aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [335,0,0], thread: [67,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/home/conda/feedstock_root/build_artifacts/pytorch-recipe_1680572619157/work/aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [335,0,0], thread: [68,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/home/conda/feedstock_root/build_artifacts/pytorch-recipe_1680572619157/work/aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [335,0,0], thread: [69,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/home/conda/feedstock_root/build_artifacts/pytorch-recipe_1680572619157/work/aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [335,0,0], thread: [70,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/home/conda/feedstock_root/build_artifacts/pytorch-recipe_1680572619157/work/aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [335,0,0], thread: [71,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/home/conda/feedstock_root/build_artifacts/pytorch-recipe_1680572619157/work/aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [335,0,0], thread: [72,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/home/conda/feedstock_root/build_artifacts/pytorch-recipe_1680572619157/work/aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [335,0,0], thread: [73,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/home/conda/feedstock_root/build_artifacts/pytorch-recipe_1680572619157/work/aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [335,0,0], thread: [74,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/home/conda/feedstock_root/build_artifacts/pytorch-recipe_1680572619157/work/aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [335,0,0], thread: [75,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/home/conda/feedstock_root/build_artifacts/pytorch-recipe_1680572619157/work/aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [335,0,0], thread: [76,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/home/conda/feedstock_root/build_artifacts/pytorch-recipe_1680572619157/work/aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [335,0,0], thread: [77,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/home/conda/feedstock_root/build_artifacts/pytorch-recipe_1680572619157/work/aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [335,0,0], thread: [78,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/home/conda/feedstock_root/build_artifacts/pytorch-recipe_1680572619157/work/aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [335,0,0], thread: [79,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/home/conda/feedstock_root/build_artifacts/pytorch-recipe_1680572619157/work/aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [335,0,0], thread: [80,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/home/conda/feedstock_root/build_artifacts/pytorch-recipe_1680572619157/work/aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [335,0,0], thread: [81,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/home/conda/feedstock_root/build_artifacts/pytorch-recipe_1680572619157/work/aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [335,0,0], thread: [82,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/home/conda/feedstock_root/build_artifacts/pytorch-recipe_1680572619157/work/aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [335,0,0], thread: [83,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/home/conda/feedstock_root/build_artifacts/pytorch-recipe_1680572619157/work/aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [335,0,0], thread: [84,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/home/conda/feedstock_root/build_artifacts/pytorch-recipe_1680572619157/work/aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [335,0,0], thread: [85,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/home/conda/feedstock_root/build_artifacts/pytorch-recipe_1680572619157/work/aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [335,0,0], thread: [86,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/home/conda/feedstock_root/build_artifacts/pytorch-recipe_1680572619157/work/aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [335,0,0], thread: [87,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/home/conda/feedstock_root/build_artifacts/pytorch-recipe_1680572619157/work/aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [335,0,0], thread: [88,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/home/conda/feedstock_root/build_artifacts/pytorch-recipe_1680572619157/work/aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [335,0,0], thread: [89,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/home/conda/feedstock_root/build_artifacts/pytorch-recipe_1680572619157/work/aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [335,0,0], thread: [90,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/home/conda/feedstock_root/build_artifacts/pytorch-recipe_1680572619157/work/aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [335,0,0], thread: [91,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/home/conda/feedstock_root/build_artifacts/pytorch-recipe_1680572619157/work/aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [335,0,0], thread: [92,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/home/conda/feedstock_root/build_artifacts/pytorch-recipe_1680572619157/work/aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [335,0,0], thread: [93,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/home/conda/feedstock_root/build_artifacts/pytorch-recipe_1680572619157/work/aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [335,0,0], thread: [94,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/home/conda/feedstock_root/build_artifacts/pytorch-recipe_1680572619157/work/aten/src/ATen/native/cuda/Indexing.cu:1146: indexSelectLargeIndex: block: [335,0,0], thread: [95,0,0] Assertion `srcIndex < srcSelectDimSize` failed.

For reference, my Distance line looks as follows:

self.distance = OptimizedDistance(
            cutoff_lower,
            cutoff_upper,
            max_num_pairs = -max_num_neighbors,
            return_vecs = True,
            loop = True,
            resize_to_fit = False, #Setting this to True makes training work, but False causes crashing
            check_errors = False,
            box = torch.diag(torch.tensor(pbc_box))
        )

Any guidance on how to debug this would be greatly appreciated as I am not very familiar with programming in CUDA or interfacing with CUDA at a low level.

Thank you!

@RaulPPelaez
Copy link
Collaborator

Your first error is probably the one described here: #205
We are working on it.

CUDA graphs are a separated feature and not required at all for OpenMM-Torch compatibility, your error above is telling you that some operation in ET is not CUDA graph compatible (in particular mask.all()). We have not had success thus far in training with CUDA graphs.
Setting resize_to_fit to False would require making severe modifications to ET. In that case OptimizedDistance pads the output with special values to ensure the output always has the same shape, ET does not handle this padding and thus fails when interpreting it as legal values.
check_errors set to False will silently let pairs out of the list if max_num_pairs is too low in exchange for being a bit faster. It is not really worth it if you are not using CUDA graphs, which you cant in ET currently.
The conf that will work for ET is this:

        self.distance = OptimizedDistance(
            cutoff_lower,
            cutoff_upper,
            max_num_pairs=-max_num_neighbors,
            return_vecs=True,
            loop=True,
            box = torch.diag(torch.tensor(pbc_box))
        )

The worrying error you are getting is this:

Could not find any similar ops to torchmdnet_neighbors::get_neighbor_pairs. This op may not exist or may not be currently supported in TorchScript.

I assumed by TorchScripting the module all operations will be placed inside the pt file. Apparently not!.
While this is not ideal, I currently do not see any workaround beyond modifying your OpenMM-Torch to load the required library:

from openmmtorch import TorchForce
import torchmdnet.models.utils #This line will register the library to torch.ops
force = TorchForce("generated_mod.pt") #My generated TorchScript module

I am sorry for this inconvenience, I will take a look to see if something can be done.

@FranklinHu1
Copy link
Author

Thanks for the clarifications, that was very helpful! I found that it was helpful to consolidate everything into one virtual environment because there were specific imports that were needed to make the TorchForce module to work, specifically:

from openmmtorch import TorchForce
import torchmdnet.neighbors #Addresses the missing operations error
import torch_cluster, torch_geometric #Required for some other operations within ET

force = TorchForce("generated_mod.pt") #This now works

With this set up, I can now run dynamics within openmm. Thanks for all your help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants