Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix barrier insertion after assert op #5114

Merged
merged 3 commits into from
Nov 12, 2024

Conversation

anmyachev
Copy link
Contributor

This will fix the following problem:

python: /home/runner/work/triton/triton/llvm-project/llvm/include/llvm/ADT/ilist_iterator.h:168: llvm::ilist_iterator::reference llvm::ilist_iterator<llvm::ilist_detail::node_options<mlir::Operation, true, false, void, false, void>, false, false>::operator*() const [OptionsT = llvm::ilist_detail::node_options<mlir::Operation, true, false, void, false, void>, IsReverse = false, IsConst = false]: Assertion `!NodePtr->isKnownSentinel()' failed.
Aborted (core dumped)

The problem was found when using PyTorch on Intel gpu:

Simplified reproducer #1:
from torch._inductor.async_compile import AsyncCompile
async_compile = AsyncCompile()

triton_per_fused_add_embedding_native_layer_norm_0 = async_compile.triton('triton_per_fused_add_embedding_native_layer_norm_0', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.persistent_reduction(
    size_hints=[512, 128],
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'in_ptr4': '*fp32', 'in_ptr5': '*fp32', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'rnumel': 'i32'}, 'device': DeviceProperties(type='xpu', index=0, cc={'driver_version': '1.3.30049', 'gpu_eu_count': 448, 'gpu_subslice_count': 56, 'has_atomic64': True, 'has_bfloat16_conversions': True, 'has_fp16': True, 'has_fp64': True, 'has_subgroup_2d_block_io': True, 'has_subgroup_matrix_multiply_accumulate': True, 'has_subgroup_matrix_multiply_accumulate_tensor_float32': False, 'max_compute_units': 448, 'max_num_sub_groups': 64, 'max_work_group_size': 1024, 'name': 'Intel(R) Data Center GPU Max 1100', 'platform_name': 'Intel(R) Level-Zero', 'sub_group_sizes': [16, 32], 'total_memory': 51539607552, 'type': 'gpu', 'vendor': 'Intel(R) Corporation', 'version': '1.3'}, major=None, regs_per_multiprocessor=None, max_threads_per_multi_processor=None, multi_processor_count=None, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7, 8), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_add_embedding_native_layer_norm_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 5, 'num_reduction': 4, 'backend_hash': 'D82C2E8E2C9203D653D1A2B8A0511701E4F7567A195A5128E03B9AA7218348AA', 'are_deterministic_algorithms_enabled': True, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_per_fused_add_embedding_native_layer_norm_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr):
    xnumel = 512
    rnumel = 128
    RBLOCK: tl.constexpr = 128
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rindex = tl.arange(0, RBLOCK)[None, :]
    roffset = 0
    rmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
    x0 = xindex
    r1 = rindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask, eviction_policy='evict_last')
    tmp7 = tl.load(in_ptr2 + (r1 + (128*x0)), xmask, other=0.0)
    tmp9 = tl.load(in_ptr3 + (r1 + (128*x0)), xmask, other=0.0)
    tmp34 = tl.load(in_ptr4 + (r1), None, eviction_policy='evict_last')
    tmp36 = tl.load(in_ptr5 + (r1), None, eviction_policy='evict_last')
    tmp1 = tl.full([XBLOCK, RBLOCK], 30000, tl.int32)
    tmp2 = tmp0 + tmp1
    tmp3 = tmp0 < 0
    tmp4 = tl.where(tmp3, tmp2, tmp0)
    tl.device_assert(((0 <= tmp4) & (tmp4 < 30000)) | ~(xmask), "index out of bounds: 0 <= tmp4 < 30000")
''', device_str='xpu')

@anmyachev anmyachev requested a review from ptillet as a code owner November 11, 2024 14:53
Copy link
Contributor

@peterbell10 peterbell10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a regression test

Signed-off-by: Anatoly Myachev <anatoly.myachev@intel.com>
Signed-off-by: Anatoly Myachev <anatoly.myachev@intel.com>
@anmyachev
Copy link
Contributor Author

Please add a regression test

done

python/test/unit/test_debug.py Show resolved Hide resolved
@peterbell10 peterbell10 enabled auto-merge (squash) November 12, 2024 17:40
@peterbell10 peterbell10 merged commit 3e359b3 into triton-lang:main Nov 12, 2024
7 checks passed
@anmyachev anmyachev deleted the assert-op branch November 12, 2024 18:08
@@ -34,6 +33,20 @@ def _kernel(COND: tl.constexpr):
getattr(torch, device).synchronize()


def test_device_assert_barrier(monkeypatch, device):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general such bugs transformations it is better to do lit tests to test specifically the case that cause the bug

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ThomasRaoux Thanks for the advice!

Should I try writing a lit test now?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be nice to replace this with a lit if you have a chance

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that this issue is somehow related to multithreading. I have a lit test (unfortunately only for the XPU backend) and if I also specify --mlir-disable-threading, the issue goes away. Also, it seems that this is still a problem of LLVM, could you give some advice on how to rewrite the reproducer in order to fill out an issue for them? Or maybe there are some ideas on what can be done about this?

Lit test
// RUN: triton-opt %s --convert-triton-intel-gpu-to-llvm

// -----

#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
// CHECK: llvm.call @__assertfail
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 32 : i32} {
  tt.func public @_kernel_qwerty(%arg0: !tt.ptr<i32>) {
    %cst = arith.constant dense<1> : tensor<8xi32, #blocked>
    %0 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32, #blocked>
    %1 = tt.splat %arg0 : !tt.ptr<i32> -> tensor<8x!tt.ptr<i32>, #blocked>
    %2 = tt.addptr %1, %0 : tensor<8x!tt.ptr<i32>, #blocked>, tensor<8xi32, #blocked>
    %3 = tt.load %2 : tensor<8x!tt.ptr<i32>, #blocked>
    %4 = arith.cmpi slt, %3, %cst : tensor<8xi32, #blocked>
    tt.assert %4, "" : tensor<8xi1, #blocked>
    tt.return
  }
}

Stack trace
RUN: at line 1: .../intel-xpu-backend-for-triton/python/build/cmake.linux-x86_64-cpython-3.10/bin/triton-opt .../intel-xpu-backend-for-triton/test/Conversion/intel/tritonintelgpu_to_llvm.mlir --convert-triton-intel-gpu-to-llvm
+ .../intel-xpu-backend-for-triton/python/build/cmake.linux-x86_64-cpython-3.10/bin/triton-opt .../intel-xpu-backend-for-triton/test/Conversion/intel/tritonintelgpu_to_llvm.mlir --convert-triton-intel-gpu-to-llvm
triton-opt: /home/runner/work/triton/triton/llvm-project/llvm/include/llvm/ADT/ilist_iterator.h:168: llvm::ilist_iterator::reference llvm::ilist_iterator<llvm::ilist_detail::node_options<mlir::Operation, true, false, void, false, void>, false, false>::operator*() const [OptionsT = llvm::ilist_detail::node_options<mlir::Operation, true, false, void, false, void>, IsReverse = false, IsConst = false]: Assertion `!NodePtr->isKnownSentinel()' failed.
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
 #0 0x000055d5cf9bf447 llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) (.../intel-xpu-backend-for-triton/python/build/cmake.linux-x86_64-cpython-3.10/bin/triton-opt+0x312f447)
 #1 0x000055d5cf9bcf6e llvm::sys::RunSignalHandlers() (.../intel-xpu-backend-for-triton/python/build/cmake.linux-x86_64-cpython-3.10/bin/triton-opt+0x312cf6e)
 #2 0x000055d5cf9bfaff SignalHandler(int) Signals.cpp:0:0
 #3 0x00007f25ba794520 (/lib/x86_64-linux-gnu/libc.so.6+0x42520)
 #4 0x00007f25ba7e89fc __pthread_kill_implementation ./nptl/pthread_kill.c:44:76
 #5 0x00007f25ba7e89fc __pthread_kill_internal ./nptl/pthread_kill.c:78:10
 #6 0x00007f25ba7e89fc pthread_kill ./nptl/pthread_kill.c:89:10
 #7 0x00007f25ba794476 gsignal ./signal/../sysdeps/posix/raise.c:27:6
 #8 0x00007f25ba77a7f3 abort ./stdlib/abort.c:81:7
 #9 0x00007f25ba77a71b _nl_load_domain ./intl/loadmsgcat.c:1177:9
