diff --git a/py4DSTEM/process/strain/strain.py b/py4DSTEM/process/strain/strain.py index 099ecdefd..621fc820e 100644 --- a/py4DSTEM/process/strain/strain.py +++ b/py4DSTEM/process/strain/strain.py @@ -64,7 +64,7 @@ def __init__(self, braggvectors: BraggVectors, name: Optional[str] = "strainmap" braggvectors.Rshape[1], ) ), - slicelabels=["exx", "eyy", "exy", "theta", "mask", "error"], + slicelabels=["e_xx", "e_yy", "e_xy", "theta", "mask", "error"], ) # set up braggvectors @@ -662,7 +662,6 @@ def show_strain( vrange_exx=None, vrange_exy=None, vrange_eyy=None, - bkgrd=True, show_cbars=None, bordercolor="k", borderwidth=1, @@ -674,6 +673,7 @@ def show_strain( cmap_theta="PRGn", mask_color="k", color_axes="k", + show_legend=True, show_gvects=True, color_gvects="r", legend_camera_length=1.6, @@ -744,381 +744,46 @@ def show_strain( returnfig : bool Toggles returning the figure """ - # Lookup table for different layouts - assert layout in ("square", "horizontal", "vertical") - layout_lookup = { - "square": ["left", "right", "left", "right"], - "horizontal": ["bottom", "bottom", "bottom", "bottom"], - "vertical": ["right", "right", "right", "right"], - } - - layout_p = layout_lookup[layout] - - # Set which colorbars to display - if show_cbars is None: - if np.all( - [ - v is None - for v in ( - vrange_exx, - vrange_eyy, - vrange_exy, - ) - ] - ): - show_cbars = ("eyy", "theta") - else: - show_cbars = ("exx", "eyy", "exy", "theta") - else: - assert np.all([v in ("exx", "eyy", "exy", "theta") for v in show_cbars]) - - # Contrast limits - if vrange_exx is None: - vrange_exx = vrange - if vrange_exy is None: - vrange_exy = vrange - if vrange_eyy is None: - vrange_eyy = vrange - for vrange in (vrange_exx, vrange_eyy, vrange_exy, vrange_theta): - assert len(vrange) == 2, "vranges must have length 2" - vmin_exx, vmax_exx = vrange_exx[0] / 100.0, vrange_exx[1] / 100.0 - vmin_eyy, vmax_eyy = vrange_eyy[0] / 100.0, vrange_eyy[1] / 100.0 - vmin_exy, vmax_exy = vrange_exy[0] / 100.0, vrange_exy[1] / 100.0 - # theta is plotted in units of degrees - vmin_theta, vmax_theta = vrange_theta[0] / (180.0 / np.pi), vrange_theta[1] / ( - 180.0 / np.pi - ) - - # Get images - e_xx = np.ma.array( - self.get_slice("exx").data, - mask=self.get_slice("mask").data == False, # noqa: E712,E501 - ) - e_yy = np.ma.array( - self.get_slice("eyy").data, - mask=self.get_slice("mask").data == False, # noqa: E712,E501 - ) - e_xy = np.ma.array( - self.get_slice("exy").data, - mask=self.get_slice("mask").data == False, # noqa: E712,E501 - ) - theta = np.ma.array( - self.get_slice("theta").data, - mask=self.get_slice("mask").data == False, # noqa: E712 - ) - - ## Plot - - # if figsize hasn't been set, set it based on the - # chosen layout and the image shape - if figsize is None: - ratio = np.sqrt(self.rshape[1] / self.rshape[0]) - if layout == "square": - figsize = (13 * ratio, 8 / ratio) - elif layout == "horizontal": - figsize = (10 * ratio, 4 / ratio) - else: - figsize = (4 * ratio, 10 / ratio) - # set up layout - if layout == "square": - fig, ((ax11, ax12, ax_legend1), (ax21, ax22, ax_legend2)) = plt.subplots( - 2, 3, figsize=figsize - ) - elif layout == "horizontal": - figsize = (figsize[0] * np.sqrt(2), figsize[1] / np.sqrt(2)) - fig, (ax11, ax12, ax21, ax22, ax_legend) = plt.subplots( - 1, 5, figsize=figsize - ) - else: - figsize = (figsize[0] / np.sqrt(2), figsize[1] * np.sqrt(2)) - fig, (ax11, ax12, ax21, ax22, ax_legend) = plt.subplots( - 5, 1, figsize=figsize - ) + from py4DSTEM.visualize import show_strain - # display images, returning cbar axis references - cax11 = show( - e_xx, - figax=(fig, ax11), - vmin=vmin_exx, - vmax=vmax_exx, - intensity_range="absolute", - cmap=cmap, - mask=self.mask, - mask_color=mask_color, - returncax=True, - ) - cax12 = show( - e_yy, - figax=(fig, ax12), - vmin=vmin_eyy, - vmax=vmax_eyy, - intensity_range="absolute", - cmap=cmap, - mask=self.mask, - mask_color=mask_color, - returncax=True, - ) - cax21 = show( - e_xy, - figax=(fig, ax21), - vmin=vmin_exy, - vmax=vmax_exy, - intensity_range="absolute", + fig, ax = show_strain( + self, + vrange=vrange, + vrange_theta=vrange_theta, + vrange_exx=vrange_exx, + vrange_exy=vrange_exy, + vrange_eyy=vrange_eyy, + show_cbars=show_cbars, + bordercolor=bordercolor, + borderwidth=borderwidth, + titlesize=titlesize, + ticklabelsize=ticklabelsize, + ticknumber=ticknumber, + unitlabelsize=unitlabelsize, cmap=cmap, - mask=self.mask, + cmap_theta=cmap_theta, mask_color=mask_color, - returncax=True, - ) - cax22 = show( - theta, - figax=(fig, ax22), - vmin=vmin_theta, - vmax=vmax_theta, - intensity_range="absolute", - cmap=cmap_theta, - mask=self.mask, - mask_color=mask_color, - returncax=True, - ) - ax11.set_title(r"$\epsilon_{xx}$", size=titlesize) - ax12.set_title(r"$\epsilon_{yy}$", size=titlesize) - ax21.set_title(r"$\epsilon_{xy}$", size=titlesize) - ax22.set_title(r"$\theta$", size=titlesize) - - # Add black background - if bkgrd: - mask = np.ma.masked_where( - self.get_slice("mask").data.astype(bool), - np.zeros_like(self.get_slice("mask").data), - ) - ax11.matshow(mask, cmap="gray") - ax12.matshow(mask, cmap="gray") - ax21.matshow(mask, cmap="gray") - ax22.matshow(mask, cmap="gray") - - # add colorbars - show_cbars = np.array( - [ - "exx" in show_cbars, - "eyy" in show_cbars, - "exy" in show_cbars, - "theta" in show_cbars, - ] - ) - if np.any(show_cbars): - divider11 = make_axes_locatable(ax11) - divider12 = make_axes_locatable(ax12) - divider21 = make_axes_locatable(ax21) - divider22 = make_axes_locatable(ax22) - cbax11 = divider11.append_axes(layout_p[0], size="4%", pad=0.15) - cbax12 = divider12.append_axes(layout_p[1], size="4%", pad=0.15) - cbax21 = divider21.append_axes(layout_p[2], size="4%", pad=0.15) - cbax22 = divider22.append_axes(layout_p[3], size="4%", pad=0.15) - for ind, show_cbar, cax, cbax, vmin, vmax, tickside, tickunits in zip( - range(4), - show_cbars, - (cax11, cax12, cax21, cax22), - (cbax11, cbax12, cbax21, cbax22), - (vmin_exx, vmin_eyy, vmin_exy, vmin_theta), - (vmax_exx, vmax_eyy, vmax_exy, vmax_theta), - (layout_p[0], layout_p[1], layout_p[2], layout_p[3]), - ("% ", " %", "% ", r" $^\circ$"), - ): - if show_cbar: - ticks = np.linspace(vmin, vmax, ticknumber, endpoint=True) - if ind < 3: - ticklabels = np.round( - np.linspace( - 100 * vmin, 100 * vmax, ticknumber, endpoint=True - ), - decimals=2, - ).astype(str) - else: - ticklabels = np.round( - np.linspace( - (180 / np.pi) * vmin, - (180 / np.pi) * vmax, - ticknumber, - endpoint=True, - ), - decimals=2, - ).astype(str) - - if tickside in ("left", "right"): - cb = plt.colorbar( - cax, cax=cbax, ticks=ticks, orientation="vertical" - ) - cb.ax.set_yticklabels(ticklabels, size=ticklabelsize) - cbax.yaxis.set_ticks_position(tickside) - cbax.set_ylabel(tickunits, size=unitlabelsize, rotation=0) - cbax.yaxis.set_label_position(tickside) - else: - cb = plt.colorbar( - cax, cax=cbax, ticks=ticks, orientation="horizontal" - ) - cb.ax.set_xticklabels(ticklabels, size=ticklabelsize) - cbax.xaxis.set_ticks_position(tickside) - cbax.set_xlabel(tickunits, size=unitlabelsize, rotation=0) - cbax.xaxis.set_label_position(tickside) - else: - cbax.axis("off") - - # Add borders - if bordercolor is not None: - for ax in (ax11, ax12, ax21, ax22): - for s in ["bottom", "top", "left", "right"]: - ax.spines[s].set_color(bordercolor) - ax.spines[s].set_linewidth(borderwidth) - ax.set_xticks([]) - ax.set_yticks([]) - - # Legend - - # for layout "square", combine vertical plots on the right end - if layout == "square": - # get gridspec object - gs = ax_legend1.get_gridspec() - # remove last two axes - ax_legend1.remove() - ax_legend2.remove() - # make new axis - ax_legend = fig.add_subplot(gs[:, -1]) - - # get the coordinate axes' directions - rotation = self.coordinate_rotation_radians - xaxis_vectx = np.cos(rotation) - xaxis_vecty = np.sin(rotation) - yaxis_vectx = np.cos(rotation + np.pi / 2) - yaxis_vecty = np.sin(rotation + np.pi / 2) - - # make the coordinate axes - ax_legend.arrow( - x=0, - y=0, - dx=xaxis_vecty, - dy=xaxis_vectx, - color=color_axes, - length_includes_head=True, - width=0.01, - head_width=0.1, - ) - ax_legend.arrow( - x=0, - y=0, - dx=yaxis_vecty, - dy=yaxis_vectx, - color=color_axes, - length_includes_head=True, - width=0.01, - head_width=0.1, - ) - ax_legend.text( - x=xaxis_vecty * 1.16, - y=xaxis_vectx * 1.16, - s="x", - fontsize=14, - color=color_axes, - horizontalalignment="center", - verticalalignment="center", - ) - ax_legend.text( - x=yaxis_vecty * 1.16, - y=yaxis_vectx * 1.16, - s="y", - fontsize=14, - color=color_axes, - horizontalalignment="center", - verticalalignment="center", + color_axes=color_axes, + show_legend=show_legend, + rotation_deg=np.rad2deg(self.coordinate_rotation_radians), + show_gvects=show_gvects, + g1=self.g1, + g2=self.g2, + color_gvects=color_gvects, + legend_camera_length=legend_camera_length, + scale_gvects=scale_gvects, + layout=layout, + figsize=figsize, + returnfig=True, ) - # make the g-vectors - if show_gvects: - # get the g-vectors directions - g1q = np.array(self.g1) - g2q = np.array(self.g2) - g1norm = np.linalg.norm(g1q) - g2norm = np.linalg.norm(g2q) - g1q /= g1norm - g2q /= g2norm - # set the lengths - g_ratio = g2norm / g1norm - if g_ratio > 1: - g1q /= g_ratio - else: - g2q *= g_ratio - g1_x, g1_y = g1q - g2_x, g2_y = g2q - - # draw the g vectors - ax_legend.arrow( - x=0, - y=0, - dx=g1_y * scale_gvects, - dy=g1_x * scale_gvects, - color=color_gvects, - length_includes_head=True, - width=0.005, - head_width=0.05, - ) - ax_legend.arrow( - x=0, - y=0, - dx=g2_y * scale_gvects, - dy=g2_x * scale_gvects, - color=color_gvects, - length_includes_head=True, - width=0.005, - head_width=0.05, - ) - ax_legend.text( - x=g1_y * scale_gvects * 1.2, - y=g1_x * scale_gvects * 1.2, - s=r"$g_1$", - fontsize=12, - color=color_gvects, - horizontalalignment="center", - verticalalignment="center", - ) - ax_legend.text( - x=g2_y * scale_gvects * 1.2, - y=g2_x * scale_gvects * 1.2, - s=r"$g_2$", - fontsize=12, - color=color_gvects, - horizontalalignment="center", - verticalalignment="center", - ) - - # find center and extent - xmin = np.min([0, 0, xaxis_vectx, yaxis_vectx]) - xmax = np.max([0, 0, xaxis_vectx, yaxis_vectx]) - ymin = np.min([0, 0, xaxis_vecty, yaxis_vecty]) - ymax = np.max([0, 0, xaxis_vecty, yaxis_vecty]) - if show_gvects: - xmin = np.min([xmin, g1_x, g2_x]) - xmax = np.max([xmax, g1_x, g2_x]) - ymin = np.min([ymin, g1_y, g2_y]) - ymax = np.max([ymax, g1_y, g2_y]) - x0 = np.mean([xmin, xmax]) - y0 = np.mean([ymin, ymax]) - xL = (xmax - x0) * legend_camera_length - yL = (ymax - y0) * legend_camera_length - - # set the extent and aspect - ax_legend.set_xlim([y0 - yL, y0 + yL]) - ax_legend.set_ylim([x0 - xL, x0 + xL]) - ax_legend.invert_yaxis() - ax_legend.set_aspect("equal") - ax_legend.axis("off") - # show/return if not returnfig: plt.show() return else: - axs = ((ax11, ax12), (ax21, ax22)) - return fig, axs + return fig, ax def show_reference_directions( self, diff --git a/py4DSTEM/visualize/vis_special.py b/py4DSTEM/visualize/vis_special.py index c1e9d6b19..125a2ce67 100644 --- a/py4DSTEM/visualize/vis_special.py +++ b/py4DSTEM/visualize/vis_special.py @@ -895,3 +895,463 @@ def return_scaled_histogram_ordering(array, vmin=None, vmax=None, normalize=Fals vmax = 1 return scaled_array, vmin, vmax + + +def show_strain( + data, + vrange=[-3, 3], + vrange_theta=[-3, 3], + vrange_exx=None, + vrange_exy=None, + vrange_eyy=None, + show_cbars=None, + bordercolor="k", + borderwidth=1, + titlesize=18, + ticklabelsize=10, + ticknumber=5, + unitlabelsize=16, + cmap="RdBu_r", + cmap_theta="PRGn", + mask_color="k", + color_axes="k", + show_legend=False, + rotation_deg=0, + show_gvects=True, + g1=None, + g2=None, + color_gvects="r", + legend_camera_length=1.6, + scale_gvects=0.6, + layout="square", + figsize=None, + returnfig=False, +): + """ + Display a strain map, showing the 4 strain components + (e_xx,e_yy,e_xy,theta), and masking each image with + strainmap.get_slice('mask') + + Parameters + ---------- + data : strainmap + vrange : length 2 list or tuple + The colorbar intensity range for exx,eyy, and exy. + vrange_theta : length 2 list or tuple + The colorbar intensity range for theta. + vrange_exx : length 2 list or tuple + The colorbar intensity range for exx; overrides `vrange` + for exx + vrange_exy : length 2 list or tuple + The colorbar intensity range for exy; overrides `vrange` + for exy + vrange_eyy : length 2 list or tuple + The colorbar intensity range for eyy; overrides `vrange` + for eyy + show_cbars : None or a tuple of strings + Show colorbars for the specified axes. Valid strings are + 'exx', 'eyy', 'exy', and 'theta'. + bordercolor : color + Color for the image borders + borderwidth : number + Width of the image borders + titlesize : number + Size of the image titles + ticklabelsize : number + Size of the colorbar ticks + ticknumber : number + Number of ticks on colorbars + unitlabelsize : number + Size of the units label on the colorbars + cmap : colormap + Colormap for exx, exy, and eyy + cmap_theta : colormap + Colormap for theta + mask_color : color + Color for the background mask + color_axes : color + Color for the legend coordinate axes + show_gvects : bool + Toggles displaying the g-vectors in the legend + rotation_deg : float + coordinate rotation for strainmap in degrees + g1 : tuple + g1 orientation (x,y) + g2 : tuple + g2 orientation (x,y) + color_gvects : color + Color for the legend g-vectors + legend_camera_length : number + The distance the legend is viewed from; a smaller number yields + a larger legend + scale_gvects : number + Scaling for the legend g-vectors relative to the coordinate axes + layout : int + Determines the layout of the grid which the strain components + will be plotted in. Must be in (0,1,2). 0=(2x2), 1=(1x4), 2=(4x1). + figsize : length 2 tuple of numbers + Size of the figure + returnfig : bool + Toggles returning the figure + """ + # Lookup table for different layouts + assert layout in ("square", "horizontal", "vertical") + layout_lookup = { + "square": ["left", "right", "left", "right"], + "horizontal": ["bottom", "bottom", "bottom", "bottom"], + "vertical": ["right", "right", "right", "right"], + } + + layout_p = layout_lookup[layout] + + # Set which colorbars to display + if show_cbars is None: + if np.all( + [ + v is None + for v in ( + vrange_exx, + vrange_eyy, + vrange_exy, + ) + ] + ): + show_cbars = ("eyy", "theta") + else: + show_cbars = ("exx", "eyy", "exy", "theta") + else: + assert np.all([v in ("exx", "eyy", "exy", "theta") for v in show_cbars]) + + # Contrast limits + if vrange_exx is None: + vrange_exx = vrange + if vrange_exy is None: + vrange_exy = vrange + if vrange_eyy is None: + vrange_eyy = vrange + for vrange in (vrange_exx, vrange_eyy, vrange_exy, vrange_theta): + assert len(vrange) == 2, "vranges must have length 2" + vmin_exx, vmax_exx = vrange_exx[0] / 100.0, vrange_exx[1] / 100.0 + vmin_eyy, vmax_eyy = vrange_eyy[0] / 100.0, vrange_eyy[1] / 100.0 + vmin_exy, vmax_exy = vrange_exy[0] / 100.0, vrange_exy[1] / 100.0 + # theta is plotted in units of degrees + vmin_theta, vmax_theta = vrange_theta[0] / (180.0 / np.pi), vrange_theta[1] / ( + 180.0 / np.pi + ) + + # Get images + mask = data.get_slice("mask").data == False # noqa: E712,E501 + e_xx = np.ma.array(data.get_slice("e_xx").data, mask=mask) + e_yy = np.ma.array(data.get_slice("e_yy").data, mask=mask) + e_xy = np.ma.array(data.get_slice("e_xy").data, mask=mask) + theta = np.ma.array(data.get_slice("theta").data, mask=mask) + # e_xx = data.get_slice("e_xx").data + # e_yy = data.get_slice("e_yy").data + # e_xy = data.get_slice("e_xy").data + # theta = data.get_slice("theta").data + + ## Plot + + # if figsize hasn't been set, set it based on the + # chosen layout and the image shape + if figsize is None: + ratio = np.sqrt(e_xx.shape[1] / e_xx.shape[0]) + if layout == "square": + figsize = (13 * ratio, 8 / ratio) + elif layout == "horizontal": + figsize = (10 * ratio, 4 / ratio) + else: + figsize = (4 * ratio, 10 / ratio) + + # set up layout + if show_legend: + if layout == "square": + fig, ((ax11, ax12, ax_legend1), (ax21, ax22, ax_legend2)) = plt.subplots( + 2, 3, figsize=figsize + ) + elif layout == "horizontal": + figsize = (figsize[0] * np.sqrt(2), figsize[1] / np.sqrt(2)) + fig, (ax11, ax12, ax21, ax22, ax_legend) = plt.subplots( + 1, 5, figsize=figsize + ) + else: + figsize = (figsize[0] / np.sqrt(2), figsize[1] * np.sqrt(2)) + fig, (ax11, ax12, ax21, ax22, ax_legend) = plt.subplots( + 5, 1, figsize=figsize + ) + else: + if layout == "square": + fig, ((ax11, ax12), (ax21, ax22)) = plt.subplots(2, 2, figsize=figsize) + elif layout == "horizontal": + figsize = (figsize[0] * np.sqrt(2), figsize[1] / np.sqrt(2)) + fig, (ax11, ax12, ax21, ax22) = plt.subplots(1, 4, figsize=figsize) + else: + figsize = (figsize[0] / np.sqrt(2), figsize[1] * np.sqrt(2)) + fig, (ax11, ax12, ax21, ax22) = plt.subplots(4, 1, figsize=figsize) + + # display images, returning cbar axis references + cax11 = show( + e_xx, + figax=(fig, ax11), + vmin=vmin_exx, + vmax=vmax_exx, + intensity_range="absolute", + cmap=cmap, + mask_color=mask_color, + returncax=True, + ) + cax12 = show( + e_yy, + figax=(fig, ax12), + vmin=vmin_eyy, + vmax=vmax_eyy, + intensity_range="absolute", + cmap=cmap, + mask_color=mask_color, + returncax=True, + ) + cax21 = show( + e_xy, + figax=(fig, ax21), + vmin=vmin_exy, + vmax=vmax_exy, + intensity_range="absolute", + cmap=cmap, + mask_color=mask_color, + returncax=True, + ) + cax22 = show( + theta, + figax=(fig, ax22), + vmin=vmin_theta, + vmax=vmax_theta, + intensity_range="absolute", + cmap=cmap_theta, + mask_color=mask_color, + returncax=True, + ) + ax11.set_title(r"$\epsilon_{xx}$", size=titlesize) + ax12.set_title(r"$\epsilon_{yy}$", size=titlesize) + ax21.set_title(r"$\epsilon_{xy}$", size=titlesize) + ax22.set_title(r"$\theta$", size=titlesize) + + # add colorbars + show_cbars = np.array( + [ + "exx" in show_cbars, + "eyy" in show_cbars, + "exy" in show_cbars, + "theta" in show_cbars, + ] + ) + if np.any(show_cbars): + divider11 = make_axes_locatable(ax11) + divider12 = make_axes_locatable(ax12) + divider21 = make_axes_locatable(ax21) + divider22 = make_axes_locatable(ax22) + cbax11 = divider11.append_axes(layout_p[0], size="4%", pad=0.15) + cbax12 = divider12.append_axes(layout_p[1], size="4%", pad=0.15) + cbax21 = divider21.append_axes(layout_p[2], size="4%", pad=0.15) + cbax22 = divider22.append_axes(layout_p[3], size="4%", pad=0.15) + for ind, show_cbar, cax, cbax, vmin, vmax, tickside, tickunits in zip( + range(4), + show_cbars, + (cax11, cax12, cax21, cax22), + (cbax11, cbax12, cbax21, cbax22), + (vmin_exx, vmin_eyy, vmin_exy, vmin_theta), + (vmax_exx, vmax_eyy, vmax_exy, vmax_theta), + (layout_p[0], layout_p[1], layout_p[2], layout_p[3]), + ("% ", " %", "% ", r" $^\circ$"), + ): + if show_cbar: + ticks = np.linspace(vmin, vmax, ticknumber, endpoint=True) + if ind < 3: + ticklabels = np.round( + np.linspace(100 * vmin, 100 * vmax, ticknumber, endpoint=True), + decimals=2, + ).astype(str) + else: + ticklabels = np.round( + np.linspace( + (180 / np.pi) * vmin, + (180 / np.pi) * vmax, + ticknumber, + endpoint=True, + ), + decimals=2, + ).astype(str) + + if tickside in ("left", "right"): + cb = plt.colorbar( + cax, cax=cbax, ticks=ticks, orientation="vertical" + ) + cb.ax.set_yticklabels(ticklabels, size=ticklabelsize) + cbax.yaxis.set_ticks_position(tickside) + cbax.set_ylabel(tickunits, size=unitlabelsize, rotation=0) + cbax.yaxis.set_label_position(tickside) + else: + cb = plt.colorbar( + cax, cax=cbax, ticks=ticks, orientation="horizontal" + ) + cb.ax.set_xticklabels(ticklabels, size=ticklabelsize) + cbax.xaxis.set_ticks_position(tickside) + cbax.set_xlabel(tickunits, size=unitlabelsize, rotation=0) + cbax.xaxis.set_label_position(tickside) + else: + cbax.axis("off") + + # Add borders + if bordercolor is not None: + for ax in (ax11, ax12, ax21, ax22): + for s in ["bottom", "top", "left", "right"]: + ax.spines[s].set_color(bordercolor) + ax.spines[s].set_linewidth(borderwidth) + ax.set_xticks([]) + ax.set_yticks([]) + + # Legend + if show_legend: + # for layout "square", combine vertical plots on the right end + if layout == "square": + # get gridspec object + gs = ax_legend1.get_gridspec() + # remove last two axes + ax_legend1.remove() + ax_legend2.remove() + # make new axis + ax_legend = fig.add_subplot(gs[:, -1]) + + # get the coordinate axes' directions + rotation = np.deg2rad(rotation_deg) + xaxis_vectx = np.cos(rotation) + xaxis_vecty = np.sin(rotation) + yaxis_vectx = np.cos(rotation + np.pi / 2) + yaxis_vecty = np.sin(rotation + np.pi / 2) + + # make the coordinate axes + ax_legend.arrow( + x=0, + y=0, + dx=xaxis_vecty, + dy=xaxis_vectx, + color=color_axes, + length_includes_head=True, + width=0.01, + head_width=0.1, + ) + ax_legend.arrow( + x=0, + y=0, + dx=yaxis_vecty, + dy=yaxis_vectx, + color=color_axes, + length_includes_head=True, + width=0.01, + head_width=0.1, + ) + ax_legend.text( + x=xaxis_vecty * 1.16, + y=xaxis_vectx * 1.16, + s="x", + fontsize=14, + color=color_axes, + horizontalalignment="center", + verticalalignment="center", + ) + ax_legend.text( + x=yaxis_vecty * 1.16, + y=yaxis_vectx * 1.16, + s="y", + fontsize=14, + color=color_axes, + horizontalalignment="center", + verticalalignment="center", + ) + + # make the g-vectors + if show_gvects: + # get the g-vectors directions + g1q = np.array(g1) + g2q = np.array(g2) + g1norm = np.linalg.norm(g1q) + g2norm = np.linalg.norm(g2q) + g1q /= g1norm + g2q /= g2norm + # set the lengths + g_ratio = g2norm / g1norm + if g_ratio > 1: + g1q /= g_ratio + else: + g2q *= g_ratio + g1_x, g1_y = g1q + g2_x, g2_y = g2q + + # draw the g vectors + ax_legend.arrow( + x=0, + y=0, + dx=g1_y * scale_gvects, + dy=g1_x * scale_gvects, + color=color_gvects, + length_includes_head=True, + width=0.005, + head_width=0.05, + ) + ax_legend.arrow( + x=0, + y=0, + dx=g2_y * scale_gvects, + dy=g2_x * scale_gvects, + color=color_gvects, + length_includes_head=True, + width=0.005, + head_width=0.05, + ) + ax_legend.text( + x=g1_y * scale_gvects * 1.2, + y=g1_x * scale_gvects * 1.2, + s=r"$g_1$", + fontsize=12, + color=color_gvects, + horizontalalignment="center", + verticalalignment="center", + ) + ax_legend.text( + x=g2_y * scale_gvects * 1.2, + y=g2_x * scale_gvects * 1.2, + s=r"$g_2$", + fontsize=12, + color=color_gvects, + horizontalalignment="center", + verticalalignment="center", + ) + + # find center and extent + xmin = np.min([0, 0, xaxis_vectx, yaxis_vectx]) + xmax = np.max([0, 0, xaxis_vectx, yaxis_vectx]) + ymin = np.min([0, 0, xaxis_vecty, yaxis_vecty]) + ymax = np.max([0, 0, xaxis_vecty, yaxis_vecty]) + if show_gvects: + xmin = np.min([xmin, g1_x, g2_x]) + xmax = np.max([xmax, g1_x, g2_x]) + ymin = np.min([ymin, g1_y, g2_y]) + ymax = np.max([ymax, g1_y, g2_y]) + x0 = np.mean([xmin, xmax]) + y0 = np.mean([ymin, ymax]) + xL = (xmax - x0) * legend_camera_length + yL = (ymax - y0) * legend_camera_length + + # set the extent and aspect + ax_legend.set_xlim([y0 - yL, y0 + yL]) + ax_legend.set_ylim([x0 - xL, x0 + xL]) + ax_legend.invert_yaxis() + ax_legend.set_aspect("equal") + ax_legend.axis("off") + + # show/return + if not returnfig: + plt.show() + return + else: + axs = ((ax11, ax12), (ax21, ax22)) + return fig, axs