diff --git a/examples/microViT.py b/examples/microViT.py index 0d3b890daf..5f4a3538f6 100644 --- a/examples/microViT.py +++ b/examples/microViT.py @@ -148,7 +148,7 @@ def forward(self, x): x = self.patch_emb(x) # flatten patches into sequence - x = x.flatten(2, 3).transpose(1, 2) # B HW C + x = x.flatten(2, 3).transpose(1, 2).contiguous() # B HW C if self.hparams.classifier == Classifier.TOKEN: # prepend classification token diff --git a/xformers/components/residual.py b/xformers/components/residual.py index 3a7e4ac43c..5f639700d1 100644 --- a/xformers/components/residual.py +++ b/xformers/components/residual.py @@ -11,7 +11,7 @@ import torch.nn as nn # NOTE: The Triton layernorm can be activated/deactivated from here -_is_triton_available = False # torch.cuda.is_available() +_is_triton_available = torch.cuda.is_available() if _is_triton_available: try: