Skip to content

Commit

Permalink
Merge branch 'GUI' of github.com:EOMYS-Public/SciDataTool into GUI
Browse files Browse the repository at this point in the history
  • Loading branch information
BenjaminGabet committed Jun 8, 2022
2 parents 699a38b + 01cf168 commit b7a12ec
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 88 deletions.
21 changes: 19 additions & 2 deletions SciDataTool/Functions/Plot/plot_2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,12 +425,29 @@ def get_cumulated_array(data, **kwargs):
pickradius=5,
)
if len(Ydatas) > 1:
ax.stackplot(
stacks = ax.stackplot(
Xdatas[i_Xdatas[0]],
Ydatas[1:],
Ydatas[1:] * Ydatas[0] / 100,
colors=color_list[1:],
labels=legend_list[1:],
)
if (
len(
[
i
for i, n in enumerate(color_list)
if n in color_list[i + 1 :] and n not in color_list[:i]
]
)
> 0
):
# Add hatches if color_list too small
hatches = ["", "//", "\\", "+"]
ncolors = [i for i, n in enumerate(color_list) if n in color_list[:i]][
0
] - 1
for ii, stack in enumerate(stacks):
stack.set_hatch(hatches[ii // ncolors])
if xticks is not None:
ax.xaxis.set_ticks(xticks)
plt.xticks(rotation=90, ha="center", va="top")
Expand Down
2 changes: 2 additions & 0 deletions SciDataTool/GUI/DDataPlotter/DDataPlotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,7 @@ def update_plot(self):
y_min=output_range["min"],
y_max=output_range["max"],
)
self.w_plot_manager.plot_arg_dict = plot_arg_dict_2D

elif len(axes_selected) == 2:
plot_arg_dict_3D = self.plot_arg_dict.copy()
Expand Down Expand Up @@ -665,6 +666,7 @@ def update_plot(self):
ax=self.ax,
is_switch_axes=not_in_order,
)
self.w_plot_manager.plot_arg_dict = plot_arg_dict_3D

else:
print("Operation not implemented yet, plot could not be updated")
Expand Down
26 changes: 20 additions & 6 deletions SciDataTool/GUI/WPlotManager/WPlotManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(self, parent=None):
self.save_path = DATA_DIR # Path to directory where animation are stored
self.path_to_image = None # Path to recover the image for the animate button
self.main_widget = None
self.plot_arg_dict = {}

self.is_test = False # Used in test to disable showing the animation
self.gif_path_list = list() # List of path to the gifs created (used in test)
Expand Down Expand Up @@ -314,6 +315,11 @@ def get_file_name(self):
file_name = self.data.name
file_name = file_name.replace("{", "[").replace("}", "]").replace(".", ",")

if "contribution_axis" in self.plot_arg_dict:
file_name += "_contribution"
elif "overall_axes" in self.plot_arg_dict:
file_name += "_order_tracking"

return file_name

def export(self, save_file_path=False):
Expand All @@ -324,10 +330,12 @@ def export(self, save_file_path=False):
a WPlotManager object
"""
# Getting the inputs of the user to export the plot + for the name of the csv file
param_list = [
*self.w_axis_manager.get_axes_selected(),
*self.w_axis_manager.get_operation_selected(),
]
[
data,
axes_selected,
data_selection,
output_range,
] = self.get_plot_info()

if self.default_file_path is None:
file_name = self.get_file_name()
Expand All @@ -350,9 +358,15 @@ def export(self, save_file_path=False):
if save_file_path not in ["", False]:
save_path = dirname(save_file_path)
file_name = basename(save_file_path).split(".")[0]
is_2D = True if len(axes_selected) == 1 else False
try:
self.data.export_along(
*param_list, save_path=save_path, file_name=file_name
data.export_along(
*[*axes_selected, *data_selection],
unit=output_range["unit"],
save_path=save_path,
file_name=file_name,
is_2D=is_2D,
plot_options=self.plot_arg_dict
)
except Exception as e:
# Displaying the error inside abox instead of the console
Expand Down
192 changes: 122 additions & 70 deletions SciDataTool/Methods/DataND/export_along.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
def export_along(
self,
*args,
is_2D=True,
unit="SI",
is_norm=False,
axis_data=[],
save_path=None,
file_name=None,
file_format="csv",
is_multiple_files=False,
plot_options={}
):
"""Exports the sliced or interpolated version of the data, using conversions and symmetries if needed, in a file.
Parameters
Expand Down Expand Up @@ -52,36 +52,56 @@ def export_along(
args = tuple(arg_list)

# Get requested data
is_fft = False
for arg in args:
if "freqs" in arg or "wavenumber" in arg:
is_fft = True
if is_fft:
results = self.get_magnitude_along(
*args, unit=unit, is_norm=is_norm, axis_data=axis_data
if is_2D:
Xdata, Ydatas, title, xlabel, ylabel, legends = self.plot_2D_Data(
*args, **plot_options, unit=unit, is_export=True
)
axes_list_new = []
if "for " in title:
slices = title.split("for ")[1]
else:
slices = ""
else:
results = self.get_along(*args, unit=unit, is_norm=is_norm, axis_data=axis_data)
axes_list = results["axes_list"]

# Remove slice axes
axes_list_new = []
slices = ""
for axis in axes_list:
if axis.unit == "SI":
axis.unit = unit_dict[axis.name]
if len(results[axis.name]) == 1:
slices += axis.name + "=" + str(results[axis.name][0])
elif isinstance(results[axis.name], str):
slices += axis.name + "=" + results[axis.name]
if "is_norm" in plot_options:
is_norm = plot_options["is_norm"]
else:
is_norm = False
if "axis_data" in plot_options:
axis_data = plot_options["axis_data"]
else:
axis_data = None
is_fft = False
for arg in args:
if "freqs" in arg or "wavenumber" in arg:
is_fft = True
if is_fft:
results = self.get_magnitude_along(
*args, unit=unit, is_norm=is_norm, axis_data=axis_data
)
else:
axes_list_new.append(axis)
for axis in results["axes_dict_other"]:
results = self.get_along(
*args, unit=unit, is_norm=is_norm, axis_data=axis_data
)
axes_list = results["axes_list"]

# Remove slice axes
axes_list_new = []
slices = ""
for axis in axes_list:
if axis.unit == "SI":
axis.unit = unit_dict[axis.name]
if len(results[axis.name]) == 1:
slices += axis.name + "=" + str(results[axis.name][0])
elif isinstance(results[axis.name], str):
slices += axis.name + "=" + results[axis.name]
else:
axes_list_new.append(axis)
for axis in results["axes_dict_other"]:
if slices != "":
slices = slices + ", "
slices += axis + "=" + str(results["axes_dict_other"][axis][0])
if slices != "":
slices = slices + ", "
slices += axis + "=" + str(results["axes_dict_other"][axis][0])
if slices != "":
slices = "sliced at " + slices
slices = "sliced at " + slices

# Default file_name
if file_name is None:
Expand All @@ -90,14 +110,14 @@ def export_along(
if file_format == "csv":
# Write csv files
# Format: first axis in column, second in row, third in file
if len(axes_list_new) == 3 and is_multiple_files:
if is_2D or len(axes_list_new) < 3:
nfiles = 1
elif len(axes_list_new) == 3 and is_multiple_files:
nfiles = len(axes_list_new[2].values)
elif len(axes_list_new) == 3:
raise Exception(
"cannot export more than 2 dimensions in single csv file. Activate is_multiple_files to write in multiple csv files."
)
elif len(axes_list_new) < 3:
nfiles = 1
else:
raise Exception("cannot export more than 3 dimensions in csv file")

Expand Down Expand Up @@ -135,48 +155,80 @@ def export_along(
meta_data = [self.symbol, self.name, "[" + unit + "]", slices_i]
csvWriter.writerow(meta_data)

# Second line: axes + second axis values
if len(axes_list_new) == 1:
A2_cell = axes_list_new[0].name + "[" + axes_list_new[0].unit + "]"
second_line = [A2_cell]
else:
A2_cell = (
axes_list_new[0].name
+ "["
+ axes_list_new[0].unit
+ "]"
+ "/"
+ axes_list_new[1].name
+ "["
+ axes_list_new[1].unit
+ "]"
)
second_line = format_matrix(
np.insert(
results[axes_list_new[1].name].astype("<U64"),
0,
A2_cell,
),
if is_2D:
# Second line: axes + second axis values
if len(Ydatas) > 1:
A2_cell = xlabel + "/" + legends[-1].split("=")[0]
second_line = format_matrix(
np.insert(
np.array(legends).astype("<U64"),
0,
A2_cell,
),
CHAR_LIST,
)
else:
A2_cell = xlabel
second_line = [A2_cell]
csvWriter.writerow(second_line)

# Rest of file: first axis + matrix
if len(Ydatas) == 1:
# Transpose if 1D array
field = np.array(Ydatas[0]).T
else:
field = np.array(Ydatas).T
matrix = format_matrix(
np.column_stack((np.array(Xdata[0]).T, field)).astype("str"),
CHAR_LIST,
)
csvWriter.writerow(second_line)

# Rest of file: first axis + matrix
if len(results[self.symbol].shape) == 1:
# Transpose if 1D array
field = results[self.symbol].T
elif nfiles > 1:
# Slice third axis
field = np.take(results[self.symbol], i, axis=2)
csvWriter.writerows(matrix)

else:
field = results[self.symbol]
matrix = format_matrix(
np.column_stack((results[axes_list_new[0].name].T, field)).astype(
"str"
),
CHAR_LIST,
)
csvWriter.writerows(matrix)
# Second line: axes + second axis values
if len(axes_list_new) == 1:
A2_cell = (
axes_list_new[0].name + "[" + axes_list_new[0].unit + "]"
)
second_line = [A2_cell]
else:
A2_cell = (
axes_list_new[0].name
+ "["
+ axes_list_new[0].unit
+ "]"
+ "/"
+ axes_list_new[1].name
+ "["
+ axes_list_new[1].unit
+ "]"
)
second_line = format_matrix(
np.insert(
results[axes_list_new[1].name].astype("<U64"),
0,
A2_cell,
),
CHAR_LIST,
)
csvWriter.writerow(second_line)

# Rest of file: first axis + matrix
if len(results[self.symbol].shape) == 1:
# Transpose if 1D array
field = results[self.symbol].T
elif nfiles > 1:
# Slice third axis
field = np.take(results[self.symbol], i, axis=2)
else:
field = results[self.symbol]
matrix = format_matrix(
np.column_stack(
(results[axes_list_new[0].name].T, field)
).astype("str"),
CHAR_LIST,
)
csvWriter.writerows(matrix)

else:
raise Exception("export format not supported")
Expand Down
17 changes: 10 additions & 7 deletions SciDataTool/Methods/DataND/plot_2D_Data.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def plot_2D_Data(
is_frame_legend=True,
is_indlabels=False,
annotations=None,
is_export=False,
):
"""Plots a field as a function of time
Expand Down Expand Up @@ -606,13 +607,11 @@ def plot_2D_Data(
cont = Ydatas[i + 1] / OVL
else:
cont = Ydatas[i + 1] ** 2 / OVL ** 2
contrib_array[i, :] = cont * OVL
contrib_array[i, :] = cont * 100
# Remove small contributions
Iloads = where(
sqrt(np_sum(contrib_array ** 2, 1)) > self.normalizations["ref"].ref
)[0]
# Iloads = where(sqrt(np_sum(contrib_array ** 2, 1)) > 1e-2)[0]
# Sort in decreasing order
Isort = argsort(-sqrt(np_sum(contrib_array[Iloads, :] ** 2, 1)))
Isort = argsort(-sqrt(np_sum(contrib_array ** 2, 1)), axis=0)
Ydatas = [Ydatas[0]]
legends = [r"100% (overall $" + symbol + "$)"]
new_color_list = [new_color_list[0]]
Expand All @@ -625,13 +624,13 @@ def plot_2D_Data(
if Isort[i] in axes_list[contrib_index].indices:
Ydatas.append(contrib_array[Isort[i], :])
legends.append(contrib_axis.values[Isort[i]])
new_color_list.append(color_list[i])
new_color_list.append(color_list[i % (len(color_list))])
elif "all" in selection or all(
[s in contrib_axis.values[Isort[i]] for s in selection]
):
Ydatas.append(contrib_array[Isort[i], :])
legends.append(contrib_axis.values[Isort[i]])
new_color_list.append(color_list[i])
new_color_list.append(color_list[i % (len(color_list))])
if "all" not in selection:
title = " ".join(selection).capitalize()
else:
Expand Down Expand Up @@ -730,6 +729,8 @@ def plot_2D_Data(
# Deactivate the option
fund_harm = None

if is_export:
return Xdatas, Ydatas, title, xlabel, ylabel, legends
plot_2D(
Xdatas,
Ydatas,
Expand Down Expand Up @@ -820,6 +821,8 @@ def plot_2D_Data(
except Exception:
pass

if is_export:
return Xdatas, Ydatas, title, xlabel, ylabel, legends
plot_2D(
Xdatas,
Ydatas,
Expand Down
Loading

0 comments on commit b7a12ec

Please sign in to comment.