From b597aa7bbc45a3572bac40b0e850ea926af84a41 Mon Sep 17 00:00:00 2001 From: "Gal Cohen (galco)" Date: Thu, 22 Aug 2024 16:03:22 +0300 Subject: [PATCH] fix: no need to dtype A in jamba (#32924) Co-authored-by: Gal Cohen --- src/transformers/models/jamba/modeling_jamba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index c6e8d425459fe0..230536a83a145c 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -630,7 +630,7 @@ def __init__(self, config: JambaConfig, layer_idx): # S4D real initialization. These are not discretized! # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded - A = torch.arange(1, self.ssm_state_size + 1, dtype=torch.float32)[None, :] + A = torch.arange(1, self.ssm_state_size + 1)[None, :] A = A.expand(self.intermediate_size, -1).contiguous() self.A_log = nn.Parameter(torch.log(A))