diff --git a/alibi/datasets/default.py b/alibi/datasets/default.py index 603d8f8b2..00b84d5be 100644 --- a/alibi/datasets/default.py +++ b/alibi/datasets/default.py @@ -142,7 +142,8 @@ def load_cats(target_size: tuple = (299, 299), return_X_y: bool = False) -> Unio # data img = tar.extractfile(member).read() # type: ignore[union-attr] img = PIL.Image.open(BytesIO(img)) - img = np.expand_dims(img.resize(target_size), axis=0) + img = np.array(img.resize(target_size)) + img = np.expand_dims(img, axis=0) images.append(img) # labels diff --git a/alibi/explainers/pd_variance.py b/alibi/explainers/pd_variance.py index 44d7fcf36..1c69881d0 100644 --- a/alibi/explainers/pd_variance.py +++ b/alibi/explainers/pd_variance.py @@ -815,7 +815,7 @@ def plot_pd_variance(exp: Explanation, f"Available values are: {exp.meta['params']['target_names']}.") if isinstance(target, numbers.Integral) \ - and (target > len(exp.meta['params']['target_names'])): # type: ignore[operator] + and (target > len(exp.meta['params']['target_names'])): raise IndexError(f"Target index out of range. Received {target}. " f"The number of targets is {len(exp.meta['params']['target_names'])}.") diff --git a/alibi/utils/visualization.py b/alibi/utils/visualization.py index d60ad6965..bd5ac0602 100644 --- a/alibi/utils/visualization.py +++ b/alibi/utils/visualization.py @@ -196,6 +196,7 @@ def visualize_image_attr( heat_map = None # Show original image if ImageVisualizationMethod[method] == ImageVisualizationMethod.original_image: + assert original_image is not None plt_axis.imshow(original_image) else: # Choose appropriate signed attributions and normalize.