diff --git a/detectron2/modeling/meta_arch/rcnn.py b/detectron2/modeling/meta_arch/rcnn.py index b01c362baa..e5f66d1a3d 100644 --- a/detectron2/modeling/meta_arch/rcnn.py +++ b/detectron2/modeling/meta_arch/rcnn.py @@ -64,8 +64,8 @@ def __init__( if vis_period > 0: assert input_format is not None, "input_format is required for visualization!" - self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1)) - self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1)) + self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) assert ( self.pixel_mean.shape == self.pixel_std.shape ), f"{self.pixel_mean} and {self.pixel_std} have different shapes!" diff --git a/detectron2/modeling/meta_arch/retinanet.py b/detectron2/modeling/meta_arch/retinanet.py index 74150ff56e..654f04cd5b 100644 --- a/detectron2/modeling/meta_arch/retinanet.py +++ b/detectron2/modeling/meta_arch/retinanet.py @@ -137,8 +137,8 @@ def __init__( self.vis_period = vis_period self.input_format = input_format - self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1)) - self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1)) + self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) """ In Detectron1, loss is normalized by number of foreground samples in the batch. diff --git a/detectron2/modeling/meta_arch/semantic_seg.py b/detectron2/modeling/meta_arch/semantic_seg.py index 2957ae1c08..025344952e 100644 --- a/detectron2/modeling/meta_arch/semantic_seg.py +++ b/detectron2/modeling/meta_arch/semantic_seg.py @@ -35,8 +35,8 @@ def __init__(self, cfg): super().__init__() self.backbone = build_backbone(cfg) self.sem_seg_head = build_sem_seg_head(cfg, self.backbone.output_shape()) - self.register_buffer("pixel_mean", torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(-1, 1, 1)) - self.register_buffer("pixel_std", torch.Tensor(cfg.MODEL.PIXEL_STD).view(-1, 1, 1)) + self.register_buffer("pixel_mean", torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.Tensor(cfg.MODEL.PIXEL_STD).view(-1, 1, 1), False) @property def device(self): diff --git a/projects/Panoptic-DeepLab/panoptic_deeplab/panoptic_seg.py b/projects/Panoptic-DeepLab/panoptic_deeplab/panoptic_seg.py index 29dcb27ad1..5062664e54 100644 --- a/projects/Panoptic-DeepLab/panoptic_deeplab/panoptic_seg.py +++ b/projects/Panoptic-DeepLab/panoptic_deeplab/panoptic_seg.py @@ -44,8 +44,8 @@ def __init__(self, cfg): self.backbone = build_backbone(cfg) self.sem_seg_head = build_sem_seg_head(cfg, self.backbone.output_shape()) self.ins_embed_head = build_ins_embed_branch(cfg, self.backbone.output_shape()) - self.register_buffer("pixel_mean", torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(-1, 1, 1)) - self.register_buffer("pixel_std", torch.Tensor(cfg.MODEL.PIXEL_STD).view(-1, 1, 1)) + self.register_buffer("pixel_mean", torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.Tensor(cfg.MODEL.PIXEL_STD).view(-1, 1, 1), False) self.meta = MetadataCatalog.get(cfg.DATASETS.TRAIN[0]) self.stuff_area = cfg.MODEL.PANOPTIC_DEEPLAB.STUFF_AREA self.threshold = cfg.MODEL.PANOPTIC_DEEPLAB.CENTER_THRESHOLD diff --git a/projects/TensorMask/tensormask/arch.py b/projects/TensorMask/tensormask/arch.py index 2746b7eb4e..6f5d815656 100644 --- a/projects/TensorMask/tensormask/arch.py +++ b/projects/TensorMask/tensormask/arch.py @@ -348,8 +348,8 @@ def __init__(self, cfg): ) # box transform self.box2box_transform = Box2BoxTransform(weights=cfg.MODEL.TENSOR_MASK.BBOX_REG_WEIGHTS) - self.register_buffer("pixel_mean", torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(-1, 1, 1)) - self.register_buffer("pixel_std", torch.Tensor(cfg.MODEL.PIXEL_STD).view(-1, 1, 1)) + self.register_buffer("pixel_mean", torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.Tensor(cfg.MODEL.PIXEL_STD).view(-1, 1, 1), False) @property def device(self):