Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit0c1d638

Browse files
committed
slight code refactoring to consolidate flag matrix computation
1 parent1fe6e86 commit0c1d638

File tree

3 files changed

+69
-66
lines changed

3 files changed

+69
-66
lines changed

‎control/flatsys/flatsys.py

Lines changed: 53 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,8 @@ def __init__(self,
155155
ifforwardisnotNone:self.forward=forward
156156
ifreverseisnotNone:self.reverse=reverse
157157

158+
# Save the length of the flat flag
159+
158160
defforward(self,x,u,params={}):
159161
"""Compute the flat flag given the states and input.
160162
@@ -217,10 +219,33 @@ def _flat_outfcn(self, t, x, u, params={}):
217219
returnnp.array(zflag[:][0])
218220

219221

222+
# Utility function to compute flag matrix given a basis
223+
def_basis_flag_matrix(sys,basis,flag,t,params={}):
224+
"""Compute the matrix of basis functions and their derivatives
225+
226+
This function computes the matrix ``M`` that is used to solve for the
227+
coefficients of the basis functions given the state and input. Each
228+
column of the matrix corresponds to a basis function and each row is a
229+
derivative, with the derivatives (flag) for each output stacked on top
230+
of each other.
231+
232+
"""
233+
flagshape= [len(f)forfinflag]
234+
M=np.zeros((sum(flagshape),basis.N*sys.ninputs))
235+
flag_off=0
236+
coeff_off=0
237+
fori,flag_leninenumerate(flagshape):
238+
forj,kinitertools.product(range(basis.N),range(flag_len)):
239+
M[flag_off+k,coeff_off+j]=basis.eval_deriv(j,k,t)
240+
flag_off+=flag_len
241+
coeff_off+=basis.N
242+
returnM
243+
244+
220245
# Solve a point to point trajectory generation problem for a flat system
221246
defpoint_to_point(
222247
sys,timepts,x0=0,u0=0,xf=0,uf=0,T0=0,basis=None,cost=None,
223-
constraints=None,initial_guess=None,minimize_kwargs={}):
248+
constraints=None,initial_guess=None,minimize_kwargs={},**kwargs):
224249
"""Compute trajectory between an initial and final conditions.
225250
226251
Compute a feasible trajectory for a differentially flat system between an
@@ -251,9 +276,9 @@ def point_to_point(
251276
252277
basis : :class:`~control.flatsys.BasisFamily` object, optional
253278
The basis functions to use for generating the trajectory. If not
254-
specified, the :class:`~control.flatsys.PolyFamily` basis family will be
255-
used, with the minimal number of elements required to find a feasible
256-
trajectory (twice the number of system states)
279+
specified, the :class:`~control.flatsys.PolyFamily` basis family
280+
will beused, with the minimal number of elements required to find a
281+
feasibletrajectory (twice the number of system states)
257282
258283
cost : callable
259284
Function that returns the integral cost given the current state
@@ -287,6 +312,12 @@ def point_to_point(
287312
`eval()` function, we can be used to compute the value of the state
288313
and input and a given time t.
289314
315+
Notes
316+
-----
317+
Additional keyword parameters can be used to fine tune the behavior of
318+
the underlying optimization function. See `minimize_*` keywords in
319+
:func:`OptimalControlProblem` for more information.
320+
290321
"""
291322
#
292323
# Make sure the problem is one that we can handle
@@ -296,7 +327,7 @@ def point_to_point(
296327
u0=_check_convert_array(u0, [(sys.ninputs,), (sys.ninputs,1)],
297328
'Initial input: ',squeeze=True)
298329
xf=_check_convert_array(xf, [(sys.nstates,), (sys.nstates,1)],
299-
'Final state: ',squeeze=True)
330+
'Final state: ',squeeze=True)
300331
uf=_check_convert_array(uf, [(sys.ninputs,), (sys.ninputs,1)],
301332
'Final input: ',squeeze=True)
302333

@@ -305,6 +336,12 @@ def point_to_point(
305336
Tf=timepts[-1]
306337
T0=timepts[0]iflen(timepts)>1elseT0
307338

339+
# Process keyword arguments
340+
minimize_kwargs['method']=kwargs.pop('minimize_method',None)
341+
minimize_kwargs['options']=kwargs.pop('minimize_options', {})
342+
ifkwargs:
343+
raiseTypeError("unrecognized keywords: ",str(kwargs))
344+
308345
#
309346
# Determine the basis function set to use and make sure it is big enough
310347
#
@@ -328,8 +365,7 @@ def point_to_point(
328365
# We need to compute the output "flag": [z(t), z'(t), z''(t), ...]
329366
# and then evaluate this at the initial and final condition.
330367
#
331-
# TODO: should be able to represent flag variables as 1D arrays
332-
# TODO: need inputs to fully define the flag
368+
333369
zflag_T0=sys.forward(x0,u0)
334370
zflag_Tf=sys.forward(xf,uf)
335371

@@ -340,41 +376,13 @@ def point_to_point(
340376
# essentially amounts to evaluating the basis functions and their
341377
# derivatives at the initial and final conditions.
342378

343-
# Figure out the size of the problem we are solving
344-
flag_tot=np.sum([len(zflag_T0[i])foriinrange(sys.ninputs)])
379+
# Compute the flags for the initial and final states
380+
M_T0=_basis_flag_matrix(sys,basis,zflag_T0,T0)
381+
M_Tf=_basis_flag_matrix(sys,basis,zflag_Tf,Tf)
345382

346-
# Start by creating an empty matrix that we can fill up
347-
# TODO: allow a different number of basis elements for each flat output
348-
M=np.zeros((2*flag_tot,basis.N*sys.ninputs))
349-
350-
# Now fill in the rows for the initial and final states
351-
# TODO: vectorize
352-
flag_off=0
353-
coeff_off=0
354-
355-
foriinrange(sys.ninputs):
356-
flag_len=len(zflag_T0[i])
357-
forjinrange(basis.N):
358-
forkinrange(flag_len):
359-
M[flag_off+k,coeff_off+j]=basis.eval_deriv(j,k,T0)
360-
M[flag_tot+flag_off+k,coeff_off+j]= \
361-
basis.eval_deriv(j,k,Tf)
362-
flag_off+=flag_len
363-
coeff_off+=basis.N
364-
365-
# Create an empty matrix that we can fill up
366-
Z=np.zeros(2*flag_tot)
367-
368-
# Compute the flag vector to use for the right hand side by
369-
# stacking up the flags for each input
370-
# TODO: make this more pythonic
371-
flag_off=0
372-
foriinrange(sys.ninputs):
373-
flag_len=len(zflag_T0[i])
374-
forjinrange(flag_len):
375-
Z[flag_off+j]=zflag_T0[i][j]
376-
Z[flag_tot+flag_off+j]=zflag_Tf[i][j]
377-
flag_off+=flag_len
383+
# Stack the initial and final matrix/flag for the point to point problem
384+
M=np.vstack([M_T0,M_Tf])
385+
Z=np.hstack([np.hstack(zflag_T0),np.hstack(zflag_Tf)])
378386

379387
#
380388
# Solve for the coefficients of the flat outputs
@@ -404,17 +412,7 @@ def traj_cost(null_coeffs):
404412
# Evaluate the costs at the listed time points
405413
costval=0
406414
fortintimepts:
407-
M_t=np.zeros((flag_tot,basis.N*sys.ninputs))
408-
flag_off=0
409-
coeff_off=0
410-
foriinrange(sys.ninputs):
411-
flag_len=len(zflag_T0[i])
412-
forj,kinitertools.product(
413-
range(basis.N),range(flag_len)):
414-
M_t[flag_off+k,coeff_off+j]= \
415-
basis.eval_deriv(j,k,t)
416-
flag_off+=flag_len
417-
coeff_off+=basis.N
415+
M_t=_basis_flag_matrix(sys,basis,zflag_T0,t)
418416

419417
# Compute flag at this time point
420418
zflag= (M_t @coeffs).reshape(sys.ninputs,-1)
@@ -452,17 +450,7 @@ def traj_const(null_coeffs):
452450
values= []
453451
fori,tinenumerate(timepts):
454452
# Calculate the states and inputs for the flat output
455-
M_t=np.zeros((flag_tot,basis.N*sys.ninputs))
456-
flag_off=0
457-
coeff_off=0
458-
foriinrange(sys.ninputs):
459-
flag_len=len(zflag_T0[i])
460-
forj,kinitertools.product(
461-
range(basis.N),range(flag_len)):
462-
M_t[flag_off+k,coeff_off+j]= \
463-
basis.eval_deriv(j,k,t)
464-
flag_off+=flag_len
465-
coeff_off+=basis.N
453+
M_t=_basis_flag_matrix(sys,basis,zflag_T0,t)
466454

467455
# Compute flag at this time point
468456
zflag= (M_t @coeffs).reshape(sys.ninputs,-1)
@@ -501,7 +489,7 @@ def traj_const(null_coeffs):
501489

502490
# Process the initial condition
503491
ifinitial_guessisNone:
504-
initial_guess=np.zeros(basis.N*sys.ninputs-2*flag_tot)
492+
initial_guess=np.zeros(M.shape[1]-M.shape[0])
505493
else:
506494
raiseNotImplementedError("Initial guess not yet implemented.")
507495

@@ -514,7 +502,7 @@ def traj_const(null_coeffs):
514502
else:
515503
raiseRuntimeError(
516504
"Unable to solve optimal control problem\n"+
517-
"scipy.optimize.minimize returned "+res.message)
505+
"scipy.optimize.minimize returned "+res.message)
518506

519507
#
520508
# Transform the trajectory from flat outputs to states and inputs

‎control/tests/flatsys_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,3 +331,18 @@ def test_point_to_point_errors(self):
331331
traj=fs.point_to_point(
332332
flat_sys,timepts,x0,u0,xf,uf,constraints=constraint,
333333
basis=fs.PolyFamily(8))
334+
335+
# Method arguments, parameters
336+
traj_method=fs.point_to_point(
337+
flat_sys,timepts,x0,u0,xf,uf,cost=cost_fcn,
338+
basis=fs.PolyFamily(8),minimize_method='slsqp')
339+
traj_kwarg=fs.point_to_point(
340+
flat_sys,timepts,x0,u0,xf,uf,cost=cost_fcn,
341+
basis=fs.PolyFamily(8),minimize_kwargs={'method':'slsqp'})
342+
np.testing.assert_almost_equal(
343+
traj_method.eval(timepts)[0],traj_kwarg.eval(timepts)[0])
344+
345+
# Unrecognized keywords
346+
withpytest.raises(TypeError,match="unrecognized keyword"):
347+
traj_method=fs.point_to_point(
348+
flat_sys,timepts,x0,u0,xf,uf,solve_ivp_method=None)

‎doc/flatsys.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
. _flatsys-module:
1+
.. _flatsys-module:
22

33
***************************
44
Differentially flat systems

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp