Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speech commands explanatory model #1869

Merged
74 changes: 44 additions & 30 deletions armory/metrics/poisoning.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
# An armory user may request one of these models under 'adhoc'/'explanatory_model'
EXPLANATORY_MODEL_CONFIGS = explanatory_model_configs = {
"speech_commands_explanatory_model": {
"module": "armory.baseline_models.tf_graph.audio_resnet50",
"name": "get_unwrapped_model",
"data_modality": "audio",
"activation_layer": "avg_pool",
"data_modality": "audio",
"model_framework": "tensorflow",
"module": "armory.baseline_models.tf_graph.audio_resnet50",
"name": "get_unwrapped_model",
"preprocess_kwargs": {},
"weights_file": "speech_commands_explanatory_model_resnet50_bean.h5",
},
"cifar10_explanatory_model": {
Expand All @@ -26,14 +27,14 @@
},
"module": "armory.baseline_models.pytorch.resnet18_bean_regularization",
"name": "get_model",
"resize_image": False,
"preprocess_kwargs": {},
"weights_file": "cifar10_explanatory_model_resnet18_bean.pt",
},
"gtsrb_explanatory_model": {
"model_kwargs": {},
"module": "armory.baseline_models.pytorch.micronnet_gtsrb_bean_regularization",
"name": "get_model",
"resize_image": False,
"preprocess_kwargs": {},
"weights_file": "gtsrb_explanatory_model_micronnet_bean.pt",
},
"resisc10_explanatory_model": {
Expand All @@ -44,6 +45,9 @@
},
"module": "armory.baseline_models.pytorch.resnet18_bean_regularization",
"name": "get_model",
"preprocess_kwargs": {
"resize_image": True,
},
"weights_file": "resisc10_explanatory_model_resnet18_bean.pt",
},
}
Expand All @@ -58,27 +62,41 @@ def __init__(
data_modality="image",
model_framework="pytorch",
activation_layer=None,
swsuggs marked this conversation as resolved.
Show resolved Hide resolved
resize_image=True,
size=(224, 224),
resample=Image.BILINEAR,
device=DEVICE,
preprocess_kwargs={},
swsuggs marked this conversation as resolved.
Show resolved Hide resolved
):
"""
explanatory_model: A callable pytorch or tensorflow model used to produce
activations for silhouette analysis
data_modality: one of "image" or "audio" (more options to be added as needed)
model_framework: "pytorch" or "tensorflow"
activation_layer: name of the layer of the model from which to draw activations
(currently only for tensorflow models).
If None, uses the final output layer.
preprocess_kwargs: keyword arguments for the preprocessing function
"""
if not callable(explanatory_model):
raise ValueError(f"explanatory_model {explanatory_model} is not callable")
if model_framework not in ("pytorch", "tensorflow"):
raise ValueError(
f"model_framework should be 'pytorch' or 'tensorflow', not '{model_framework}'"
)
self.explanatory_model = explanatory_model
self.data_modality = data_modality
self.model_framework = model_framework
self.activation_layer = activation_layer
self.resize_image = bool(resize_image)
self.size = size
self.resample = resample
self.device = device

if self.model_framework == "tensorflow" and self.activation_layer is not None:
self.explanatory_model = tf.keras.Model(
explanatory_model.layers[0].input,
explanatory_model.get_layer(self.activation_layer).output,
)
self.preprocess_kwargs = preprocess_kwargs

if self.activation_layer is not None:
if self.model_framework == "tensorflow":
# Set explanatory_model to return activations from internal layer
self.explanatory_model = tf.keras.Model(
explanatory_model.layers[0].input,
explanatory_model.get_layer(self.activation_layer).output,
)
else:
raise ValueError(
"Currently, 'activation_layer' can only be specified for a tensorflow model, not pytorch."
)

@classmethod
def from_config(cls, model_config, **kwargs):
Expand All @@ -100,9 +118,6 @@ def from_config(cls, model_config, **kwargs):
raise ValueError(f"config key {k} is required")
module, name, weights_file = (model_config.pop(k) for k in keys)
model_kwargs = model_config.pop("model_kwargs", {})
data_modality = model_config.pop("data_modality", "image")
model_framework = model_config.pop("model_framework", "pytorch")
activation_layer = model_config.pop("activation_layer", None)

weights_path = maybe_download_weights_from_s3(
weights_file, auto_expand_tars=True
Expand All @@ -113,9 +128,6 @@ def from_config(cls, model_config, **kwargs):

return cls(
explanatory_model,
data_modality,
model_framework,
activation_layer,
**model_config,
)

Expand Down Expand Up @@ -151,7 +163,7 @@ def get_activations(self, x, batch_size: int = None):

@staticmethod
def _preprocess_image(
x, resize_image=True, size=(224, 224), resample=Image.BILINEAR, device=DEVICE
x, resize_image=False, size=(224, 224), resample=Image.BILINEAR, device=DEVICE
):
if np.issubdtype(x.dtype, np.floating):
if x.min() < 0.0 or x.max() > 1.0:
Expand Down Expand Up @@ -185,10 +197,12 @@ def preprocess(self, x):
if self.data_modality == "image":
return type(self)._preprocess_image(
x,
self.resize_image,
self.size,
resample=self.resample,
device=self.device,
**self.preprocess_kwargs,
)
elif self.data_modality == "audio":
return x

else:
raise ValueError(
f"There is no preprocessing function for data_modality '{self.data_modality}'. Please set data_modality to 'image' or 'audio', or implement preprocessing for data_modality '{self.data_modality}'"
)