Expand Up @@ -5,13 +5,36 @@ from os.path import commonprefix import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np from . import config __all__ = ['suptitle', 'get_plot_axes'] # # Style parameters # _ctrlplot_rcParams = mpl.rcParams.copy() _ctrlplot_rcParams.update({ 'axes.labelsize': 'small', 'axes.titlesize': 'small', 'figure.titlesize': 'medium', 'legend.fontsize': 'x-small', 'xtick.labelsize': 'small', 'ytick.labelsize': 'small', }) # # User functions # # The functions below can be used by users to modify ctrl plots or get # information about them. # def suptitle( title, fig=None, frame='axes', **kwargs): Expand All @@ -35,7 +58,7 @@ def suptitle( Additional keywords (passed to matplotlib). """ rcParams = config._get_param('freqplot ', 'rcParams', kwargs, pop=True) rcParams = config._get_param('ctrlplot ', 'rcParams', kwargs, pop=True) if fig is None: fig = plt.gcf() Expand All @@ -61,10 +84,10 @@ def suptitle( def get_plot_axes(line_array): """Get a list of axes from an array of lines. This function can be used to return the set of axes corresponding to the line array that is returned by `time_response_plot`. This is useful for generating an axes array that can be passed to subsequent plotting calls. This function can be used to return the set of axes corresponding to the line array that is returned by `time_response_plot`. Thisis useful for generating an axes array that can be passed tosubsequent plotting calls. Parameters ---------- Expand All @@ -89,6 +112,125 @@ def get_plot_axes(line_array): # # Utility functions # # These functions are used by plotting routines to provide a consistent way # of processing and displaying information. # def _process_ax_keyword( axs, shape=(1, 1), rcParams=None, squeeze=False, clear_text=False): """Utility function to process ax keyword to plotting commands. This function processes the `ax` keyword to plotting commands. If no ax keyword is passed, the current figure is checked to see if it has the correct shape. If the shape matches the desired shape, then the current figure and axes are returned. Otherwise a new figure is created with axes of the desired shape. Legacy behavior: some of the older plotting commands use a axes label to identify the proper axes for plotting. This behavior is supported through the use of the label keyword, but will only work if shape == (1, 1) and squeeze == True. """ if axs is None: fig = plt.gcf() # get current figure (or create new one) axs = fig.get_axes() # Check to see if axes are the right shape; if not, create new figure # Note: can't actually check the shape, just the total number of axes if len(axs) != np.prod(shape): with plt.rc_context(rcParams): if len(axs) != 0: # Create a new figure fig, axs = plt.subplots(*shape, squeeze=False) else: # Create new axes on (empty) figure axs = fig.subplots(*shape, squeeze=False) fig.set_layout_engine('tight') fig.align_labels() else: # Use the existing axes, properly reshaped axs = np.asarray(axs).reshape(*shape) if clear_text: # Clear out any old text from the current figure for text in fig.texts: text.set_visible(False) # turn off the text del text # get rid of it completely else: try: axs = np.asarray(axs).reshape(shape) except ValueError: raise ValueError( "specified axes are not the right shape; " f"got {axs.shape} but expecting {shape}") fig = axs[0, 0].figure # Process the squeeze keyword if squeeze and shape == (1, 1): axs = axs[0, 0] # Just return the single axes object elif squeeze: axs = axs.squeeze() return fig, axs # Turn label keyword into array indexed by trace, output, input # TODO: move to ctrlutil.py and update parameter names to reflect general use def _process_line_labels(label, ntraces, ninputs=0, noutputs=0): if label is None: return None if isinstance(label, str): label = [label] * ntraces # single label for all traces # Convert to an ndarray, if not done aleady try: line_labels = np.asarray(label) except ValueError: raise ValueError("label must be a string or array_like") # Turn the data into a 3D array of appropriate shape # TODO: allow more sophisticated broadcasting (and error checking) try: if ninputs > 0 and noutputs > 0: if line_labels.ndim == 1 and line_labels.size == ntraces: line_labels = line_labels.reshape(ntraces, 1, 1) line_labels = np.broadcast_to( line_labels, (ntraces, ninputs, noutputs)) else: line_labels = line_labels.reshape(ntraces, ninputs, noutputs) except ValueError: if line_labels.shape[0] != ntraces: raise ValueError("number of labels must match number of traces") else: raise ValueError("labels must be given for each input/output pair") return line_labels # Get labels for all lines in an axes def _get_line_labels(ax, use_color=True): labels, lines = [], [] last_color, counter = None, 0 # label unknown systems for i, line in enumerate(ax.get_lines()): label = line.get_label() if use_color and label.startswith("Unknown"): label = f"Unknown-{counter}" if last_color is None: last_color = line.get_color() elif last_color != line.get_color(): counter += 1 last_color = line.get_color() elif label[0] == '_': continue if label not in labels: lines.append(line) labels.append(label) return lines, labels # Utility function to make legend labels Expand Down Expand Up @@ -160,3 +302,83 @@ def _find_axes_center(fig, axs): ylim = [min(ll[1], ylim[0]), max(ur[1], ylim[1])] return (np.sum(xlim)/2, np.sum(ylim)/2) # Internal function to add arrows to a curve def _add_arrows_to_line2D( axes, line, arrow_locs=[0.2, 0.4, 0.6, 0.8], arrowstyle='-|>', arrowsize=1, dir=1): """ Add arrows to a matplotlib.lines.Line2D at selected locations. Parameters: ----------- axes: Axes object as returned by axes command (or gca) line: Line2D object as returned by plot command arrow_locs: list of locations where to insert arrows, % of total length arrowstyle: style of the arrow arrowsize: size of the arrow Returns: -------- arrows: list of arrows Based on https://stackoverflow.com/questions/26911898/ """ # Get the coordinates of the line, in plot coordinates if not isinstance(line, mpl.lines.Line2D): raise ValueError("expected a matplotlib.lines.Line2D object") x, y = line.get_xdata(), line.get_ydata() # Determine the arrow properties arrow_kw = {"arrowstyle": arrowstyle} color = line.get_color() use_multicolor_lines = isinstance(color, np.ndarray) if use_multicolor_lines: raise NotImplementedError("multicolor lines not supported") else: arrow_kw['color'] = color linewidth = line.get_linewidth() if isinstance(linewidth, np.ndarray): raise NotImplementedError("multiwidth lines not supported") else: arrow_kw['linewidth'] = linewidth # Figure out the size of the axes (length of diagonal) xlim, ylim = axes.get_xlim(), axes.get_ylim() ul, lr = np.array([xlim[0], ylim[0]]), np.array([xlim[1], ylim[1]]) diag = np.linalg.norm(ul - lr) # Compute the arc length along the curve s = np.cumsum(np.sqrt(np.diff(x) ** 2 + np.diff(y) ** 2)) # Truncate the number of arrows if the curve is short # TODO: figure out a smarter way to do this frac = min(s[-1] / diag, 1) if len(arrow_locs) and frac < 0.05: arrow_locs = [] # too short; no arrows at all elif len(arrow_locs) and frac < 0.2: arrow_locs = [0.5] # single arrow in the middle # Plot the arrows (and return list if patches) arrows = [] for loc in arrow_locs: n = np.searchsorted(s, s[-1] * loc) if dir == 1 and n == 0: # Move the arrow forward by one if it is at start of a segment n = 1 # Place the head of the arrow at the desired location arrow_head = [x[n], y[n]] arrow_tail = [x[n - dir], y[n - dir]] p = mpl.patches.FancyArrowPatch( arrow_tail, arrow_head, transform=axes.transData, lw=0, **arrow_kw) axes.add_patch(p) arrows.append(p) return arrows