diff --git a/.github/workflows/testing_ci.yml b/.github/workflows/testing_ci.yml index 6ac06b14..a43660f4 100644 --- a/.github/workflows/testing_ci.yml +++ b/.github/workflows/testing_ci.yml @@ -42,7 +42,6 @@ jobs: run: | which python which pip - pip install --upgrade pip install numpy==1.24 torch==${{ matrix.pytorch-version }} -f https://download.pytorch.org/whl/cpu python -c "import torch; print('PyTorch:', torch.__version__)" diff --git a/pypots/nn/modules/reformer/local_attention.py b/pypots/nn/modules/reformer/local_attention.py index 86388b7c..37f1e029 100644 --- a/pypots/nn/modules/reformer/local_attention.py +++ b/pypots/nn/modules/reformer/local_attention.py @@ -43,10 +43,6 @@ def apply_rotary_pos_emb(q, k, freqs, scale=1): return q, k -def exists(val): - return val is not None - - def default(value, d): return d if not exists(value) else value @@ -186,7 +182,6 @@ def forward( ), "cannot perform window size extrapolation if xpos is not turned on" ( - shape, autopad, pad_value, window_size, @@ -195,7 +190,6 @@ def forward( look_forward, shared_qk, ) = ( - q.shape, self.autopad, -1, default(window_size, self.window_size), @@ -216,7 +210,7 @@ def forward( lambda t: pad_to_multiple(t, self.window_size, dim=-2), (q, k, v) ) - b, n, dim_head, device, dtype = *q.shape, q.device, q.dtype + b, n, dim_head, device = *q.shape, q.device scale = default(self.scale, dim_head**-0.5) diff --git a/pypots/nn/modules/reformer/lsh_attention.py b/pypots/nn/modules/reformer/lsh_attention.py index 3f6f2980..40d82076 100644 --- a/pypots/nn/modules/reformer/lsh_attention.py +++ b/pypots/nn/modules/reformer/lsh_attention.py @@ -139,9 +139,9 @@ def merge_dims(ind_from, ind_to, tensor): def split_at_index(dim, index, t): pre_slices = (slice(None),) * dim - l = (*pre_slices, slice(None, index)) - r = (*pre_slices, slice(index, None)) - return t[l], t[r] + l_ = (*pre_slices, slice(None, index)) + r_ = (*pre_slices, slice(index, None)) + return t[l_], t[r_] class FullQKAttention(nn.Module): @@ -608,10 +608,9 @@ def forward( **kwargs, ): device, dtype = x.device, x.dtype - b, t, e, h, dh, m, l_h = ( + b, t, e, h, m, l_h = ( *x.shape, self.heads, - self.dim_head, self.num_mem_kv, self.n_local_attn_heads, )