Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Help with gemm! #738

Closed
willtebbutt opened this issue Apr 18, 2023 · 4 comments
Closed

Help with gemm! #738

willtebbutt opened this issue Apr 18, 2023 · 4 comments

Comments

@willtebbutt
Copy link

I've been trying to get an implementation of LinearAlgebra.BLAS.gemm! working, and struggling to complete the implementation, and opening an issue seemed like a better option than a discussion on slack or discourse.

I'm not worrying about the mathematical details for now, just trying to get the thing to run.

This is my attempt thus far:

using Enzyme
using .EnzymeRules
using LinearAlgebra

function EnzymeRules.augmented_primal(
    config::ConfigWidth{1},
    ::Const{typeof(BLAS.gemm!)},
    ::Type{<:Const},
    transA::Const{<:AbstractChar},
    transB::Const{<:AbstractChar},
    alpha::Const,
    A::Duplicated{<:AbstractVecOrMat{T}},
    B::Duplicated{<:AbstractVecOrMat{T}},
    beta::Const,
    C::Duplicated{<:AbstractVecOrMat{T}},
) where {T<:Union{Float32, Float64}}
    println("in the forwards-pass")
    tape = (copy(A.val), copy(B.val), C.dval)
    BLAS.gemm!(transA.val, transB.val, alpha.val, A.val, B.val, beta.val, C.val)
    primal = needs_primal(config) ? C.val : nothing
    shadow = needs_shadow(config) ? C.dval : nothing
    @show needs_primal(config), needs_shadow(config)
    return AugmentedReturn(primal, shadow, tape)
end

function EnzymeRules.reverse(
    config::ConfigWidth{1},
    ::Const{typeof(BLAS.gemm!)},
    ::Type{<:Const},
    tape,
    transA::Const{<:AbstractChar},
    transB::Const{<:AbstractChar},
    alpha::Const,
    A::Duplicated{<:AbstractVecOrMat{T}},
    B::Duplicated{<:AbstractVecOrMat{T}},
    beta::Const,
    C::Duplicated{<:AbstractVecOrMat{T}},
) where {T<:Union{Float32, Float64}}
    println("In the reverse-pass")
    B.dval .= 1.0 # dummy implementation to see what happens
    return (nothing, nothing, nothing, nothing, nothing, nothing, nothing)
end

When I attempt to compute a pullback:

D = 5;
A = Duplicated(randn(D, D), zeros(D, D));
B = Duplicated(randn(D, 2D), zeros(D, 2D));
C = Duplicated(zeros(D, 2D), zeros(D, 2D));
autodiff(Reverse, BLAS.gemm!, Const, 'N', 'N', true, A, B, false, C)

I get the following error:

┌ Warning: Using fallback BLAS replacements, performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YU9ZW/src/utils.jl:50
┌ Error: ("Calling convention mismatch", [3 x {} addrspace(10)*] addrspace(11)*,   %88 = extractvalue { [3 x {} addrspace(10)*] } %87, 0, 1, define void @julia_reverse_2479([3 x {} addrspace(10)*] addrspace(11)* nocapture nofree noundef nonnull readnone align 8 dereferenceable(24) %0, [1 x i32] addrspace(11)* nocapture nofree noundef nonnull readnone align 4 dereferenceable(4) %1, [1 x i32] addrspace(11)* nocapture nofree noundef nonnull readnone align 4 dereferenceable(4) %2, [1 x i8] addrspace(11)* nocapture nofree noundef nonnull readnone align 1 dereferenceable(1) %3, [2 x {} addrspace(10)*] addrspace(11)* nocapture nofree noundef nonnull readnone align 8 dereferenceable(16) %4, [2 x {} addrspace(10)*] addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(16) %5, [1 x i8] addrspace(11)* nocapture nofree noundef nonnull readnone align 1 dereferenceable(1) %6, [2 x {} addrspace(10)*] addrspace(11)* nocapture nofree noundef nonnull readnone ain the forwards-passlign 8 dereferenceable(16) %7) #21 !dbg !423 {
│ top:%8 = call {}*** @julia.get_pgcstack()
│   call void @julia_println_2482() #21, !dbg !424%9 = getelementptr inbounds [2 x {} addrspace(10)*], [2 x {} addrspace(10)*] addrspace(11)* %5, i64 0, i64 1, !dbg !425
│   %10 = load atomic {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %9 unordered, align 8, !dbg !425, !tbaa !126, !invariant.load !29, !alias.scope !129, !noalias !130, !nonnull !29, !dereferenceable !247, !align !248
│   %11 = bitcast {} addrspace(10)* %10 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)*, !dbg !430
│   %12 = addrspacecast { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(10)* %11 to { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)*, !dbg !430
│   %13 = getelementptr inbounds { i8 addrspace(13)*, i64, i16, i1
6, i32 }, { i8 addrspace(13)*, i64, i16, i16, i32 } addrspace(11)* %12, i64 0, i32 1, !dbg !430
│   %14 = load i64, i64 addrspace(11)* %13, align 8, !dbg !430, !tbaa !126, !range !128, !invariant.load !29, !alias.scope !129, !noalias !130
│   %.not.not = icmp eq i64 %14, 0, !dbg !444
│   br i1 %.not.not, label %L33, label %L15.preheader, !dbg !436
│ 
│ L15.preheader:                                    ; preds = %top
│   %15 = bitcast {} addrspace(10)* %10 to double addrspace(13)* addrspace(10)*%16 = addrspacecast double addrspace(13)* addrspace(10)* %15 to double addrspace(13)* addrspace(11)*%17 = load double addrspace(13)*, double addrspace(13)* addrspace(11)* %16, align 16, !tbaa !126, !invariant.load !29, !alias.scope !453, !noalias !130, !nonnull !29
│   br label %L15, !dbg !456
│ 
│ L15:                                              ; preds = %L15.preheader, %L15
│   %value_phi3 = phi i64 [ %20, %L15 ], [ 1, %L15.preheader ]
│   %18 = add nsw i64 %value_phi3, -1, !dbg !457
│   %19 = getelementptr inbounds double, double addrspace(13)* %17, i64 %18, !dbg !457
│   store double 1.000000e+00, double addrspace(13)* %19, align 8, !dbg !457, !tbaa !460, !alias.scope !37, !noalias !462
│   %.not.not9 = icmp eq i64 %value_phi3, %14, !dbg !463
│   %20 = add nuw nsw i64 %value_phi3, 1, !dbg !465
│   br i1 %.not.not9, label %L33, label %L15, !dbg !456
│ 
│ L33:                                              ; preds = %L15, %top
│   ret void, !dbg !466
│ }
│ , Tuple{ConfigWidth{1, false, false, (false, false, false,(needs_primal(config), needs_shadow(config)) =  false, false, false, false, false)}, Const{typeof(LinearAlgebra.BLAS.gemm!)}, Type{Const{Matrix{Float64}}}, Const{Char}, Const{Char}, Const{Bool}, Duplicated{Matrix{Float64}}, Duplicated{Matrix{Float64}}, Const{Bool}, Duplicated{Matrix{Float64}}}, Tuple{ConfigWidth{1, false, false, (false, false, false, false, false, false, false, false)}, Const{typeof(LinearAlgebra.BLAS.gemm!)}, Type{Const{Matrix{Float64}}}, Tuple{Matrix{Float64}, Matrix{Float64}, Matrix{Float64}}, Const{Char}, Const{Char}, Const{Bool}, Duplicated{Matrix{Float64}}, Duplicated{Matrix{Float64}}, Const{Bool}, Duplicated{Matrix{Float64}}}, ; Function Attrs: mustprogress willreturn
│ define internal noundef nonnull align 16 dereferenceable(40) void @diffejulia_gemm__1274mustwrap_inner.8(i32 %0, i32 %1, i8 %2, {} addrspace(10)* noundef nonnull align 16 dereferenceable(40) %3, {} addrspace(10)* %"'", {} addrspace(10)* noundef nonnull align 16 dereferenceable(40) %4, {} addrspace(10)* %"'1", i8 %5, {} addrspace(10)* noundef nonnull returned align 16 dereferenceable(40) %6, {} addrspace(10)* %"'2") local_unnamed_addr #20 {
│ entry:%7 = call {}*** @julia.get_pgcstack()
│   %8 = call {}*** @julia.get_pgcstack()
│   %9 = call {}*** @julia.get_pgcstack()
│   %10 = call {}*** @julia.get_pgcstack()
│   %11 = call {}*** @julia.get_pgcstack()
│   %12 = call {}*** @julia.get_pgcstack()
│   %13 = call {}*** @julia.get_pgcstack()
│   %14 = call {}*** @julia.get_pgcstack()
│   %15 = call {}*** @julia.get_pgcstack()
│   %16 = call {}*** @julia.get_pgcstack()
│   %17 = call {}*** @julia.get_pgcstack()
│   %18 = call {}*** @julia.get_pgcstack()
│   %19 = call {}*** @julia.get_pgcstack()
(false, false)│   %20 = call {}*** @julia.get_pgcstack()
│   %21 = bitcast {}*** %20 to {}**%22 = getelementptr inbounds {}*, {}** %21, i64 -13%23 = getelementptr inbounds {}*, {}** %22, i64 15%24 = bitcast {}** %23 to i8**%25 = load i8*, i8** %24, align 8%26 = call {} addrspace(10)* @julia.gc_alloc_obj({}** %22, i64 4, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 5710854928 to {}*) to {} addrspace(10)*))
│   %27 = bitcast {} addrspace(10)* %26 to [1 x i32] addrspace(10)*%28 = addrspacecast [1 x i32] addrspace(10)* %27 to [1 x i32] addrspace(11)*%29 = getelementptr [1 x i32], [1 x i32] addrspace(11)* %28, i64 0, i32 0
│   store i32 %0, i32 addrspace(11)* %29, align 4%30 = bitcast {}*** %19 to {}**%31 = getelementptr inbounds {}*, {}** %30, i64 -13%32 = getelementptr inbounds {}*, {}** %31, i64 15%33 = bitcast {}** %32 to i8**%34 = load i8*, i8** %33, align 8%35 = call {} addrspace(10)* @julia.gc_alloc_obj({}** %31, i64 4, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 5710854928 to {}*) to {} addrspace(10)*))
│   %36 = bitcast {} addrspace(10)* %35 to [1 x i32] addrspace(10)*%37 = addrspacecast [1 x i32] addrspace(10)* %36 to [1 x i32] addrspace(11)*%38 = getelementptr [1 x i32], [1 x i32] addrspace(11)* %37, i64 0, i32 0
│   store i32 %1, i32 addrspace(11)* %38, align 4%39 = bitcast {}*** %18 to {}**%40 = getelementptr inbounds {}*, {}** %39, i64 -13%41 = getelementptr inbounds {}*, {}** %40, i64 15%42 = bitcast {}** %41 to i8**%43 = load i8*, i8** %42, align 8%44 = call {} addrspace(10)* @julia.gc_alloc_obj({}** %40, i64 1, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 5710318800 to {}*) to {} addrspace(10)*))
│   %45 = bitcast {} addrspace(10)* %44 to [1 x i8] addrspace(10)*%46 = addrspacecast [1 x i8] addrspace(10)* %45 to [1 x i8] addrspace(11)*%47 = getelementptr [1 x i8], [1 x i8] addrspace(11)* %46, i64 0, i32 0
│   store i8 %2, i8 addrspace(11)* %47, align 1%48 = bitcast {}*** %17 to {}**%49 = getelementptr inbounds {}*, {}** %48, i64 -13%50 = getelementptr inbounds {}*, {}** %49, i64 15%51 = bitcast {}** %50 to i8**%52 = load i8*, i8** %51, align 8%53 = call {} addrspace(10)* @julia.gc_alloc_obj({}** %49, i64 16, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 6179224592 to {}*) to {} addrspace(10)*))
│   %54 = bitcast {} addrspace(10)* %53 to [2 x {} addrspace(10)*] addrspace(10)*%55 = addrspacecast [2 x {} addrspace(10)*] addrspace(10)* %54 to [2 x {} addrspace(10)*] addrspace(11)*%56 = getelementptr [2 x {} addrspace(10)*], [2 x {} addrspace(10)*] addrspace(11)* %55, i64 0, i32 0
│   store {} addrspace(10)* %3, {} addrspace(10)* addrspace(11)* %56, align 8%57 = getelementptr [2 x {} addrspace(10)*], [2 x {} addrspace(10)*] addrspace(11)* %55, i64 0, i32 1
│   store {} addrspace(10)* %"'", {} addrspace(10)* addrspace(11)* %57, align 8
│   call void ({} addrspace(10)*, ...) @julia.write_barrier({} addrspace(10)* %53, {} addrspace(10)* %3, {} addrspace(10)* %"'")
│ m  %58 = bitcast {}*** %16 to {}**%59 = getelementptr inbounds {}*, {}** %58, i64 -13%60 = getelementptr inbounds {}*, {}** %59, i64 15%61 = bitcast {}** %60 to i8**%62 = load i8*, i8** %61, align 8%63 = call {} addrspace(10)* @julia.gc_alloc_obj({}** %59, i64 16, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 6179224592 to {}*) to {} addrspace(10)*))
│   %64 = bitcast {} addrspace(10)* %63 to [2 x {} addrspace(10)*] addrspace(10)*%65 = addrspacecast [2 x {} addrspace(10)*] addrspace(10)* %64 to [2 x {} addrspace(10)*] addrspace(11)*%66 = getelementptr [2 x {} addrspace(10)*], [2 x {} addrspace(10)*] addrspace(11)* %65, i64 0, i32 0
│   store {} addrspace(10)* %4, {} addrspace(10)* addrspace(11)* %66, align 8%67 = getelementptr [2 x {} addrspace(10)*], [2 x {} addrspace(10)*] addrspace(11)* %65, i64 0, i32 1
│   store {} addrspace(10)* %"'1", {} addrspace(10)* addrspace(11)* %67, align 8
│   call void ({} addrspace(10)*, ...) @julia.write_barrier({} addrspace(10)* %63, {} addrspace(10)* %4, {} addrspace(10)* %"'1")
│   %68 = bitcast {}*** %15 to {}**%69 = getelementptr inbounds {}*, {}** %68, i64 -13%70 = getelementptr inbounds {}*, {}** %69, i64 15%71 = bitcast {}** %70 to i8**%72 = load i8*, i8** %71, align 8%73 = call {} addrspace(10)* @julia.gc_alloc_obj({}** %69, i64 1, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 5710318800 to {}*) to {} addrspace(10)*))
│   %74 = bitcast {} addrspace(10)* %73 to [1 x i8] addrspace(10)*%75 = addrspacecast [1 x i8] addrspace(10)* %74 to [1 x i8] addrspace(11)*%76 = getelementptr [1 x i8], [1 x i8] addrspace(11)* %75, i64 0, i32 0
│   store i8 %5, i8 addrspace(11)* %76, align 1%77 = bitcast {}*** %14 to {}**%78 = getelementptr inbounds {}*, {}** %77, i64 -13%79 = getelementptr inbounds {}*, {}** %78, i64 15%80 = bitcast {}** %79 to i8**%81 = load i8*, i8** %80, align 8%82 = call {} addrspace(10)* @julia.gc_alloc_obj({}** %78, i64 16, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 6179224592 to {}*) to {} addrspace(10)*))
│   %83 = bitcast {} addrspace(10)* %82 to [2 x {} addrspace(10)*] addrspace(10)*%84 = addrspacecast [2 x {} addrspace(10)*] addrspace(10)* %83 to [2 x {} addrspace(10)*] addrspace(11)*%85 = getelementptr [2 x {} addrspace(10)*], [2 x {} addrspace(10)*] addrspace(11)* %84, i64 0, i32 0
│   store {} addrspace(10)* %6, {} addrspace(10)* addrspace(11)* %85, align 8%86 = getelementptr [2 x {} addrspace(10)*], [2 x {} addrspace(10)*] addrspace(11)* %84, i64 0, i32 1
│   store {} addrspace(10)* %"'2", {} addrspace(10)* addrspace(11)* %86, align 8
│   call void ({} addrspace(10)*, ...) @julia.write_barrier({} addrspace(10)* %82, {} addrspace(10)* %6, {} addrspace(10)* %"'2")
│   call void @julia_augmented_primal_2343({ [3 x {} addrspace(10)*] }* sret({ [3 x {} addrspace(10)*] }) %89, [1 x i32] addrspace(11)* %28, [1 x i32] addrspace(11)* %37, [1 x i8] addrspace(11)* %46, [2 x {} addrspace(10)*] addrspace(11)* %55, [2 x {} addrspace(10)*] addrspace(11)* %65, [1 x i8] addrspace(11)* %75, [2 x {} addrspace(10)*] addrspace(11)* %84)
│   %87 = load { [3 x {} addrspace(10)*] }, { [3 x {} addrspace(10)*] }* %89, align 8%88 = extractvalue { [3 x {} addrspace(10)*] } %87, 0
│   br label %invertentry
│ 
│ allocsForInversion:                               ; No predecessors!
│   %89 = alloca { [3 x {} addrspace(10)*] }, align 8
│ 
│ invertentry:                                      ; preds = %entry
│   %90 = bitcast {}*** %13 to {}**%91 = getelementptr inbounds {}*, {}** %90, i64 -13%92 = getelementptr inbounds {}*, {}** %91, i64 15%93 = bitcast {}** %92 to i8**%94 = load i8*, i8** %93, align 8%95 = call {} addrspace(10)* @julia.gc_alloc_obj({}** %91, i64 4, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 5710854928 to {}*) to {} addrspace(10)*))
│   %96 = bitcast {} addrspace(10)* %95 to [1 x i32] addrspace(10)*
91m│   %97 = addrspacecast [1 x i32] addrspace(10)* %96 to [1 x i32] addrspace(11)*%98 = getelementptr [1 x i32], [1 x i32] addrspace(11)* %97, i64 0, i32 0
│   store i32 %0, i32 addrspace(11)* %98, align 4%99 = bitcast {}*** %12 to {}**%100 = getelementptr inbounds {}*, {}** %99, i64 -13%101 = getelementptr inbounds {}*, {}** %100, i64 15%102 = bitcast {}** %101 to i8**%103 = load i8*, i8** %102, align 8%104 = call {} addrspace(10)* @julia.gc_alloc_obj({}** %100, i64 4, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 5710854928 to {}*) to {} addrspace(10)*))
│   %105 = bitcast {} addrspace(10)* %104 to [1 x i32] addrspace(10)*%106 = addrspacecast [1 x i32] addrspace(10)* %105 to [1 x i32] addrspace(11)*%10nothing7 = getelementptr [1 x i32], [1 x i32] addrspace(11)* %106, i64 0, i32 0
│   store i32 %1, i32 addrspace(11)* %107, align 4%108 = bitcast {}*** %11 to {}**%109 = getelementptr inbounds {}*, {}** %108, i64 -13%110 = getelementptr inbounds {}*, {}** %109, i64 15%111 = bitcast {}** %110 to i8**%112 = load i8*, i8** %111, align 8%113 = call {} addrspace(10)* @julia.gc_alloc_obj({}** %109, i64 1, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 5710318800 to {}*) to {} addrspace(10)*))
│   %114 = bitcast {} addrspace(10)* %113 to [1 x i8] addrspace(10)*%115 = addrspacecast [1 x i8] addrspace(10)* %114 to [1 x i8] addrspace(11)*%116 = getelementptr [1 x i8], [1 x i8] addrspace(11)* %115, i64 0, i32 0
│   store i8 %2, i8 addrspace,(11)* %116, align 1%117 = bitcast {}*** %10 to {}**%118 = getelementptr inbounds {}*, {}** %117, i64 -13%119 = getelementptr inbounds {}*, {}** %118, i64 15%120 = bitcast {}** %119 to i8**%121 = load i8*, i8** %120, align 8%122 = call {} addrspace(10)* @julia.gc_alloc_obj({}** %118, i64 16, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 6179224592 to {}*) to {} addrspace(10)*))
│   %123 = bitcast {} addrspace(10)* %122 to [2 x {} addrspace(10)*] addrspace(10)*%124 = addrspacecast [2 x {} addrspace(10)*] addrspace(10)* %123 to [2 x {} addrspace(10)*] addrspace(11)*%125 = getelementptr [2 x {} addrspace(10)*], [2 x {} addrspace(10)*] addrspace(11)* %124, i64 0, i32 0
│   store {} addrspace(10)* %3, {} addrspace(10)* addrspace(11)* %125, align 8%126 = getelementptr [2 x {} addrspace(10)*], [2 x {} addrspace(10)*] addrspace(11)* %124, i64 0, i32 1
│   store {} addrspace(10)* %"'", {} addrspace(10)* addrspace(11)* %126, align 8
│   call void ({} addrspace(10)*, ...) @julia.write_barrier({} addrspace(10)* %122, {} addrspace(10)* %3, {} addrspace(10)* %"'")
│   %127 = bitcast {}*** %9 to {}**%128 = getelementptr inbounds {}*, {}** %127, i64 -13%129 = getelementptr inbounds {}*, {}** %128, i64 15%130 = bitcast {}** %129 to i8**%131 = load i8*, i8** %130, align 8%132 = call {} addrspace(10)* @julia.gc_alloc_obj({}** %128, i64 16, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 6179224592 to {}*) to {} addrspace(10)*))
│   %133 = bitcast {} addrspace(10)* %132 to [2 x {} addrspace(10)*] addrspace(10)*nothing%134 = addrspacecast [2 x {} addrspace(10)*] addrspace(10)* %133 to [2 x {} addrspace(10)*] addrspace(11)*%135 = getelementptr [2 x {} addrspace(10)*], [2 x {} addrspace(10)*] addrspace(11)* %134, i64 0, i32 0
│   store {} addrspace(10)* %4, {} addrspace(10)* addrspace(11)* %135, align 8%136 = getelementptr [2 x {} addrspace(10)*], [2 x {} addrspace(10)*] addrspace(11)* %134, i64 0, i32 1
│   store {} addrspace(10)* %"'1", {} addrspace(10)* addrspace(11)* %136, align 8
│   call void ({} addrspace(10)*, ...) @julia.write_barrier({} addrspace(10)* %132, {} addrspace(10)* %4, {} addrspace(10)* %"'1")
│   %137 = bitcast {}*** %8 to {}**%138 = getelementptr inbounds {}*, {}** %137, i64 -13%139 = getelementptr inbounds {}*, {}** %138, i64 15%140 = bitcast {}** %139 to i8**%141 = load i8*, i8** %140, align 8%142 = call {} addrspace(10)* @julia.gc_alloc_obj({}** %138, i64 1, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 5710318800 to {}*) to {} addrspace(10)*))
│   %143 = bitcast {} addrspace(10)* %142 to [1 x i8] addrspace(10)*%144 = addrspacecast [1 x i8] addrspace(10)* %143 to [1 x i8] addrspace(11)*%145 = getelementptr [1 x i8], [1 x i8] addrspace(11)* %144, i64 0, i32 0
│   store i8 %5, i8 addrspace(11)* %145, align 1%146 = bitcast {}*** %7 to {}**%147 = getelementptr inbounds {}*, {}** %146, i64 -13%148 = getelementptr inbounds {}*, {}** %147, i64 15%149 = bitcast {}** %148 to i8**%150 = load i8*, i8** %149, align 8%151 = call {} addrspace(10)* @julia.gc_alloc_ob,j({}** %147, i64 16, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 6179224592 to {}*) to {} addrspace(10)*))
│   %152 = bitcast {} addrspace(10)* %151 to [2 x {} addrspace(10)*] addrspace(10)*%153 = addrspacecast [2 x {} addrspace(10)*] addrspace(10)* %152 to [2 x {} addrspace(10)*] addrspace(11)*%154 = getelementptr [2 x {} addrspace(10)*], [2 x {} addrspace(10)*] addrspace(11)* %153, i64 0, i32 0
│   store {} addrspace(10)* %6, {} addrspace(10)* addrspace(11)* %154, align 8%155 = getelementptr [2 x {} addrspace(10)*], [2 x {} addrspace(10)*] addrspace(11)* %153, i64 0, i32 1
│   store {} addrspace(10)* %"'2", {} addrspace(10)* addrspace(11)* %155, align 8
│   call void ({} addrspace(10)*, ...) @julia.write_barrier({} addrspace(10)* %151, {} addrspace(10)* %6, {} addrspace(10)* %"'2")
│ }
│ )
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/YU9ZW/src/utils.jl:50
 nothing, nothing, nothing, nothing, nothing),)

and, and it doesn't look like the reverse implementation is getting hit, because:

julia> B.dval
5×10 Matrix{Float64}:
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0

Any advice on what might be going on would be appreciated!

julia> versioninfo()
Julia Version 1.9.0-rc2
Commit 72aec423c2a (2023-04-01 10:41 UTC)
Platform Info:
  OS: macOS (x86_64-apple-darwin21.4.0)
  CPU: 12 × Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-14.0.6 (ORCJIT, skylake)
  Threads: 6 on 12 virtual cores
Environment:
  JULIA_NUM_THREADS = 6

I'm checked out to the latest commit on main.

@wsmoses
Copy link
Member

wsmoses commented Apr 18, 2023

That's an internal error for the rules dispatcher, will fix.

@willtebbutt
Copy link
Author

Excellent -- thanks!

@wsmoses
Copy link
Member

wsmoses commented Apr 19, 2023

Fixed by #743

@wsmoses wsmoses closed this as completed Apr 19, 2023
@willtebbutt
Copy link
Author

Fantastic. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants