Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit234e6ec

Browse files
committed
move code around to new locations
1 parent6406868 commit234e6ec

File tree

8 files changed

+281
-270
lines changed

8 files changed

+281
-270
lines changed

‎control/ctrlplot.py

Lines changed: 227 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,36 @@
55

66
fromos.pathimportcommonprefix
77

8+
importmatplotlibasmpl
89
importmatplotlib.pyplotasplt
910
importnumpyasnp
1011

1112
from .importconfig
1213

1314
__all__= ['suptitle','get_plot_axes']
1415

16+
#
17+
# Style parameters
18+
#
19+
20+
_ctrlplot_rcParams=mpl.rcParams.copy()
21+
_ctrlplot_rcParams.update({
22+
'axes.labelsize':'small',
23+
'axes.titlesize':'small',
24+
'figure.titlesize':'medium',
25+
'legend.fontsize':'x-small',
26+
'xtick.labelsize':'small',
27+
'ytick.labelsize':'small',
28+
})
29+
30+
31+
#
32+
# User functions
33+
#
34+
# The functions below can be used by users to modify ctrl plots or get
35+
# information about them.
36+
#
37+
1538

1639
defsuptitle(
1740
title,fig=None,frame='axes',**kwargs):
@@ -35,7 +58,7 @@ def suptitle(
3558
Additional keywords (passed to matplotlib).
3659
3760
"""
38-
rcParams=config._get_param('freqplot','rcParams',kwargs,pop=True)
61+
rcParams=config._get_param('ctrlplot','rcParams',kwargs,pop=True)
3962

4063
iffigisNone:
4164
fig=plt.gcf()
@@ -61,10 +84,10 @@ def suptitle(
6184
defget_plot_axes(line_array):
6285
"""Get a list of axes from an array of lines.
6386
64-
This function can be used to return the set of axes corresponding to
65-
the line array that is returned by `time_response_plot`. This is useful for
66-
generating an axes array that can be passed to subsequent plotting
67-
calls.
87+
This function can be used to return the set of axes corresponding
88+
tothe line array that is returned by `time_response_plot`. This
89+
is useful forgenerating an axes array that can be passed to
90+
subsequent plottingcalls.
6891
6992
Parameters
7093
----------
@@ -89,6 +112,125 @@ def get_plot_axes(line_array):
89112
#
90113
# Utility functions
91114
#
115+
# These functions are used by plotting routines to provide a consistent way
116+
# of processing and displaing information.
117+
#
118+
119+
120+
def_process_ax_keyword(
121+
axs,shape=(1,1),rcParams=None,squeeze=False,clear_text=False):
122+
"""Utility function to process ax keyword to plotting commands.
123+
124+
This function processes the `ax` keyword to plotting commands. If no
125+
ax keyword is passed, the current figure is checked to see if it has
126+
the correct shape. If the shape matches the desired shape, then the
127+
current figure and axes are returned. Otherwise a new figure is
128+
created with axes of the desired shape.
129+
130+
Legacy behavior: some of the older plotting commands use a axes label
131+
to identify the proper axes for plotting. This behavior is supported
132+
through the use of the label keyword, but will only work if shape ==
133+
(1, 1) and squeeze == True.
134+
135+
"""
136+
ifaxsisNone:
137+
fig=plt.gcf()# get current figure (or create new one)
138+
axs=fig.get_axes()
139+
140+
# Check to see if axes are the right shape; if not, create new figure
141+
# Note: can't actually check the shape, just the total number of axes
142+
iflen(axs)!=np.prod(shape):
143+
withplt.rc_context(rcParams):
144+
iflen(axs)!=0:
145+
# Create a new figure
146+
fig,axs=plt.subplots(*shape,squeeze=False)
147+
else:
148+
# Create new axes on (empty) figure
149+
axs=fig.subplots(*shape,squeeze=False)
150+
fig.set_layout_engine('tight')
151+
fig.align_labels()
152+
else:
153+
# Use the existing axes, properly reshaped
154+
axs=np.asarray(axs).reshape(*shape)
155+
156+
ifclear_text:
157+
# Clear out any old text from the current figure
158+
fortextinfig.texts:
159+
text.set_visible(False)# turn off the text
160+
deltext# get rid of it completely
161+
else:
162+
try:
163+
axs=np.asarray(axs).reshape(shape)
164+
exceptValueError:
165+
raiseValueError(
166+
"specified axes are not the right shape; "
167+
f"got{axs.shape} but expecting{shape}")
168+
fig=axs[0,0].figure
169+
170+
# Process the squeeze keyword
171+
ifsqueezeandshape== (1,1):
172+
axs=axs[0,0]# Just return the single axes object
173+
elifsqueeze:
174+
axs=axs.squeeze()
175+
176+
returnfig,axs
177+
178+
179+
# Turn label keyword into array indexed by trace, output, input
180+
# TODO: move to ctrlutil.py and update parameter names to reflect general use
181+
def_process_line_labels(label,ntraces,ninputs=0,noutputs=0):
182+
iflabelisNone:
183+
returnNone
184+
185+
ifisinstance(label,str):
186+
label= [label]*ntraces# single label for all traces
187+
188+
# Convert to an ndarray, if not done aleady
189+
try:
190+
line_labels=np.asarray(label)
191+
exceptValueError:
192+
raiseValueError("label must be a string or array_like")
193+
194+
# Turn the data into a 3D array of appropriate shape
195+
# TODO: allow more sophisticated broadcasting (and error checking)
196+
try:
197+
ifninputs>0andnoutputs>0:
198+
ifline_labels.ndim==1andline_labels.size==ntraces:
199+
line_labels=line_labels.reshape(ntraces,1,1)
200+
line_labels=np.broadcast_to(
201+
line_labels, (ntraces,ninputs,noutputs))
202+
else:
203+
line_labels=line_labels.reshape(ntraces,ninputs,noutputs)
204+
exceptValueError:
205+
ifline_labels.shape[0]!=ntraces:
206+
raiseValueError("number of labels must match number of traces")
207+
else:
208+
raiseValueError("labels must be given for each input/output pair")
209+
210+
returnline_labels
211+
212+
213+
# Get labels for all lines in an axes
214+
def_get_line_labels(ax,use_color=True):
215+
labels,lines= [], []
216+
last_color,counter=None,0# label unknown systems
217+
fori,lineinenumerate(ax.get_lines()):
218+
label=line.get_label()
219+
ifuse_colorandlabel.startswith("Unknown"):
220+
label=f"Unknown-{counter}"
221+
iflast_colorisNone:
222+
last_color=line.get_color()
223+
eliflast_color!=line.get_color():
224+
counter+=1
225+
last_color=line.get_color()
226+
eliflabel[0]=='_':
227+
continue
228+
229+
iflabelnotinlabels:
230+
lines.append(line)
231+
labels.append(label)
232+
233+
returnlines,labels
92234

93235

94236
# Utility function to make legend labels
@@ -160,3 +302,83 @@ def _find_axes_center(fig, axs):
160302
ylim= [min(ll[1],ylim[0]),max(ur[1],ylim[1])]
161303

162304
return (np.sum(xlim)/2,np.sum(ylim)/2)
305+
306+
307+
# Internal function to add arrows to a curve
308+
def_add_arrows_to_line2D(
309+
axes,line,arrow_locs=[0.2,0.4,0.6,0.8],
310+
arrowstyle='-|>',arrowsize=1,dir=1):
311+
"""
312+
Add arrows to a matplotlib.lines.Line2D at selected locations.
313+
314+
Parameters:
315+
-----------
316+
axes: Axes object as returned by axes command (or gca)
317+
line: Line2D object as returned by plot command
318+
arrow_locs: list of locations where to insert arrows, % of total length
319+
arrowstyle: style of the arrow
320+
arrowsize: size of the arrow
321+
322+
Returns:
323+
--------
324+
arrows: list of arrows
325+
326+
Based on https://stackoverflow.com/questions/26911898/
327+
328+
"""
329+
# Get the coordinates of the line, in plot coordinates
330+
ifnotisinstance(line,mpl.lines.Line2D):
331+
raiseValueError("expected a matplotlib.lines.Line2D object")
332+
x,y=line.get_xdata(),line.get_ydata()
333+
334+
# Determine the arrow properties
335+
arrow_kw= {"arrowstyle":arrowstyle}
336+
337+
color=line.get_color()
338+
use_multicolor_lines=isinstance(color,np.ndarray)
339+
ifuse_multicolor_lines:
340+
raiseNotImplementedError("multicolor lines not supported")
341+
else:
342+
arrow_kw['color']=color
343+
344+
linewidth=line.get_linewidth()
345+
ifisinstance(linewidth,np.ndarray):
346+
raiseNotImplementedError("multiwidth lines not supported")
347+
else:
348+
arrow_kw['linewidth']=linewidth
349+
350+
# Figure out the size of the axes (length of diagonal)
351+
xlim,ylim=axes.get_xlim(),axes.get_ylim()
352+
ul,lr=np.array([xlim[0],ylim[0]]),np.array([xlim[1],ylim[1]])
353+
diag=np.linalg.norm(ul-lr)
354+
355+
# Compute the arc length along the curve
356+
s=np.cumsum(np.sqrt(np.diff(x)**2+np.diff(y)**2))
357+
358+
# Truncate the number of arrows if the curve is short
359+
# TODO: figure out a smarter way to do this
360+
frac=min(s[-1]/diag,1)
361+
iflen(arrow_locs)andfrac<0.05:
362+
arrow_locs= []# too short; no arrows at all
363+
eliflen(arrow_locs)andfrac<0.2:
364+
arrow_locs= [0.5]# single arrow in the middle
365+
366+
# Plot the arrows (and return list if patches)
367+
arrows= []
368+
forlocinarrow_locs:
369+
n=np.searchsorted(s,s[-1]*loc)
370+
371+
ifdir==1andn==0:
372+
# Move the arrow forward by one if it is at start of a segment
373+
n=1
374+
375+
# Place the head of the arrow at the desired location
376+
arrow_head= [x[n],y[n]]
377+
arrow_tail= [x[n-dir],y[n-dir]]
378+
379+
p=mpl.patches.FancyArrowPatch(
380+
arrow_tail,arrow_head,transform=axes.transData,lw=0,
381+
**arrow_kw)
382+
axes.add_patch(p)
383+
arrows.append(p)
384+
returnarrows

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp