diff --git a/foolbox/attacks/gradient_descent_base.py b/foolbox/attacks/gradient_descent_base.py index 2af79418..49fcc0f0 100644 --- a/foolbox/attacks/gradient_descent_base.py +++ b/foolbox/attacks/gradient_descent_base.py @@ -127,7 +127,7 @@ def run( verify_input_bounds(x0, model) # perform a gradient ascent (targeted attack) or descent (untargeted attack) - if isinstance(criterion_, Misclassification): + if hasattr(criterion_, "labels"): gradient_step_sign = 1.0 classes = criterion_.labels elif hasattr(criterion_, "target_classes"): diff --git a/foolbox/criteria.py b/foolbox/criteria.py index 71564e80..d50665aa 100644 --- a/foolbox/criteria.py +++ b/foolbox/criteria.py @@ -77,6 +77,7 @@ def __and__(self, other: "Criterion") -> "Criterion": class _And(Criterion): + def __init__(self, a: Criterion, b: Criterion): super().__init__() self.a = a @@ -141,3 +142,53 @@ def __call__(self, perturbed: T, outputs: T) -> T: assert classes.shape == self.target_classes.shape is_adv = classes == self.target_classes return restore_type(is_adv) + + +class ConfidentClassification(Criterion): + """Considers those perturbed inputs adversarial whose predicted class has probability >= p. + Args: + p: Classification is deemed confident when probability is at least p. p must be between 0 and 1. + """ + + def __init__(self, p: float): + super().__init__() + assert 0 <= p <= 1 + self.p = p + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.p!r})" + + def __call__(self, perturbed: T, outputs: T) -> T: + outputs_, restore_type = ep.astensor_(outputs) + del perturbed, outputs + + is_conf = ep.softmax(outputs_).max(axis=-1) >= self.p + return restore_type(is_conf) + + +class ConfidentMisclassification(_And): + """Considers those perturbed inputs adversarial whose predicted class + differs from the label and matches another class with probability >= p. + + Args: + labels: Tensor with labels of the unperturbed inputs ``(batch,)``. + p: Classification is deemed confident when probability is at least p. p must be between 0 and 1. + """ + + def __init__(self, labels: Any, p: float): + super().__init__(Misclassification(labels), ConfidentClassification(p)) + self.labels = self.a.labels + + +class ConfidentTargetedMisclassification(_And): + """Considers those perturbed inputs adversarial whose predicted class + matches the target class with probability >= p. + + Args: + target_classes: Tensor with target classes ``(batch,)``. + p: Classification is deemed confident when probability is at least p. p must be between 0 and 1. + """ + + def __init__(self, target_classes: Any, p: float): + super().__init__(TargetedMisclassification(target_classes), ConfidentClassification(p)) + self.target_classes = self.a.target_classes