Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commita3b5ba0

Browse files
Merge pull request#4004 from AayushSabharwal/as/v10-modular-tearing
refactor: modularize tearing
2 parents8d6be1e +79da748 commita3b5ba0

File tree

8 files changed

+138
-61
lines changed

8 files changed

+138
-61
lines changed

‎src/structural_transformation/StructuralTransformations.jl‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ using DocStringExtensions
5959
export tearing, dae_index_lowering, check_consistency
6060
export dummy_derivative
6161
export sorted_incidence_matrix,
62-
pantelides!, pantelides_reassemble,tearing_reassemble,find_solvables!,
62+
pantelides!, pantelides_reassemble, find_solvables!,
6363
linear_subsys_adjmat!
6464
export tearing_substitution
6565
export torn_system_jacobian_sparsity
@@ -69,9 +69,9 @@ export computed_highest_diff_variables
6969
export shift2term, lower_shift_varname, simplify_shifts, distribute_shift
7070

7171
include("utils.jl")
72+
include("tearing.jl")
7273
include("pantelides.jl")
7374
include("bipartite_tearing/modia_tearing.jl")
74-
include("tearing.jl")
7575
include("symbolics_tearing.jl")
7676
include("partial_state_selection.jl")
7777
include("codegen.jl")

‎src/structural_transformation/bipartite_tearing/modia_tearing.jl‎

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,19 +62,22 @@ function tear_graph_block_modia!(var_eq_matching, ict, solvable_graph, eqs, vars
6262
returnnothing
6363
end
6464

65-
functionbuild_var_eq_matching(structure::SystemStructure,::Type{U}= Unassigned;
66-
varfilter::F2= v->true, eqfilter::F3= eq->true)where {U,F2, F3}
65+
functionbuild_var_eq_matching(structure::SystemStructure;
66+
varfilter::F2, eqfilter::F3)where {F2, F3}
6767
@unpack graph, solvable_graph= structure
68-
var_eq_matching=maximal_matching(graph, eqfilter, varfilter,U)
68+
var_eq_matching=maximal_matching(graph, eqfilter, varfilter,MatchedVarT)
6969
matching_len=max(length(var_eq_matching),
7070
maximum(x-> xisa Int? x:0, var_eq_matching, init=0))
7171
returncomplete(var_eq_matching, matching_len), matching_len
7272
end
7373

74-
functiontear_graph_modia(structure::SystemStructure, isder::F=nothing,
75-
::Type{U}= Unassigned;
76-
varfilter::F2= v->true,
77-
eqfilter::F3= eq->true)where {F, U, F2, F3}
74+
@kwdefstruct ModiaTearing{F, F2, F3}
75+
isder::F=nothing
76+
varfilter::F2=Returns(true)
77+
eqfilter::F3=Returns(true)
78+
end
79+
80+
function (alg::ModiaTearing)(structure::SystemStructure)
7881
# It would be possible here to simply iterate over all variables and attempt to
7982
# use tearEquations! to produce a matching that greedily selects the minimal
8083
# number of torn variables. However, we can do this process faster if we first
@@ -86,8 +89,11 @@ function tear_graph_modia(structure::SystemStructure, isder::F = nothing,
8689
# to have optimal solutions that cannot be found by this process. We will not
8790
# find them here [TODO: It would be good to have an explicit example of this.]
8891

92+
isder= alg.isder
93+
varfilter= alg.varfilter
94+
eqfilter= alg.eqfilter
8995
@unpack graph, solvable_graph= structure
90-
var_eq_matching, matching_len=build_var_eq_matching(structure, U; varfilter, eqfilter)
96+
var_eq_matching, matching_len=build_var_eq_matching(structure; varfilter, eqfilter)
9197
full_var_eq_matching=copy(var_eq_matching)
9298
var_sccs=find_var_sccs(graph, var_eq_matching)
9399
vargraph=DiCMOBiGraph{true}(graph,0,Matching(matching_len))
@@ -126,5 +132,6 @@ function tear_graph_modia(structure::SystemStructure, isder::F = nothing,
126132
tear_graph_block_modia!(var_eq_matching, ict, solvable_graph, free_eqs,
127133
BitSet(free_vars), isder)
128134
end
129-
return var_eq_matching, full_var_eq_matching, var_sccs
135+
136+
returnTearingResult(var_eq_matching, full_var_eq_matching, var_sccs), (;)
130137
end

‎src/structural_transformation/partial_state_selection.jl‎

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
1-
struct SelectedStateend
2-
31
functiondummy_derivative_graph!(state::TransformationState, jac=nothing;
42
state_priority=nothing, log=Val(false), kwargs...)
53
state.structure.solvable_graph===nothing&&find_solvables!(state; kwargs...)
64
complete!(state.structure)
75
var_eq_matching=complete(pantelides!(state; kwargs...))
8-
dummy_derivative_graph!(state.structure, var_eq_matching, jac, state_priority, log)
6+
dummy_derivative_graph!(state.structure, var_eq_matching, jac, state_priority, log; kwargs...)
97
end
108

119
struct DummyDerivativeSummary
@@ -15,7 +13,8 @@ end
1513

1614
functiondummy_derivative_graph!(
1715
structure::SystemStructure, var_eq_matching, jac=nothing,
18-
state_priority=nothing,::Val{log}=Val(false))where {log}
16+
state_priority=nothing,::Val{log}=Val(false);
17+
tearing_alg::TearingAlgorithm=DummyDerivativeTearing(), kwargs...)where {log}
1918
@unpack eq_to_diff, var_to_diff, graph= structure
2019
diff_to_eq=invview(eq_to_diff)
2120
diff_to_var=invview(var_to_diff)
@@ -173,8 +172,9 @@ function dummy_derivative_graph!(
173172
@warn"The number of dummy derivatives ($n_dummys) does not match the number of differentiated equations ($n_diff_eqs)."
174173
end
175174

176-
ret=tearing_with_dummy_derivatives(structure,BitSet(dummy_derivatives))
177-
(ret...,DummyDerivativeSummary(var_dummy_scc, var_state_priority))
175+
tearing_result, extra=tearing_alg(structure,BitSet(dummy_derivatives))
176+
extra= (; extra..., ddsummary=DummyDerivativeSummary(var_dummy_scc, var_state_priority))
177+
return tearing_result, extra
178178
end
179179

180180
functionis_present(structure, v)::Bool
@@ -201,7 +201,9 @@ function isdiffed((structure, dummy_derivatives), v)::Bool
201201
diff_to_var[v]!==nothing&&is_some_diff(structure, dummy_derivatives, v)
202202
end
203203

204-
functiontearing_with_dummy_derivatives(structure, dummy_derivatives)
204+
struct DummyDerivativeTearing<:TearingAlgorithmend
205+
206+
function (::DummyDerivativeTearing)(structure::SystemStructure, dummy_derivatives::Union{BitSet, Tuple{}}= ())
205207
@unpack var_to_diff= structure
206208
# We can eliminate variables that are not selected (differential
207209
# variables). Selected unknowns are differentiated variables that are not
@@ -213,18 +215,18 @@ function tearing_with_dummy_derivatives(structure, dummy_derivatives)
213215
can_eliminate[v]=true
214216
end
215217
end
216-
var_eq_matching, full_var_eq_matching,
217-
var_sccs=tear_graph_modia(structure,
218-
Base.Fix1(isdiffed, (structure, dummy_derivatives)),
219-
Union{Unassigned, SelectedState};
220-
varfilter= Base.Fix1(getindex, can_eliminate))
218+
modia_tearing=ModiaTearing(;
219+
isder= Base.Fix1(isdiffed, (structure, dummy_derivatives)),
220+
varfilter=Base.Fix1(getindex, can_eliminate)
221+
)
222+
tearing_result, _=modia_tearing(structure)
221223

222224
for vin𝑑vertices(structure.graph)
223225
is_present(structure, v)||continue
224226
dv= var_to_diff[v]
225227
(dv===nothing||!is_some_diff(structure, dummy_derivatives, dv))&&continue
226-
var_eq_matching[v]=SelectedState()
228+
tearing_result.var_eq_matching[v]=SelectedState()
227229
end
228230

229-
returnvar_eq_matching, full_var_eq_matching, var_sccs,can_eliminate
231+
returntearing_result, (;can_eliminate)
230232
end

‎src/structural_transformation/symbolics_tearing.jl‎

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1047,9 +1047,15 @@ differential variables.
10471047
- `var_sccs`: The topologically sorted strongly connected components of the system
10481048
according to `full_var_eq_matching`.
10491049
"""
1050-
functiontearing_reassemble(state::TearingState, var_eq_matching::Matching,
1051-
full_var_eq_matching::Matching, var_sccs::Vector{Vector{Int}}; simplify=false, mm,
1052-
array_hack=true, fully_determined=true)
1050+
@kwdefstruct DefaultReassembleAlgorithm<:ReassembleAlgorithm
1051+
simplify::Bool=false
1052+
array_hack::Bool=true
1053+
end
1054+
1055+
function (alg::DefaultReassembleAlgorithm)(state::TearingState, tearing_result::TearingResult, mm::Union{SparseMatrixCLIL, Nothing}; fully_determined::Bool=true, kw...)
1056+
@unpack simplify, array_hack= alg
1057+
@unpack var_eq_matching, full_var_eq_matching, var_sccs= tearing_result
1058+
10531059
extra_eqs_vars=get_extra_eqs_vars(
10541060
state, var_eq_matching, full_var_eq_matching, fully_determined)
10551061
neweqs=collect(equations(state))
@@ -1314,25 +1320,25 @@ end
13141320
ndims=ndims(arr)
13151321
end
13161322

1317-
functiontearing(state::TearingState; kwargs...)
1323+
functiontearing(state::TearingState; tearing_alg::TearingAlgorithm=DummyDerivativeTearing(),
1324+
kwargs...)
13181325
state.structure.solvable_graph===nothing&&find_solvables!(state; kwargs...)
13191326
complete!(state.structure)
1320-
tearing_with_dummy_derivatives(state.structure, ())
1327+
tearing_alg(state.structure)
13211328
end
13221329

13231330
"""
1324-
tearing(sys; simplify=false)
1331+
tearing(sys)
13251332
13261333
Tear the nonlinear equations in system. When `simplify=true`, we simplify the
13271334
new residual equations after tearing. End users are encouraged to call [`mtkcompile`](@ref)
13281335
instead, which calls this function internally.
13291336
"""
13301337
functiontearing(sys::AbstractSystem, state=TearingState(sys); mm=nothing,
1331-
simplify=false, array_hack=true, fully_determined=true, kwargs...)
1332-
var_eq_matching, full_var_eq_matching, var_sccs, can_eliminate=tearing(state)
1333-
invalidate_cache!(tearing_reassemble(
1334-
state, var_eq_matching, full_var_eq_matching, var_sccs; mm,
1335-
simplify, array_hack, fully_determined))
1338+
reassemble_alg::ReassembleAlgorithm=DefaultReassembleAlgorithm(),
1339+
fully_determined=true, kwargs...)
1340+
tearing_result, extras=tearing(state; kwargs...)
1341+
invalidate_cache!(reassemble_alg(state, tearing_result, mm; fully_determined))
13361342
end
13371343

13381344
"""
@@ -1341,8 +1347,9 @@ end
13411347
Perform index reduction and use the dummy derivative technique to ensure that
13421348
the system is balanced.
13431349
"""
1344-
functiondummy_derivative(sys, state=TearingState(sys); simplify=false,
1345-
mm=nothing, array_hack=true, fully_determined=true, kwargs...)
1350+
functiondummy_derivative(sys, state=TearingState(sys);
1351+
reassemble_alg::ReassembleAlgorithm=DefaultReassembleAlgorithm(),
1352+
mm=nothing, fully_determined=true, kwargs...)
13461353
jac=let state= state
13471354
(eqs, vars)->begin
13481355
symeqs=EquationsView(state)[eqs]
@@ -1364,10 +1371,7 @@ function dummy_derivative(sys, state = TearingState(sys); simplify = false,
13641371
p
13651372
end
13661373
end
1367-
var_eq_matching, full_var_eq_matching, var_sccs,
1368-
can_eliminate, summary=dummy_derivative_graph!(
1369-
state, jac; state_priority,
1370-
kwargs...)
1371-
tearing_reassemble(state, var_eq_matching, full_var_eq_matching, var_sccs;
1372-
simplify, mm, array_hack, fully_determined)
1374+
tearing_result, extras=dummy_derivative_graph!(
1375+
state, jac; state_priority, kwargs...)
1376+
reassemble_alg(state, tearing_result, mm; fully_determined)
13731377
end

‎src/structural_transformation/tearing.jl‎

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,65 @@ function free_equations(graph, vars_scc, var_eq_matching, varfilter::F) where {F
8383
end
8484
findall(!, seen_eqs)
8585
end
86+
87+
struct SelectedStateend
88+
const MatchingT{T}= Matching{T, Vector{Union{T, Int}}}
89+
const MatchedVarT= Union{Unassigned, SelectedState}
90+
const VarEqMatchingT= MatchingT{MatchedVarT}
91+
92+
"""
93+
$TYPEDEF
94+
95+
A struct containing the results of tearing.
96+
97+
# Fields
98+
99+
$TYPEDFIELDS
100+
"""
101+
struct TearingResult
102+
"""
103+
The variable-equation matching. Differential variables are matched to `SelectedState`.
104+
The derivative of a differential variable is matched to the corresponding differential
105+
equation. Solved variables are matched to the equation they are solved from. Algebraic
106+
variables are matched to `unassigned`.
107+
"""
108+
var_eq_matching::VarEqMatchingT
109+
"""
110+
The variable-equation matching prior to tearing. This is the maximal matching used to
111+
compute `var_sccs` (see below). For generating the torn system, `var_eq_matching` is
112+
the source of truth. This should only be used to identify algebraic equations in each
113+
SCC.
114+
"""
115+
full_var_eq_matching::VarEqMatchingT
116+
"""
117+
The partitioning of variables into strongly connected components (SCCs). The SCCs are
118+
sorted in dependency order, so each SCC depends on variables in previous SCCs.
119+
"""
120+
var_sccs::Vector{Vector{Int}}
121+
end
122+
123+
"""
124+
$TYPEDEF
125+
126+
Supertype for all tearing algorithms. A tearing algorithm takes as input the
127+
`SystemStructure` along with any other necessary arguments.
128+
129+
The output of a tearing algorithm must be a `TearingResult` and a `NamedTuple` of
130+
any additional data computed in the process that may be useful for further processing.
131+
"""
132+
abstract type TearingAlgorithmend
133+
134+
"""
135+
$TYPEDEF
136+
137+
Supertype for all reassembling algorithms. A reassembling algorithm takes as input the
138+
`TearingState`, `TearingResult` and integer incidence matrix `mm::SparseMatrixCLIL`. The
139+
matrix `mm` may be `nothing`. The algorithm must also accept arbitrary keyword arguments.
140+
The following keyword arguments will always be provided:
141+
- `fully_determined::Bool`: flag indicating whether the system is fully determined.
142+
143+
The output of a reassembling algorithm must be the torn system.
144+
145+
A reassemble algorithm must also implement `with_fully_determined`
146+
"""
147+
abstract type ReassembleAlgorithmend

‎src/systems/systems.jl‎

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,14 @@ function mtkcompile(
3131
sys::AbstractSystem; additional_passes= [], simplify=false, split=true,
3232
allow_symbolic=false, allow_parameter=true, conservative=false, fully_determined=true,
3333
inputs= Any[], outputs= Any[],
34-
disturbance_inputs= Any[],
34+
disturbance_inputs= Any[], array_hack=true,
3535
kwargs...)
3636
isscheduled(sys)&&throw(RepeatedStructuralSimplificationError())
37-
newsys′=__mtkcompile(sys; simplify,
37+
reassemble_alg=get(kwargs,:reassemble_alg,
38+
StructuralTransformations.DefaultReassembleAlgorithm(; simplify, array_hack))
39+
newsys′=__mtkcompile(sys;
3840
allow_symbolic, allow_parameter, conservative, fully_determined,
39-
inputs, outputs, disturbance_inputs, additional_passes,
41+
inputs, outputs, disturbance_inputs, additional_passes, reassemble_alg,
4042
kwargs...)
4143
if newsys′isa Tuple
4244
@assertlength(newsys′)==2
@@ -59,7 +61,7 @@ function mtkcompile(
5961
end
6062
end
6163

62-
function__mtkcompile(sys::AbstractSystem; simplify=false,
64+
function__mtkcompile(sys::AbstractSystem;
6365
inputs= Any[], outputs= Any[],
6466
disturbance_inputs= Any[],
6567
sort_eqs=true,
@@ -72,7 +74,7 @@ function __mtkcompile(sys::AbstractSystem; simplify = false,
7274
return sys
7375
end
7476
ifisempty(equations(sys))&&!is_time_dependent(sys)&&!_iszero(cost(sys))
75-
returnsimplify_optimization_system(sys; kwargs..., sort_eqs, simplify)
77+
returnsimplify_optimization_system(sys; kwargs..., sort_eqs)
7678
end
7779

7880
sys, statemachines=extract_top_level_statemachines(sys)
@@ -94,7 +96,7 @@ function __mtkcompile(sys::AbstractSystem; simplify = false,
9496
end
9597
ifisempty(brown_vars)
9698
returnmtkcompile!(
97-
state;simplify,inputs, outputs, disturbance_inputs, kwargs...)
99+
state; inputs, outputs, disturbance_inputs, kwargs...)
98100
else
99101
Is= Int[]
100102
Js= Int[]
@@ -129,7 +131,7 @@ function __mtkcompile(sys::AbstractSystem; simplify = false,
129131
if!iszero(new_idxs[i])&&
130132
invview(var_to_diff)[i]===nothing]
131133
ode_sys=mtkcompile(
132-
sys;simplify,inputs, outputs, disturbance_inputs, kwargs...)
134+
sys; inputs, outputs, disturbance_inputs, kwargs...)
133135
eqs=equations(ode_sys)
134136
sorted_g_rows=zeros(Num,length(eqs),size(g,2))
135137
for (i, eq)inenumerate(eqs)

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp