5
5
6
6
from os .path import commonprefix
7
7
8
+ import matplotlib as mpl
8
9
import matplotlib .pyplot as plt
9
10
import numpy as np
10
11
11
12
from .import config
12
13
13
14
__all__ = ['suptitle' ,'get_plot_axes' ]
14
15
16
+ #
17
+ # Style parameters
18
+ #
19
+
20
+ _ctrlplot_rcParams = mpl .rcParams .copy ()
21
+ _ctrlplot_rcParams .update ({
22
+ 'axes.labelsize' :'small' ,
23
+ 'axes.titlesize' :'small' ,
24
+ 'figure.titlesize' :'medium' ,
25
+ 'legend.fontsize' :'x-small' ,
26
+ 'xtick.labelsize' :'small' ,
27
+ 'ytick.labelsize' :'small' ,
28
+ })
29
+
30
+
31
+ #
32
+ # User functions
33
+ #
34
+ # The functions below can be used by users to modify ctrl plots or get
35
+ # information about them.
36
+ #
37
+
15
38
16
39
def suptitle (
17
40
title ,fig = None ,frame = 'axes' ,** kwargs ):
@@ -35,7 +58,7 @@ def suptitle(
35
58
Additional keywords (passed to matplotlib).
36
59
37
60
"""
38
- rcParams = config ._get_param ('freqplot ' ,'rcParams' ,kwargs ,pop = True )
61
+ rcParams = config ._get_param ('ctrlplot ' ,'rcParams' ,kwargs ,pop = True )
39
62
40
63
if fig is None :
41
64
fig = plt .gcf ()
@@ -61,10 +84,10 @@ def suptitle(
61
84
def get_plot_axes (line_array ):
62
85
"""Get a list of axes from an array of lines.
63
86
64
- This function can be used to return the set of axes corresponding to
65
- the line array that is returned by `time_response_plot`. This is useful for
66
- generating an axes array that can be passed to subsequent plotting
67
- calls.
87
+ This function can be used to return the set of axes corresponding
88
+ to the line array that is returned by `time_response_plot`. This
89
+ is useful for generating an axes array that can be passed to
90
+ subsequent plotting calls.
68
91
69
92
Parameters
70
93
----------
@@ -89,6 +112,125 @@ def get_plot_axes(line_array):
89
112
#
90
113
# Utility functions
91
114
#
115
+ # These functions are used by plotting routines to provide a consistent way
116
+ # of processing and displaing information.
117
+ #
118
+
119
+
120
+ def _process_ax_keyword (
121
+ axs ,shape = (1 ,1 ),rcParams = None ,squeeze = False ,clear_text = False ):
122
+ """Utility function to process ax keyword to plotting commands.
123
+
124
+ This function processes the `ax` keyword to plotting commands. If no
125
+ ax keyword is passed, the current figure is checked to see if it has
126
+ the correct shape. If the shape matches the desired shape, then the
127
+ current figure and axes are returned. Otherwise a new figure is
128
+ created with axes of the desired shape.
129
+
130
+ Legacy behavior: some of the older plotting commands use a axes label
131
+ to identify the proper axes for plotting. This behavior is supported
132
+ through the use of the label keyword, but will only work if shape ==
133
+ (1, 1) and squeeze == True.
134
+
135
+ """
136
+ if axs is None :
137
+ fig = plt .gcf ()# get current figure (or create new one)
138
+ axs = fig .get_axes ()
139
+
140
+ # Check to see if axes are the right shape; if not, create new figure
141
+ # Note: can't actually check the shape, just the total number of axes
142
+ if len (axs )!= np .prod (shape ):
143
+ with plt .rc_context (rcParams ):
144
+ if len (axs )!= 0 :
145
+ # Create a new figure
146
+ fig ,axs = plt .subplots (* shape ,squeeze = False )
147
+ else :
148
+ # Create new axes on (empty) figure
149
+ axs = fig .subplots (* shape ,squeeze = False )
150
+ fig .set_layout_engine ('tight' )
151
+ fig .align_labels ()
152
+ else :
153
+ # Use the existing axes, properly reshaped
154
+ axs = np .asarray (axs ).reshape (* shape )
155
+
156
+ if clear_text :
157
+ # Clear out any old text from the current figure
158
+ for text in fig .texts :
159
+ text .set_visible (False )# turn off the text
160
+ del text # get rid of it completely
161
+ else :
162
+ try :
163
+ axs = np .asarray (axs ).reshape (shape )
164
+ except ValueError :
165
+ raise ValueError (
166
+ "specified axes are not the right shape; "
167
+ f"got{ axs .shape } but expecting{ shape } " )
168
+ fig = axs [0 ,0 ].figure
169
+
170
+ # Process the squeeze keyword
171
+ if squeeze and shape == (1 ,1 ):
172
+ axs = axs [0 ,0 ]# Just return the single axes object
173
+ elif squeeze :
174
+ axs = axs .squeeze ()
175
+
176
+ return fig ,axs
177
+
178
+
179
+ # Turn label keyword into array indexed by trace, output, input
180
+ # TODO: move to ctrlutil.py and update parameter names to reflect general use
181
+ def _process_line_labels (label ,ntraces ,ninputs = 0 ,noutputs = 0 ):
182
+ if label is None :
183
+ return None
184
+
185
+ if isinstance (label ,str ):
186
+ label = [label ]* ntraces # single label for all traces
187
+
188
+ # Convert to an ndarray, if not done aleady
189
+ try :
190
+ line_labels = np .asarray (label )
191
+ except ValueError :
192
+ raise ValueError ("label must be a string or array_like" )
193
+
194
+ # Turn the data into a 3D array of appropriate shape
195
+ # TODO: allow more sophisticated broadcasting (and error checking)
196
+ try :
197
+ if ninputs > 0 and noutputs > 0 :
198
+ if line_labels .ndim == 1 and line_labels .size == ntraces :
199
+ line_labels = line_labels .reshape (ntraces ,1 ,1 )
200
+ line_labels = np .broadcast_to (
201
+ line_labels , (ntraces ,ninputs ,noutputs ))
202
+ else :
203
+ line_labels = line_labels .reshape (ntraces ,ninputs ,noutputs )
204
+ except ValueError :
205
+ if line_labels .shape [0 ]!= ntraces :
206
+ raise ValueError ("number of labels must match number of traces" )
207
+ else :
208
+ raise ValueError ("labels must be given for each input/output pair" )
209
+
210
+ return line_labels
211
+
212
+
213
+ # Get labels for all lines in an axes
214
+ def _get_line_labels (ax ,use_color = True ):
215
+ labels ,lines = [], []
216
+ last_color ,counter = None ,0 # label unknown systems
217
+ for i ,line in enumerate (ax .get_lines ()):
218
+ label = line .get_label ()
219
+ if use_color and label .startswith ("Unknown" ):
220
+ label = f"Unknown-{ counter } "
221
+ if last_color is None :
222
+ last_color = line .get_color ()
223
+ elif last_color != line .get_color ():
224
+ counter += 1
225
+ last_color = line .get_color ()
226
+ elif label [0 ]== '_' :
227
+ continue
228
+
229
+ if label not in labels :
230
+ lines .append (line )
231
+ labels .append (label )
232
+
233
+ return lines ,labels
92
234
93
235
94
236
# Utility function to make legend labels
@@ -160,3 +302,83 @@ def _find_axes_center(fig, axs):
160
302
ylim = [min (ll [1 ],ylim [0 ]),max (ur [1 ],ylim [1 ])]
161
303
162
304
return (np .sum (xlim )/ 2 ,np .sum (ylim )/ 2 )
305
+
306
+
307
+ # Internal function to add arrows to a curve
308
+ def _add_arrows_to_line2D (
309
+ axes ,line ,arrow_locs = [0.2 ,0.4 ,0.6 ,0.8 ],
310
+ arrowstyle = '-|>' ,arrowsize = 1 ,dir = 1 ):
311
+ """
312
+ Add arrows to a matplotlib.lines.Line2D at selected locations.
313
+
314
+ Parameters:
315
+ -----------
316
+ axes: Axes object as returned by axes command (or gca)
317
+ line: Line2D object as returned by plot command
318
+ arrow_locs: list of locations where to insert arrows, % of total length
319
+ arrowstyle: style of the arrow
320
+ arrowsize: size of the arrow
321
+
322
+ Returns:
323
+ --------
324
+ arrows: list of arrows
325
+
326
+ Based on https://stackoverflow.com/questions/26911898/
327
+
328
+ """
329
+ # Get the coordinates of the line, in plot coordinates
330
+ if not isinstance (line ,mpl .lines .Line2D ):
331
+ raise ValueError ("expected a matplotlib.lines.Line2D object" )
332
+ x ,y = line .get_xdata (),line .get_ydata ()
333
+
334
+ # Determine the arrow properties
335
+ arrow_kw = {"arrowstyle" :arrowstyle }
336
+
337
+ color = line .get_color ()
338
+ use_multicolor_lines = isinstance (color ,np .ndarray )
339
+ if use_multicolor_lines :
340
+ raise NotImplementedError ("multicolor lines not supported" )
341
+ else :
342
+ arrow_kw ['color' ]= color
343
+
344
+ linewidth = line .get_linewidth ()
345
+ if isinstance (linewidth ,np .ndarray ):
346
+ raise NotImplementedError ("multiwidth lines not supported" )
347
+ else :
348
+ arrow_kw ['linewidth' ]= linewidth
349
+
350
+ # Figure out the size of the axes (length of diagonal)
351
+ xlim ,ylim = axes .get_xlim (),axes .get_ylim ()
352
+ ul ,lr = np .array ([xlim [0 ],ylim [0 ]]),np .array ([xlim [1 ],ylim [1 ]])
353
+ diag = np .linalg .norm (ul - lr )
354
+
355
+ # Compute the arc length along the curve
356
+ s = np .cumsum (np .sqrt (np .diff (x )** 2 + np .diff (y )** 2 ))
357
+
358
+ # Truncate the number of arrows if the curve is short
359
+ # TODO: figure out a smarter way to do this
360
+ frac = min (s [- 1 ]/ diag ,1 )
361
+ if len (arrow_locs )and frac < 0.05 :
362
+ arrow_locs = []# too short; no arrows at all
363
+ elif len (arrow_locs )and frac < 0.2 :
364
+ arrow_locs = [0.5 ]# single arrow in the middle
365
+
366
+ # Plot the arrows (and return list if patches)
367
+ arrows = []
368
+ for loc in arrow_locs :
369
+ n = np .searchsorted (s ,s [- 1 ]* loc )
370
+
371
+ if dir == 1 and n == 0 :
372
+ # Move the arrow forward by one if it is at start of a segment
373
+ n = 1
374
+
375
+ # Place the head of the arrow at the desired location
376
+ arrow_head = [x [n ],y [n ]]
377
+ arrow_tail = [x [n - dir ],y [n - dir ]]
378
+
379
+ p = mpl .patches .FancyArrowPatch (
380
+ arrow_tail ,arrow_head ,transform = axes .transData ,lw = 0 ,
381
+ ** arrow_kw )
382
+ axes .add_patch (p )
383
+ arrows .append (p )
384
+ return arrows