From b69921bbc081b0cd7a6ba4065b6265f96df250b6 Mon Sep 17 00:00:00 2001 From: johannesring Date: Tue, 16 Jul 2024 14:42:49 +0200 Subject: [PATCH] Add type assertions and casts to fix mypy issues --- .../automatedPostprocessing/log_plotter.py | 38 +++++++++++-------- .../create_spectrograms_chromagrams.py | 16 ++++++-- 2 files changed, 34 insertions(+), 20 deletions(-) diff --git a/src/vasp/automatedPostprocessing/log_plotter.py b/src/vasp/automatedPostprocessing/log_plotter.py index fc8e79a..9b4b22e 100644 --- a/src/vasp/automatedPostprocessing/log_plotter.py +++ b/src/vasp/automatedPostprocessing/log_plotter.py @@ -19,7 +19,7 @@ import argparse import logging from pathlib import Path -from typing import Dict, Any, List, Optional, Tuple +from typing import Dict, Any, List, Optional, Tuple, cast import pickle import numpy as np @@ -401,7 +401,9 @@ def plot_multiple_variables_comparison(variable_mean: np.ndarray, variable_min: split_variable_max_data = np.array_split(variable_max, num_cycles) # Create subplots for mean, min, and max - fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=figure_size, sharex=True) + fig, axes = plt.subplots(3, 1, figsize=figure_size, sharex=True) + assert isinstance(axes, np.ndarray) and axes.shape == (3,) + ax1, ax2, ax3 = axes for cycle in range(first_cycle - 1, last_cycle): cycle_variable_mean_data = split_variable_mean_data[cycle] @@ -503,13 +505,14 @@ def plot_probe_points(time: np.ndarray, probe_points: Dict[int, Dict[str, np.nda # Create subplots based on the number of selected probe points if num_rows == 1 and num_cols == 1: # If only one probe point is selected, create a single figure - fig, axes = plt.subplots(figsize=figure_size) + fig, _ = plt.subplots(figsize=figure_size) axes = [fig.gca()] # Get the current axis as a list else: - fig, axes = plt.subplots(num_rows, num_cols, figsize=figure_size) + fig, axes_array = plt.subplots(num_rows, num_cols, figsize=figure_size) # Flatten the axes array for easier iteration - axes = axes.flatten() + assert isinstance(axes_array, np.ndarray) + axes = axes_array.flatten().tolist() for i, (probe_point, data) in enumerate(selected_probe_data.items()): ax = axes[i] @@ -525,7 +528,8 @@ def plot_probe_points(time: np.ndarray, probe_points: Dict[int, Dict[str, np.nda ax.grid(True) ax.tick_params(axis='y', which='major', labelsize=12, labelcolor='b') - ax2 = ax.twinx() + # Create a twin Axes sharing the xaxis and cast it to Axes + ax2 = cast(plt.Axes, ax.twinx()) l2, = ax2.plot(time[start:end], pressure_data, color='r') ax2.set_ylabel("Pressure [Pa]", color='r') ax2.legend([l1, l2], ["Velocity Magnitude", "Pressure"], loc="upper right") @@ -583,13 +587,14 @@ def plot_probe_points_displacement(time: np.ndarray, probe_points: Dict[int, Dic # Create subplots based on the number of selected probe points if num_rows == 1 and num_cols == 1: # If only one probe point is selected, create a single figure - fig, axes = plt.subplots(figsize=figure_size) + fig, _ = plt.subplots(figsize=figure_size) axes = [fig.gca()] else: - fig, axes = plt.subplots(num_rows, num_cols, figsize=figure_size) + fig, axes_array = plt.subplots(num_rows, num_cols, figsize=figure_size) # Flatten the axes array for easier iteration - axes = axes.flatten() + assert isinstance(axes_array, np.ndarray) + axes = axes_array.flatten().tolist() for i, (probe_point, data) in enumerate(selected_probe_data.items()): ax = axes[i] @@ -658,9 +663,9 @@ def plot_probe_points_comparison(probe_points: Dict[int, Dict[str, np.ndarray]], f"Comparing from cycle {first_cycle} to cycle {last_cycle}") # Create subplots for magnitude and pressure - fig, axs = plt.subplots(2, 1, figsize=figure_size) - - ax, ax2 = axs + fig, axes = plt.subplots(2, 1, figsize=figure_size) + assert isinstance(axes, np.ndarray) and axes.shape == (2,) + ax, ax2 = axes # Split the data into separate cycles split_magnitude_data = np.array_split(data["magnitude"], num_cycles) @@ -1011,13 +1016,14 @@ def plot_probe_points_tke(tke_data: Dict[int, Tuple[np.ndarray, np.ndarray, np.n # Create subplots for each probe point if num_rows == 1 and num_cols == 1: # If only one probe point is selected, create a single figure - fig, axes = plt.subplots(figsize=figure_size) - axes = [axes] + fig, _ = plt.subplots(figsize=figure_size) + axes = [fig.gca()] else: - fig, axes = plt.subplots(num_rows, num_cols, figsize=figure_size) + fig, axes_array = plt.subplots(num_rows, num_cols, figsize=figure_size) # Flatten the axes array for easier iteration - axes = axes.flatten() + assert isinstance(axes_array, np.ndarray) + axes = axes_array.flatten().tolist() # Add common title fig.suptitle("Turbulent Kinetic Energy (TKE) for Probe Points", fontsize=16) diff --git a/src/vasp/automatedPostprocessing/postprocessing_h5py/create_spectrograms_chromagrams.py b/src/vasp/automatedPostprocessing/postprocessing_h5py/create_spectrograms_chromagrams.py index 212f314..e6c545a 100755 --- a/src/vasp/automatedPostprocessing/postprocessing_h5py/create_spectrograms_chromagrams.py +++ b/src/vasp/automatedPostprocessing/postprocessing_h5py/create_spectrograms_chromagrams.py @@ -79,13 +79,21 @@ def create_spectrogram_composite(case_name: str, quantity: str, df: pd.DataFrame # Create composite figure if amplitude_file and flow_rate_file: - fig1, (ax1, ax2, ax3, ax4, ax5) = plt.subplots(5, sharex=True, gridspec_kw={'height_ratios': [1, 3, 1, 1, 1]}) + fig1, axes = plt.subplots(5, sharex=True, gridspec_kw={'height_ratios': [1, 3, 1, 1, 1]}) + assert isinstance(axes, np.ndarray) and axes.shape == (5,) + ax1, ax2, ax3, ax4, ax5 = axes elif flow_rate_file: - fig1, (ax1, ax2, ax3, ax4) = plt.subplots(4, sharex=True, gridspec_kw={'height_ratios': [1, 3, 1, 1]}) + fig1, axes = plt.subplots(4, sharex=True, gridspec_kw={'height_ratios': [1, 3, 1, 1]}) + assert isinstance(axes, np.ndarray) and axes.shape == (4,) + ax1, ax2, ax3, ax4 = axes elif amplitude_file: - fig1, (ax2, ax3, ax4, ax5) = plt.subplots(4, sharex=True, gridspec_kw={'height_ratios': [3, 1, 1, 1]}) + fig1, axes = plt.subplots(4, sharex=True, gridspec_kw={'height_ratios': [3, 1, 1, 1]}) + assert isinstance(axes, np.ndarray) and axes.shape == (4,) + ax2, ax3, ax4, ax5 = axes else: - fig1, (ax2, ax3, ax4) = plt.subplots(3, sharex=True, gridspec_kw={'height_ratios': [3, 1, 1]}) + fig1, axes = plt.subplots(3, sharex=True, gridspec_kw={'height_ratios': [3, 1, 1]}) + assert isinstance(axes, np.ndarray) and axes.shape == (3,) + ax2, ax3, ax4 = axes fig1.set_size_inches(7.5, 9)