Skip to content

Commit

Permalink
[SYCL] Support both opaque and no-opaque modes for matrix tests (#7227)
Browse files Browse the repository at this point in the history
This PR is intended to fix matrix test failures caused by upcoming
pulldown: #7186
  • Loading branch information
againull authored Nov 2, 2022
1 parent 19e9941 commit bf06295
Show file tree
Hide file tree
Showing 8 changed files with 214 additions and 8 deletions.
33 changes: 32 additions & 1 deletion sycl/test/check_device_code/matrix/matrix-nvptx-bf16-test.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// REQUIRES: cuda

// RUN: %clangxx -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s
// RUN: %clangxx -Xclang -no-opaque-pointers -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s
// RUN: %clangxx -Xclang -opaque-pointers -fsycl-device-only -fsycl-targets=nvptx64-nvidia-cuda -Xsycl-target-backend --cuda-gpu-arch=sm_80 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=3 -S -Xclang -emit-llvm %s -o -| FileCheck %s --check-prefixes=CHECK-OPAQUE

#include <sycl/sycl.hpp>

Expand Down Expand Up @@ -42,14 +43,19 @@ int main() {
sub_b;

// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.f32.p1f32(float addrspace(1)* %_arg_accC, i32 16)
// CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.row.stride.f32.p1(ptr addrspace(1) %_arg_accC, i32 16)
joint_matrix_load(sg, sub_c, accC.get_pointer(), stride);
// CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16)
// CHECK-OPAQUE: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.a.row.stride.bf16.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16)
joint_matrix_load(sg, sub_a, accA.get_pointer(), stride);
// CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.b.row.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16)
// CHECK-OPAQUE: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.b.row.stride.bf16.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16)
joint_matrix_load(sg, sub_b, accB.get_pointer(), stride);
// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.row.row.bf16(i32 %11, i32 %12, i32 %13, i32 %14, i32 %17, i32 %18, i32 %19, i32 %20, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8)
// CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.row.row.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}})
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
// CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1f32(float addrspace(1)* %_arg_accD, float %22, float %23, float %24, float %25, float %26, float %27, float %28, float %29, i32 16)
// CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.row.stride.f32.p1(ptr addrspace(1) %_arg_accD, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16)
joint_matrix_store(sg, sub_c, accD.get_pointer(), stride);
});

Expand All @@ -71,14 +77,19 @@ int main() {
sub_b;

// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.f32.p1f32(float addrspace(1)* %_arg_accC, i32 16)
// CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.load.c.col.stride.f32.p1(ptr addrspace(1) %_arg_accC, i32 16)
joint_matrix_load(sg, sub_c, accC.get_pointer(), stride);
// CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.a.col.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16)
// CHECK-OPAQUE: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.a.col.stride.bf16.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16)
joint_matrix_load(sg, sub_a, accA.get_pointer(), stride);
// CHECK: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.b.col.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16)
// CHECK-OPAQUE: tail call { i32, i32, i32, i32 } @llvm.nvvm.wmma.m16n16k16.load.b.col.stride.bf16.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16)
joint_matrix_load(sg, sub_b, accB.get_pointer(), stride);
// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.col.col.bf16(i32 %11, i32 %12, i32 %13, i32 %14, i32 %17, i32 %18, i32 %19, i32 %20, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8)
// CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m16n16k16.mma.col.col.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}})
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
// CHECK: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1f32(float addrspace(1)* %_arg_accD, float %22, float %23, float %24, float %25, float %26, float %27, float %28, float %29, i32 16)
// CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m16n16k16.store.d.col.stride.f32.p1(ptr addrspace(1) %_arg_accD, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16)
joint_matrix_store(sg, sub_c, accD.get_pointer(), stride);
});

Expand All @@ -99,14 +110,19 @@ int main() {
sub_b;

// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.load.c.row.stride.f32.p1f32(float addrspace(1)* %_arg_accC, i32 16)
// CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.load.c.row.stride.f32.p1(ptr addrspace(1) %_arg_accC, i32 16)
joint_matrix_load(sg, sub_c, accC.get_pointer(), stride);
// CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.a.row.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16)
// CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.a.row.stride.bf16.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16)
joint_matrix_load(sg, sub_a, accA.get_pointer(), stride);
// CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.b.row.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16)
// CHECK-OPAQUE: tail call { i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.b.row.stride.bf16.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16)
joint_matrix_load(sg, sub_b, accB.get_pointer(), stride);
// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.row.row.bf16(i32 %11, i32 %12, i32 %13, i32 %14, i32 %15, i32 %16, i32 %17, i32 %18, i32 %21, i32 %22, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8)
// CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.row.row.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}})
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
// CHECK: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f32.p1f32(float addrspace(1)* %_arg_accD, float %24, float %25, float %26, float %27, float %28, float %29, float %30, float %31, i32 16)
// CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.row.stride.f32.p1(ptr addrspace(1) %_arg_accD, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16)
joint_matrix_store(sg, sub_c, accD.get_pointer(), stride);
});

Expand All @@ -127,14 +143,19 @@ int main() {
sub_b;

// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.load.c.col.stride.f32.p1f32(float addrspace(1)* %_arg_accC, i32 16)
// CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.load.c.col.stride.f32.p1(ptr addrspace(1) %_arg_accC, i32 16)
joint_matrix_load(sg, sub_c, accC.get_pointer(), stride);
// CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.a.col.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16)
// CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.a.col.stride.bf16.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16)
joint_matrix_load(sg, sub_a, accA.get_pointer(), stride);
// CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.b.col.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16)
// CHECK-OPAQUE: tail call { i32, i32 } @llvm.nvvm.wmma.m32n8k16.load.b.col.stride.bf16.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16)
joint_matrix_load(sg, sub_b, accB.get_pointer(), stride);
// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.col.col.bf16(i32 %11, i32 %12, i32 %13, i32 %14, i32 %15, i32 %16, i32 %17, i32 %18, i32 %21, i32 %22, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8)
// CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m32n8k16.mma.col.col.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}})
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
// CHECK: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f32.p1f32(float addrspace(1)* %_arg_accD, float %24, float %25, float %26, float %27, float %28, float %29, float %30, float %31, i32 16)
// CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m32n8k16.store.d.col.stride.f32.p1(ptr addrspace(1) %_arg_accD, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16)
joint_matrix_store(sg, sub_c, accD.get_pointer(), stride);
});

Expand All @@ -155,14 +176,19 @@ int main() {
sub_b;

// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.load.c.row.stride.f32.p1f32(float addrspace(1)* %_arg_accC, i32 16)
// CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.load.c.row.stride.f32.p1(ptr addrspace(1) %_arg_accC, i32 16)
joint_matrix_load(sg, sub_c, accC.get_pointer(), stride);
// CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.a.row.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16)
// CHECK-OPAQUE: tail call { i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.a.row.stride.bf16.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16)
joint_matrix_load(sg, sub_a, accA.get_pointer(), stride);
// CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.b.row.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16)
// CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.b.row.stride.bf16.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16)
joint_matrix_load(sg, sub_b, accB.get_pointer(), stride);
// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.row.row.bf16(i32 %11, i32 %12, i32 %15, i32 %16, i32 %17, i32 %18, i32 %19, i32 %20, i32 %21, i32 %22, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8)
// CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.row.row.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}})
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
// CHECK: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f32.p1f32(float addrspace(1)* %_arg_accD, float %24, float %25, float %26, float %27, float %28, float %29, float %30, float %31, i32 16)
// CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.row.stride.f32.p1(ptr addrspace(1) %_arg_accD, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16)
joint_matrix_store(sg, sub_c, accD.get_pointer(), stride);
});

Expand All @@ -183,14 +209,19 @@ int main() {
sub_b;

// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.load.c.col.stride.f32.p1f32(float addrspace(1)* %_arg_accC, i32 16)
// CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.load.c.col.stride.f32.p1(ptr addrspace(1) %_arg_accC, i32 16)
joint_matrix_load(sg, sub_c, accC.get_pointer(), stride);
// CHECK: tail call { i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.a.col.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16)
// CHECK-OPAQUE: tail call { i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.a.col.stride.bf16.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16)
joint_matrix_load(sg, sub_a, accA.get_pointer(), stride);
// CHECK: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.b.col.stride.bf16.p0i32(i32* %call.ascast.i.i{{.*}}.i, i32 16)
// CHECK-OPAQUE: tail call { i32, i32, i32, i32, i32, i32, i32, i32 } @llvm.nvvm.wmma.m8n32k16.load.b.col.stride.bf16.p0(ptr %call.ascast.i.i{{.*}}.i, i32 16)
joint_matrix_load(sg, sub_b, accB.get_pointer(), stride);
// CHECK: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.col.col.bf16(i32 %11, i32 %12, i32 %15, i32 %16, i32 %17, i32 %18, i32 %19, i32 %20, i32 %21, i32 %22, float %1, float %2, float %3, float %4, float %5, float %6, float %7, float %8)
// CHECK-OPAQUE: tail call { float, float, float, float, float, float, float, float } @llvm.nvvm.wmma.m8n32k16.mma.col.col.bf16(i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}})
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
// CHECK: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f32.p1f32(float addrspace(1)* %_arg_accD, float %24, float %25, float %26, float %27, float %28, float %29, float %30, float %31, i32 16)
// CHECK-OPAQUE: tail call void @llvm.nvvm.wmma.m8n32k16.store.d.col.stride.f32.p1(ptr addrspace(1) %_arg_accD, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, float {{.*}}, i32 16)
joint_matrix_store(sg, sub_c, accD.get_pointer(), stride);
});
});
Expand Down
Loading

0 comments on commit bf06295

Please sign in to comment.