Skip to content

Commit

Permalink
Merge pull request #7 from grad0s/bugfix/double_unpacking
Browse files Browse the repository at this point in the history
ctx.saved_tensors should be unpacked only once
  • Loading branch information
tgale96 authored Apr 9, 2024
2 parents d9e2e64 + f74df1d commit e5c47f6
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions stk/backend/sputnik.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,9 @@ def forward(ctx,
@staticmethod
@custom_bwd
def backward(ctx, dy):
lhs = (ctx.shape,) + ctx.saved_tensors[:-1]
rhs = ctx.saved_tensors[-1]
saved_tensors = ctx.saved_tensors
lhs = (ctx.shape,) + saved_tensors[:-1]
rhs = saved_tensors[-1]
trans_a = ctx.transpose_a
trans_b = _is_transposed(rhs)

Expand Down Expand Up @@ -201,8 +202,9 @@ def forward(ctx,
@staticmethod
@custom_bwd
def backward(ctx, dy):
lhs = ctx.saved_tensors[0]
rhs = (ctx.shape,) + ctx.saved_tensors[1:]
saved_tensors = ctx.saved_tensors
lhs = saved_tensors[0]
rhs = (ctx.shape,) + saved_tensors[1:]
trans_a = _is_transposed(lhs)
trans_b = ctx.transpose_b

Expand Down Expand Up @@ -270,8 +272,9 @@ def forward(ctx,
@staticmethod
@custom_bwd
def backward(ctx, dy):
lhs, rhs = ctx.saved_tensors[:2]
dy = (ctx.shape, dy) + ctx.saved_tensors[2:]
saved_tensors = ctx.saved_tensors
lhs, rhs = saved_tensors[:2]
dy = (ctx.shape, dy) + saved_tensors[2:]
trans_a = _is_transposed(lhs)
trans_b = _is_transposed(rhs)

Expand Down

0 comments on commit e5c47f6

Please sign in to comment.