#10 0x00007f25ba78be96 (/lib/x86_64-linux-gnu/libc.so.6+0x39e96)
#11 0x000055d5cf90c137 mlir::Operation::updateOrderIfNecessary() (.../intel-xpu-backend-for-triton/python/build/cmake.linux-x86_64-cpython-3.10/bin/triton-opt+0x307c137)
#12 0x000055d5cf90bf7f mlir::Operation::isBeforeInBlock(mlir::Operation*) (.../intel-xpu-backend-for-triton/python/build/cmake.linux-x86_64-cpython-3.10/bin/triton-opt+0x307bf7f)
#13 0x000055d5cf8e4222 mlir::DominanceInfo::properlyDominates(mlir::Value, mlir::Operation*) const (.../intel-xpu-backend-for-triton/python/build/cmake.linux-x86_64-cpython-3.10/bin/triton-opt+0x3054222)
#14 0x000055d5cf92b224 (anonymous namespace)::OperationVerifier::verifyOpAndDominance(mlir::Operation&) Verifier.cpp:0:0
#15 0x000055d5cf92c9be std::_Function_handler<void (), llvm::LogicalResult mlir::failableParallelForEach<mlir::Operation**, (anonymous namespace)::OperationVerifier::verifyOnExit(mlir::Operation&)::$_3>(mlir::MLIRContext*, mlir::Operation**, mlir::Operation**, (anonymous namespace)::OperationVerifier::verifyOnExit(mlir::Operation&)::$_3&&)::'lambda'()>::_M_invoke(std::_Any_data const&) Verifier.cpp:0:0
#16 0x000055d5cea285e8 std::_Function_handler<std::unique_ptr<std::__future_base::_Result_base, std::__future_base::_Result_base::_Deleter> (), std::__future_base::_Task_setter<std::unique_ptr<std::__future_base::_Result<void>, std::__future_base::_Result_base::_Deleter>, std::thread::_Invoker<std::tuple<std::function<void ()>>>, void>>::_M_invoke(std::_Any_data const&) (.../intel-xpu-backend-for-triton/python/build/cmake.linux-x86_64-cpython-3.10/bin/triton-opt+0x21985e8)
#17 0x000055d5cea28547 std::__future_base::_State_baseV2::_M_do_set(std::function<std::unique_ptr<std::__future_base::_Result_base, std::__future_base::_Result_base::_Deleter> ()>*, bool*) (.../intel-xpu-backend-for-triton/python/build/cmake.linux-x86_64-cpython-3.10/bin/triton-opt+0x2198547)
#18 0x00007f25ba7ebee8 __pthread_once_slow ./nptl/pthread_once.c:118:7
#19 0x000055d5cea288fb std::__future_base::_Deferred_state<std::thread::_Invoker<std::tuple<std::function<void ()>>>, void>::_M_complete_async() (.../intel-xpu-backend-for-triton/python/build/cmake.linux-x86_64-cpython-3.10/bin/triton-opt+0x21988fb)
#20 0x000055d5cea289a9 void std::__invoke_impl<void, std::shared_future<void> llvm::ThreadPoolInterface::asyncImpl<void>(std::function<void ()>, llvm::ThreadPoolTaskGroup*)::'lambda'()&>(std::__invoke_other, std::shared_future<void> llvm::ThreadPoolInterface::asyncImpl<void>(std::function<void ()>, llvm::ThreadPoolTaskGroup*)::'lambda'()&) .../intel-xpu-backend-for-triton/python/build/cmake.linux-x86_64-cpython-3.10/bin/triton-opt+0x21989a9)
#21 0x000055d5cf99c1db llvm::StdThreadPool::processTasks(llvm::ThreadPoolTaskGroup*) (.../intel-xpu-backend-for-triton/python/build/cmake.linux-x86_64-cpython-3.10/bin/triton-opt+0x310c1db)
#22 0x000055d5cf99d4d7 void* llvm::thread::ThreadProxy<std::tuple<llvm::StdThreadPool::grow(int)::$_0>>(void*) ThreadPool.cpp:0:0
#23 0x00007f25ba7e6ac3 start_thread ./nptl/pthread_create.c:442:8
#24 0x00007f25ba878850 ./misc/../sysdeps/unix/sysv/linux/x86_64/clone3.S:83:0

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm guessing the fact that it goes away without multi-threading is just that there is some memory corruption and it passes by luck otherwise. Is the barrier inserted at the right spot in this case? Checking that should be enough

If you think it is an issue in MLIR core you can file an issue in llvm github: https://github.com/llvm/llvm-project

peterbell10 pushed a commit that referenced this pull request Nov 13, 2024
…d code (#5135)

I discovered it while I was looking at
#5114 (comment).

Looks like `bar_id` attribute is no longer used. The attribute setting
was removed in:
dd2a323.

Signed-off-by: Anatoly Myachev <anatoly.myachev@intel.com>
Luosuu pushed a commit to Luosuu/triton that referenced this pull request Nov 13, 2024
This will fix the following problem:
```bash
python: /home/runner/work/triton/triton/llvm-project/llvm/include/llvm/ADT/ilist_iterator.h:168: llvm::ilist_iterator::reference llvm::ilist_iterator<llvm::ilist_detail::node_options<mlir::Operation, true, false, void, false, void>, false, false>::operator*() const [OptionsT = llvm::ilist_detail::node_options<mlir::Operation, true, false, void, false, void>, IsReverse = false, IsConst = false]: Assertion `!NodePtr->isKnownSentinel()' failed.
Aborted (core dumped)
```

The problem was found when using PyTorch on Intel gpu:

<details>

<summary> Simplified reproducer triton-lang#1:</summary>

```python
from torch._inductor.async_compile import AsyncCompile
async_compile = AsyncCompile()

triton_per_fused_add_embedding_native_layer_norm_0 = async_compile.triton('triton_per_fused_add_embedding_native_layer_norm_0', '''
import triton
import triton.language as tl
from triton.compiler.compiler import AttrsDescriptor

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.persistent_reduction(
    size_hints=[512, 128],
    reduction_hint=ReductionHint.INNER,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*fp32', 'in_ptr2': '*fp32', 'in_ptr3': '*fp32', 'in_ptr4': '*fp32', 'in_ptr5': '*fp32', 'out_ptr2': '*fp32', 'xnumel': 'i32', 'rnumel': 'i32'}, 'device': DeviceProperties(type='xpu', index=0, cc={'driver_version': '1.3.30049', 'gpu_eu_count': 448, 'gpu_subslice_count': 56, 'has_atomic64': True, 'has_bfloat16_conversions': True, 'has_fp16': True, 'has_fp64': True, 'has_subgroup_2d_block_io': True, 'has_subgroup_matrix_multiply_accumulate': True, 'has_subgroup_matrix_multiply_accumulate_tensor_float32': False, 'max_compute_units': 448, 'max_num_sub_groups': 64, 'max_work_group_size': 1024, 'name': 'Intel(R) Data Center GPU Max 1100', 'platform_name': 'Intel(R) Level-Zero', 'sub_group_sizes': [16, 32], 'total_memory': 51539607552, 'type': 'gpu', 'vendor': 'Intel(R) Corporation', 'version': '1.3'}, major=None, regs_per_multiprocessor=None, max_threads_per_multi_processor=None, multi_processor_count=None, warp_size=32), 'constants': {}, 'configs': [AttrsDescriptor.from_dict({'arg_properties': {'tt.divisibility': (0, 1, 2, 3, 4, 5, 6, 7, 8), 'tt.equal_to': ()}, 'cls': 'AttrsDescriptor'})]},
    inductor_meta={'autotune_hints': set(), 'kernel_name': 'triton_per_fused_add_embedding_native_layer_norm_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 5, 'num_reduction': 4, 'backend_hash': 'D82C2E8E2C9203D653D1A2B8A0511701E4F7567A195A5128E03B9AA7218348AA', 'are_deterministic_algorithms_enabled': True, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False}
)
@triton.jit
def triton_per_fused_add_embedding_native_layer_norm_0(in_ptr0, in_ptr1, in_ptr2, in_ptr3, in_ptr4, in_ptr5, out_ptr2, xnumel, rnumel, XBLOCK : tl.constexpr):
    xnumel = 512
    rnumel = 128
    RBLOCK: tl.constexpr = 128
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
    xmask = xindex < xnumel
    rindex = tl.arange(0, RBLOCK)[None, :]
    roffset = 0
    rmask = tl.full([XBLOCK, RBLOCK], True, tl.int1)
    x0 = xindex
    r1 = rindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask, eviction_policy='evict_last')
    tmp7 = tl.load(in_ptr2 + (r1 + (128*x0)), xmask, other=0.0)
    tmp9 = tl.load(in_ptr3 + (r1 + (128*x0)), xmask, other=0.0)
    tmp34 = tl.load(in_ptr4 + (r1), None, eviction_policy='evict_last')
    tmp36 = tl.load(in_ptr5 + (r1), None, eviction_policy='evict_last')
    tmp1 = tl.full([XBLOCK, RBLOCK], 30000, tl.int32)
    tmp2 = tmp0 + tmp1
    tmp3 = tmp0 < 0
    tmp4 = tl.where(tmp3, tmp2, tmp0)
    tl.device_assert(((0 <= tmp4) & (tmp4 < 30000)) | ~(xmask), "index out of bounds: 0 <= tmp4 < 30000")
''', device_str='xpu')

```
</details>
Luosuu pushed a commit to Luosuu/triton that referenced this pull request Nov 13, 2024
…d code (triton-lang#5135)

I discovered it while I was looking at
triton-lang#5114 (comment).

Looks like `bar_id` attribute is no longer used. The attribute setting
was removed in:
triton-lang@dd2a323.

Signed-off-by: Anatoly Myachev <anatoly.myachev@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants