Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error using VOS optimized inferences in vos_inference.py #501

Open
MohammedSB opened this issue Dec 18, 2024 · 14 comments
Open

Error using VOS optimized inferences in vos_inference.py #501

MohammedSB opened this issue Dec 18, 2024 · 14 comments

Comments

@MohammedSB
Copy link

Hi,

When setting the --use_vos_optimized_video_predictor
I get the following error
torch._dynamo.exc.InternalTorchDynamoError: RuntimeError: Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run.

This is my environment, I built it from SAM2's requirement.

antlr4-python3-runtime   4.9.3
filelock                 3.16.1
fsspec                   2024.10.0
hydra-core               1.3.2
iopath                   0.1.10
Jinja2                   3.1.4
MarkupSafe               3.0.2
mpmath                   1.3.0
natsort                  8.4.0
networkx                 3.4.2
numpy                    2.2.0
nvidia-cublas-cu12       12.4.5.8
nvidia-cuda-cupti-cu12   12.4.127
nvidia-cuda-nvrtc-cu12   12.4.127
nvidia-cuda-runtime-cu12 12.4.127
nvidia-cudnn-cu12        9.1.0.70
nvidia-cufft-cu12        11.2.1.3
nvidia-curand-cu12       10.3.5.147
nvidia-cusolver-cu12     11.6.1.9
nvidia-cusparse-cu12     12.3.1.170
nvidia-nccl-cu12         2.21.5
nvidia-nvjitlink-cu12    12.4.127
nvidia-nvtx-cu12         12.4.127
omegaconf                2.3.0
opencv-python            4.10.0.84
packaging                24.2
pandas                   2.2.3
pillow                   11.0.0
pip                      24.2
portalocker              3.0.0
python-dateutil          2.9.0.post0
pytz                     2024.2
PyYAML                   6.0.2
SAM-2                    1.0         /path/sam2
scipy                    1.14.1
setuptools               75.6.0
six                      1.17.0
sympy                    1.13.1
torch                    2.5.1
torchvision              0.20.1
tqdm                     4.67.1
triton                   3.1.0
typing_extensions        4.12.2
tzdata                   2024.2

My OS Ubuntu 22.04.5 LTS

Would really appreciate the help!

@tonydavis629
Copy link

Same issue

@chayryali
Copy link
Contributor

@MohammedSB @tonydavis629 Can you share the full trace?

@MohammedSB
Copy link
Author

./inference_video.sh: line 2: source: /mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/: is a directory
Running inference for dataset /SUN-SEG/testing-easy with model checkpoint /model-weights/data_medsam2/exp_log/v2_Only3D_tiny512-75e_box10_10lr_video10xRatio/checkpoints/checkpoint_10.pt
Image encoder compilation is enabled. First forward pass will be slow.
Compiling all components for VOS setting. First time may be very slow.
using only the first frame's mask in input_mask_dir as input to the SAM 2 model
running VOS prediction on 119 videos:
['case35_2', 'case20_6', 'case85_7', 'case32_5', 'case66_8', 'case91_2', 'case12_3', 'case91_1', 'case50_2', 'case2_3', 'case12_1', 'case20_9', 'case19', 'case54_2', 'case71_2', 'case11', 'case56_2', 'case95_4', 'case44_1', 'case14_1', 'case7_2', 'case68_3', 'case67', 'case51_8', 'case47_2', 'case24_5', 'case37_3', 'case34_2', 'case89', 'case96_2', 'case80_7', 'case23', 'case74_1', 'case36_4', 'case63_2', 'case83_3', 'case32_6', 'case80_6', 'case66_5', 'case80_5', 'case2_9', 'case6_2', 'case80_3', 'case66_3', 'case32_2', 'case86_1', 'case51_7', 'case79', 'case95_2', 'case40', 'case94', 'case51_9', 'case91_4', 'case81_1', 'case95_3', 'case37_1', 'case86_2', 'case24_2', 'case56_1', 'case90_1', 'case35_8', 'case50_3', 'case51_1', 'case14_5', 'case2_5', 'case47_4', 'case13_1', 'case24_1', 'case24_3', 'case39_3', 'case5_6', 'case80_8', 'case37_2', 'case8_1', 'case66_6', 'case51_5', 'case64', 'case91_3', 'case2_8', 'case80_2', 'case91_5', 'case91_6', 'case47_3', 'case51_4', 'case80_1', 'case2_6', 'case72_3', 'case8_2', 'case24_4', 'case97_3', 'case68_5', 'case99', 'case14_4', 'case66_9', 'case68_2', 'case47_5', 'case5_4', 'case29_2', 'case74_2', 'case32_3', 'case14_2', 'case20_7', 'case3_2', 'case13_3', 'case36_2', 'case60', 'case25_2', 'case51_6', 'case80_4', 'case31', 'case75_1', 'case71_1', 'case3_1', 'case68_1', 'case96_1', 'case97_1', 'case48', 'case9', 'case34_1']

1/119 - running on case35_2
frame loading (JPEG): 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 90/90 [00:02<00:00, 44.75it/s]
propagate in video:   1%|█                                                                                                   | 1/90 [00:00<00:54,  1.63it/s]
Traceback (most recent call last):
  File "/mnt/pool/home/jma/Documents/sam2/tools/vos_inference.py", line 508, in <module>
    main()
  File "/mnt/pool/home/jma/Documents/sam2/tools/vos_inference.py", line 479, in main
    vos_inference(
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/sam2/tools/vos_inference.py", line 224, in vos_inference
    for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 57, in generator_context
    response = gen.send(request)
               ^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/sam2/sam2/sam2_video_predictor.py", line 603, in propagate_in_video
    current_out, pred_masks = self._run_single_frame_inference(
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/sam2/sam2/sam2_video_predictor.py", line 758, in _run_single_frame_inference
    ) = self._get_image_feature(inference_state, frame_idx, batch_size)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/sam2/sam2/sam2_video_predictor.py", line 714, in _get_image_feature
    backbone_out = self.forward_image(image)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/sam2/sam2/sam2_video_predictor.py", line 1018, in forward_image
    backbone_out = self.image_encoder(img_batch)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1269, in __call__
    return self._torchdynamo_orig_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 526, in __call__
    return _compile(
           ^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 952, in _compile
    raise InternalTorchDynamoError(
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 924, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 666, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_utils_internal.py", line 87, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 699, in _compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
    transformations(instructions, code_options)
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 219, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 634, in transform
    tracer.run()
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2796, in run
    super().run()
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
          ^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2279, in CALL
    self._call(inst)
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2273, in _call
    self.call_function(fn, args, kwargs)
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/variables/lazy.py", line 156, in realize_and_forward
    return getattr(self.realize(), name)(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/variables/nn_module.py", line 899, in call_function
    return variables.UserFunctionVariable(fn, source=source).call_function(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_
    tracer.run()
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
          ^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2279, in CALL
    self._call(inst)
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2273, in _call
    self.call_function(fn, args, kwargs)
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/variables/lazy.py", line 156, in realize_and_forward
    return getattr(self.realize(), name)(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/variables/nn_module.py", line 899, in call_function
    return variables.UserFunctionVariable(fn, source=source).call_function(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_
    tracer.run()
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
          ^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1680, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_
    tracer.run()
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
          ^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1680, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 385, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_
    tracer.run()
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
          ^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 582, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 1680, in CALL_FUNCTION_EX
    self.call_function(fn, argsvars.items, kwargsvars)
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 830, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 324, in call_function
    return super().call_function(tx, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py", line 111, in call_function
    return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 836, in inline_user_function_return
    return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3011, in inline_call
    return cls.inline_call_(parent, func, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 3139, in inline_call_
    tracer.run()
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
          ^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2102, in CONTAINS_OP
    self.push(right.call_method(self, "__contains__", [left], {}))
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/variables/lazy.py", line 156, in realize_and_forward
    return getattr(self.realize(), name)(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/variables/dicts.py", line 338, in call_method
    return ConstantVariable.create(args[0] in self)
                                   ^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/variables/dicts.py", line 180, in __contains__
    and not isinstance(self.items[Hashable(vt)], variables.DeletedVariable)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/variables/base.py", line 110, in __instancecheck__
    instance = instance.realize()
               ^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/variables/lazy.py", line 63, in realize
    self._cache.realize()
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/variables/lazy.py", line 29, in realize
    self.vt = VariableBuilder(tx, self.source)(self.value)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/variables/builder.py", line 377, in __call__
    vt = self._wrap(value)
         ^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/variables/builder.py", line 543, in _wrap
    return type_dispatch(self, value)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/variables/builder.py", line 1593, in wrap_tensor
    tensor_variable = wrap_fx_proxy(
                      ^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/variables/builder.py", line 2037, in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/variables/builder.py", line 2149, in wrap_fx_proxy_cls
    example_value = wrap_to_fake_tensor_and_record(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/variables/builder.py", line 2709, in wrap_to_fake_tensor_and_record
    fake_e = wrap_fake_exception(
             ^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/utils.py", line 1574, in wrap_fake_exception
    return fn()
           ^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/variables/builder.py", line 2710, in <lambda>
    lambda: tx.fake_mode.from_tensor(
            ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 2238, in from_tensor
    return self.fake_tensor_converter.from_real_tensor(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py", line 375, in from_real_tensor
    out = self.meta_converter(
          ^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_subclasses/meta_utils.py", line 1637, in __call__
    t_desc = self.describer.describe_tensor(t, trace=trace)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_subclasses/meta_utils.py", line 245, in describe_tensor
    storage = self.describe_storage(t.untyped_storage(), trace=trace)
                                    ^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.InternalTorchDynamoError: RuntimeError: Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run. Stack trace: File "/mnt/pool/home/jma/Documents/sam2/sam2/modeling/backbones/image_encoder.py", line 31, in forward
    features, pos = self.neck(self.trunk(sample))
  File "/mnt/pool/home/jma/Documents/sam2/sam2/modeling/backbones/image_encoder.py", line 132, in forward
    pos[i] = self.position_encoding(x_out).to(x_out.dtype)
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/mnt/pool/home/jma/Documents/sam2/sam2/modeling/position_encoding.py", line 130, in forward
    return self._pe(B, x.device, *cache_key)
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/mnt/pool/home/jma/Documents/sam2/sam2/modeling/position_encoding.py", line 123, in _pe
    self.cache[cache_key] = pos[0].clone(). To prevent overwriting, clone the tensor outside of torch.compile() or call torch.compiler.cudagraph_mark_step_begin() before each model invocation.

from user code:
   File "/mnt/pool/home/jma/Documents/sam2/sam2/modeling/backbones/image_encoder.py", line 31, in forward
    features, pos = self.neck(self.trunk(sample))
  File "/mnt/pool/home/jma/Documents/sam2/sam2/modeling/backbones/image_encoder.py", line 132, in forward
    pos[i] = self.position_encoding(x_out).to(x_out.dtype)
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/mnt/pool/home/jma/Documents/sam2/sam2/modeling/position_encoding.py", line 130, in forward
    return self._pe(B, x.device, *cache_key)
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/mnt/pool/home/jma/Documents/sam2/sam2/modeling/position_encoding.py", line 92, in _pe
    if cache_key in self.cache:

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

@Adibvafa
Copy link

Adibvafa commented Dec 19, 2024

Same issue, only when optimized is set to True.

@chayryali
Copy link
Contributor

@MohammedSB Thanks for sharing the trace. Wondering if your installation and code is up to date with main? Specifically, in the trace, it says line 123 is self.cache[cache_key] = pos[0].clone(), but this is not consistent with what's on main, which is self.cache[cache_key] = pos[0], i.e. there is no clone operation inside the compiled part of the code.

@MohammedSB
Copy link
Author

@chayryali Sorry, you get the same issue even if you don't have the clone.
I added the clone in an attempt to fix.
U can ignore that--u still get the exact same issue without it.

@chayryali
Copy link
Contributor

chayryali commented Dec 19, 2024

@MohammedSB Ok, can you share a trace without this additional clone?

Also, does compiling only the image encoder (using the setting in the config compile_image_encoder: True and without using --use_vos_optimized_video_predictor) run into any errors?

Also, what is the resolution of your video frames? The default assumes 1024 (see here) and warms up the cache accordingly.

Can other folks (@Adibvafa @tonydavis629) also share some more information resolution, traces etc if possible.

@MohammedSB
Copy link
Author

MohammedSB commented Dec 19, 2024

@chayryali I am using 512 image resolution. Changing the default from 1024 -> 512 actually fixed the issue. But now I get another one:

./inference_video.sh: line 2: source: /mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/: is a directory
Running inference for dataset US_CAMUS_Test with model checkpoint /model-weights/data_medsam2/exp_log/v2_Only3D_tiny512-75e_box10_10lr_video_baseRatio/checkpoints/checkpoint_10.pt
Image encoder compilation is enabled. First forward pass will be slow.
Compiling all components for VOS setting. First time may be very slow.
using only the first frame's mask in input_mask_dir as input to the SAM 2 model
running VOS prediction on 100 videos:
['patient0239_4CH', 'patient0201_4CH', 'patient0276_2CH', 'patient0243_2CH', 'patient0194_4CH', 'patient0242_4CH', 'patient0052_4CH', 'patient0258_2CH', 'patient0217_4CH', 'patient0225_2CH', 'patient0251_2CH', 'patient0214_4CH', 'patient0219_4CH', 'patient0273_4CH', 'patient0254_2CH', 'patient0223_2CH', 'patient0215_2CH', 'patient0219_2CH', 'patient0226_4CH', 'patient0215_4CH', 'patient0027_4CH', 'patient0246_4CH', 'patient0234_4CH', 'patient0248_2CH', 'patient0194_2CH', 'patient0220_4CH', 'patient0052_2CH', 'patient0240_4CH', 'patient0254_4CH', 'patient0191_4CH', 'patient0228_2CH', 'patient0237_2CH', 'patient0260_4CH', 'patient0261_4CH', 'patient0252_2CH', 'patient0027_2CH', 'patient0231_4CH', 'patient0263_2CH', 'patient0273_2CH', 'patient0246_2CH', 'patient0221_2CH', 'patient0263_4CH', 'patient0248_4CH', 'patient0239_2CH', 'patient0241_2CH', 'patient0051_4CH', 'patient0238_2CH', 'patient0269_4CH', 'patient0224_2CH', 'patient0262_2CH', 'patient0238_4CH', 'patient0218_4CH', 'patient0240_2CH', 'patient0187_4CH', 'patient0266_2CH', 'patient0262_4CH', 'patient0241_4CH', 'patient0224_4CH', 'patient0197_4CH', 'patient0226_2CH', 'patient0275_2CH', 'patient0217_2CH', 'patient0251_4CH', 'patient0213_4CH', 'patient0187_2CH', 'patient0237_4CH', 'patient0231_2CH', 'patient0260_2CH', 'patient0227_4CH', 'patient0214_2CH', 'patient0252_4CH', 'patient0221_4CH', 'patient0047_4CH', 'patient0201_2CH', 'patient0218_2CH', 'patient0208_2CH', 'patient0208_4CH', 'patient0199_4CH', 'patient0225_4CH', 'patient0197_2CH', 'patient0242_2CH', 'patient0227_2CH', 'patient0047_2CH', 'patient0223_4CH', 'patient0234_2CH', 'patient0243_4CH', 'patient0189_2CH', 'patient0191_2CH', 'patient0051_2CH', 'patient0276_4CH', 'patient0228_4CH', 'patient0266_4CH', 'patient0199_2CH', 'patient0269_2CH', 'patient0213_2CH', 'patient0220_2CH', 'patient0258_4CH', 'patient0261_2CH', 'patient0189_4CH', 'patient0275_4CH']

1/100 - running on patient0239_4CH
frame loading (JPEG): 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 19/19 [00:00<00:00, 104.96it/s]
propagate in video:   0%|                                                                                                                          | 0/19 [00:00<?, ?it/s]/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_inductor/lowering.py:1713: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
  warnings.warn(
propagate in video:   5%|██████                                                                                                            | 1/19 [00:04<01:26,  4.78s/it]
Traceback (most recent call last):
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1446, in _call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/repro/after_dynamo.py", line 129, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/repro/after_dynamo.py", line 129, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/__init__.py", line 2234, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1253, in compile_fx
    return compile_fx(
           ^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1521, in compile_fx
    return aot_autograd(
           ^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/backends/common.py", line 72, in __call__
    cg = aot_module_simplified(gm, example_inputs, **self.kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1071, in aot_module_simplified
    compiled_fn = dispatch_and_compile()
                  ^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 1056, in dispatch_and_compile
    compiled_fn, _ = create_aot_dispatcher_function(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 522, in create_aot_dispatcher_function
    return _create_aot_dispatcher_function(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_functorch/aot_autograd.py", line 759, in _create_aot_dispatcher_function
    compiled_fn, fw_metadata = compiler_fn(
                               ^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 179, in aot_dispatch_base
    compiled_fw = compiler(fw_module, updated_flat_args)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1350, in fw_compiler_base
    return _fw_compiler_base(model, example_inputs, is_inference)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 1421, in _fw_compiler_base
    return inner_compile(
           ^^^^^^^^^^^^^^
  File "/home/jma/anaconda3/lib/python3.12/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 475, in compile_fx_inner
    return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/repro/after_aot.py", line 85, in debug_wrapper
    inner_compiled_fn = compiler_fn(gm, example_inputs)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_inductor/compile_fx.py", line 661, in _compile_fx_inner
    compiled_graph = FxGraphCache.load(
                     ^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_inductor/codecache.py", line 1324, in load
    compiled_graph = FxGraphCache._lookup_graph(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_inductor/codecache.py", line 1062, in _lookup_graph
    shape_env.evaluate_guards_expression(candidate.guards_expr, hints)
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/fx/experimental/symbolic_shapes.py", line 4266, in evaluate_guards_expression
    return eval(code, SYMPY_INTERP, {"L": dict(zip(arg_names, args))})
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<string>", line 1, in <module>
NameError: name 'OpaqueUnaryFn_sqrt' is not defined

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/mnt/pool/home/jma/Documents/sam2/tools/vos_inference.py", line 508, in <module>
    main()
  File "/mnt/pool/home/jma/Documents/sam2/tools/vos_inference.py", line 479, in main
    vos_inference(
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/sam2/tools/vos_inference.py", line 224, in vos_inference
    for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 57, in generator_context
    response = gen.send(request)
               ^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/sam2/sam2/sam2_video_predictor.py", line 603, in propagate_in_video
    current_out, pred_masks = self._run_single_frame_inference(
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/sam2/sam2/sam2_video_predictor.py", line 762, in _run_single_frame_inference
    current_out = self.track_step(
                  ^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/sam2/sam2/modeling/sam2_base.py", line 837, in track_step
    current_out, sam_outputs, _, _ = self._track_step(
                                     ^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/sam2/sam2/modeling/sam2_base.py", line 763, in _track_step
    pix_feat = self._prepare_memory_conditioned_features(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/sam2/sam2/modeling/sam2_base.py", line 669, in _prepare_memory_conditioned_features
    pix_feat_with_mem = self.memory_attention(
                        ^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 1269, in __call__
    return self._torchdynamo_orig_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 526, in __call__
    return _compile(
           ^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 924, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 666, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_utils_internal.py", line 87, in wrapper_function
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 699, in _compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
    transformations(instructions, code_options)
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 219, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py", line 634, in transform
    tracer.run()
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2796, in run
    super().run()
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 983, in run
    while self.step():
          ^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 895, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2987, in RETURN_VALUE
    self._return(inst)
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py", line 2972, in _return
    self.output.compile_subgraph(
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1142, in compile_subgraph
    self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1369, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1416, in call_user_compiler
    return self._call_user_compiler(gm)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/pool/home/jma/Documents/mohammed/medsam2_mohammed/lib/python3.12/site-packages/torch/_dynamo/output_graph.py", line 1465, in _call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
NameError: name 'OpaqueUnaryFn_sqrt' is not defined

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

Also note, I am using ViT-tiny.

@chayryali
Copy link
Contributor

Ok, you will need to also change the cache for RoPE in self and cross-attention, here and here from [64, 64] to [32, 32] since the resolution has been changed to 512 from 1024.

@tonydavis629
Copy link

I changed the rope feat_sizes to [32, 32] at 1024 resolution but it seems to be repeating the triton benchmarking. This is with compile_image_encoder true. I am expecting to see the propagate in video performed but it doesn't seem to get to that point after several minutes.

Image encoder compilation is enabled. First forward pass will be slow.
Compiling all components for VOS setting. First time may be very slow.
frame loading (JPEG):  25%|████████████████████████████████████████████████████████████████████████████████████████▍                                                                                                                                                                                                                                                                    | 559/2205 [02:04<00:58, 28.08it/s]AUTOTUNE mm(65536x144, 144x576)
  mm 0.0504 ms 100.0% 
  triton_mm_58 0.0598 ms 84.2% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_61 0.0601 ms 83.9% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_59 0.0603 ms 83.5% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=8
  triton_mm_62 0.0649 ms 77.6% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_54 0.0700 ms 72.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_55 0.0706 ms 71.3% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=8
  triton_mm_51 0.0723 ms 69.7% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=4
  triton_mm_56 0.0733 ms 68.7% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_60 0.0740 ms 68.1% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 3.0024 seconds and 0.0044 seconds precompiling
frame loading (JPEG):  29%|██████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                                                                                                                                                                                                                      | 649/2205 [02:07<00:55, 27.90it/s]AUTOTUNE mm(16384x288, 288x1152)
  mm 0.0325 ms 100.0% 
  triton_mm_232 0.0352 ms 92.2% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_225 0.0412 ms 78.9% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_233 0.0413 ms 78.6% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_229 0.0418 ms 77.8% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_226 0.0430 ms 75.5% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
  triton_mm_230 0.0430 ms 75.5% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
  triton_mm_227 0.0448 ms 72.5% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_234 0.0461 ms 70.5% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=8
  triton_mm_231 0.0515 ms 63.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 2.2958 seconds and 0.0025 seconds precompiling
frame loading (JPEG):  34%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                                                                                                                                                                                                                        | 739/2205 [02:11<00:59, 24.82it/s]AUTOTUNE convolution(1x144x256x256, 256x144x1x1)
  convolution 0.0319 ms 100.0% 
  triton_convolution2d_3755 0.0404 ms 78.8% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=1, STRIDE_W=1, UNROLL=True, num_stages=2, num_warps=8
  triton_convolution2d_3753 0.0413 ms 77.1% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=1, STRIDE_W=1, UNROLL=True, num_stages=2, num_warps=8
  triton_convolution2d_3756 0.0415 ms 76.8% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=1, STRIDE_W=1, UNROLL=True, num_stages=2, num_warps=8
  triton_convolution2d_3754 0.0415 ms 76.7% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=1, STRIDE_W=1, UNROLL=True, num_stages=2, num_warps=4
  triton_convolution2d_3750 0.0420 ms 75.8% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=1, STRIDE_W=1, UNROLL=True, num_stages=2, num_warps=4
  triton_convolution2d_3751 0.0453 ms 70.3% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=1, STRIDE_W=1, UNROLL=True, num_stages=2, num_warps=4
  triton_convolution2d_3752 0.0634 ms 50.3% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=1024, BLOCK_N=16, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=1, STRIDE_W=1, UNROLL=True, num_stages=1, num_warps=8
  conv1x1_via_mm 0.1655 ms 19.3% 
SingleProcess AUTOTUNE benchmarking takes 1.0881 seconds and 0.0022 seconds precompiling
frame loading (JPEG):  37%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                                                                                                                                                                                          | 826/2205 [02:15<00:49, 27.92it/s]AUTOTUNE mm(4096x576, 576x2304)
  mm 0.0268 ms 100.0% 
  triton_mm_707 0.0302 ms 89.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_708 0.0311 ms 86.3% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_702 0.0325 ms 82.7% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_709 0.0338 ms 79.5% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8
  triton_mm_700 0.0371 ms 72.5% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_704 0.0379 ms 70.9% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_701 0.0379 ms 70.9% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
  triton_mm_705 0.0382 ms 70.2% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
  triton_mm_698 0.0428 ms 62.7% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 2.3036 seconds and 0.0049 seconds precompiling
frame loading (JPEG):  54%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋                                                                                                                                                                | 1189/2205 [02:29<00:39, 25.60it/s]AUTOTUNE convolution(1x3x1024x1024, 144x3x7x7)
  triton_convolution2d_6 0.0975 ms 100.0% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=7, KERNEL_W=7, PADDING_H=3, PADDING_W=3, STRIDE_H=4, STRIDE_W=4, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_1 0.1098 ms 88.8% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=7, KERNEL_W=7, PADDING_H=3, PADDING_W=3, STRIDE_H=4, STRIDE_W=4, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_3 0.1505 ms 64.8% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=7, KERNEL_W=7, PADDING_H=3, PADDING_W=3, STRIDE_H=4, STRIDE_W=4, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_4 0.1576 ms 61.8% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=7, KERNEL_W=7, PADDING_H=3, PADDING_W=3, STRIDE_H=4, STRIDE_W=4, UNROLL=False, num_stages=2, num_warps=4
  convolution 0.1861 ms 52.4% 
  triton_convolution2d_5 0.2380 ms 41.0% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=7, KERNEL_W=7, PADDING_H=3, PADDING_W=3, STRIDE_H=4, STRIDE_W=4, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_0 0.2880 ms 33.8% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=7, KERNEL_W=7, PADDING_H=3, PADDING_W=3, STRIDE_H=4, STRIDE_W=4, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_2 0.3750 ms 26.0% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=1024, BLOCK_N=16, GROUPS=1, KERNEL_H=7, KERNEL_W=7, PADDING_H=3, PADDING_W=3, STRIDE_H=4, STRIDE_W=4, UNROLL=False, num_stages=1, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.9860 seconds and 0.0011 seconds precompiling
frame loading (JPEG): 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2205/2205 [03:06<00:00, 11.81it/s]
AUTOTUNE mm(1024x1152, 1152x4608)
  mm 0.0285 ms 100.0% 
  triton_mm_3462 0.0285 ms 99.9% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_3463 0.0296 ms 96.2% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_3457 0.0303 ms 94.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_3464 0.0307 ms 92.8% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8
  triton_mm_3458 0.0353 ms 80.8% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=4
  triton_mm_3456 0.0363 ms 78.6% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
  triton_mm_3455 0.0381 ms 74.7% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_3459 0.0404 ms 70.7% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_3453 0.0410 ms 69.6% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 2.2929 seconds and 48.8446 seconds precompiling
AUTOTUNE mm(16384x1152, 1152x288)
  triton_mm_633 0.0379 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8
  triton_mm_626 0.0385 ms 98.3% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  mm 0.0391 ms 96.7% 
  triton_mm_632 0.0400 ms 94.7% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_631 0.0424 ms 89.2% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_627 0.0453 ms 83.5% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=4
  triton_mm_622 0.0471 ms 80.4% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
  triton_mm_628 0.0480 ms 78.9% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_629 0.0490 ms 77.3% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
  triton_mm_625 0.0509 ms 74.4% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 2.3090 seconds and 0.0030 seconds precompiling

With compile_image_encoder: true, I get the warning

/.venv/lib/python3.11/site-packages/torch/_inductor/lowering.py:1713: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.

and indeed no progress in video mask propagation.

@chayryali
Copy link
Contributor

@tonydavis629 Changing feat_sizes to [32, 32] was appropriate for @MohammedSB since their resolution is 512. [64, 64] is correct for a resolution of 1024.

Compilation for the first time can take several minutes, this is not unexpected. It will be cached locally for later use though.

@tonydavis629
Copy link

Ok, it did take about 5 minutes, thanks.

What is the appropriate feat_sizes for a resolution of 2048 and 4096? 128 did not work for 2048 (same torch._dynamo.exc.InternalTorchDynamoError: RuntimeError: Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run error as above).

@chayryali
Copy link
Contributor

@tonydavis629 Did you also change the image resolution here? Seems changing resolution is of interest, I'll push an update to automate these settings.

@tonydavis629
Copy link

Nice, that works! Thanks so much.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants