Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit4b4546c

Browse files
committed
type hint fixes for adaptive/tests/test_learners.py
1 parentcc296f4 commit4b4546c

File tree

1 file changed

+21
-110
lines changed

1 file changed

+21
-110
lines changed

‎adaptive/tests/test_learners.py‎

Lines changed: 21 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
fromadaptive.learnerimport (
2121
AverageLearner,
2222
BalancingLearner,
23+
BaseLearner,
2324
DataSaver,
2425
IntegratorLearner,
2526
Learner1D,
@@ -92,28 +93,15 @@ def uniform(a: Union[int, float], b: int) -> Callable:
9293
learner_function_combos=collections.defaultdict(list)
9394

9495

95-
deflearn_with(
96-
learner_type:Union[
97-
Type[Learner2D],
98-
Type[SequenceLearner],
99-
Type[AverageLearner],
100-
Type[Learner1D],
101-
Type[LearnerND],
102-
],
103-
**init_kwargs,
104-
)->Callable:
96+
deflearn_with(learner_type:Type[BaseLearner],**init_kwargs,)->Callable:
10597
def_(f):
10698
learner_function_combos[learner_type].append((f,init_kwargs))
10799
returnf
108100

109101
return_
110102

111103

112-
defxfail(
113-
learner:Union[Type[Learner2D],Type[LearnerND]]
114-
)->Union[
115-
Tuple[MarkDecorator,Type[Learner2D]],Tuple[MarkDecorator,Type[LearnerND]]
116-
]:
104+
defxfail(learner:Type[BaseLearner])->Tuple[MarkDecorator,Type[BaseLearner]]:
117105
returnpytest.mark.xfail,learner
118106

119107

@@ -141,14 +129,7 @@ def linear_with_peak(x: Union[int, float], d: uniform(-1, 1)) -> float:
141129
@learn_with(Learner2D,bounds=((-1,1), (-1,1)))
142130
@learn_with(SequenceLearner,sequence=np.random.rand(1000,2))
143131
defring_of_fire(
144-
xy:Union[
145-
Tuple[float,float],
146-
np.ndarray,
147-
Tuple[int,int],
148-
Tuple[float,float],
149-
Tuple[float,float],
150-
],
151-
d:uniform(0.2,1),
132+
xy:Union[np.ndarray,Tuple[float,float]],d:uniform(0.2,1),
152133
)->float:
153134
a=0.2
154135
x,y=xy
@@ -158,8 +139,7 @@ def ring_of_fire(
158139
@learn_with(LearnerND,bounds=((-1,1), (-1,1), (-1,1)))
159140
@learn_with(SequenceLearner,sequence=np.random.rand(1000,3))
160141
defsphere_of_fire(
161-
xyz:Union[Tuple[float,float,float],Tuple[int,int,int],np.ndarray],
162-
d:uniform(0.2,1),
142+
xyz:Union[Tuple[float,float,float],np.ndarray],d:uniform(0.2,1),
163143
)->float:
164144
a=0.2
165145
x,y,z=xyz
@@ -177,16 +157,7 @@ def gaussian(n: int) -> float:
177157

178158
# Create a sequence of learner parameters by adding all
179159
# possible loss functions to an existing parameter set.
180-
defadd_loss_to_params(
181-
learner_type:Union[
182-
Type[Learner2D],
183-
Type[SequenceLearner],
184-
Type[AverageLearner],
185-
Type[Learner1D],
186-
Type[LearnerND],
187-
],
188-
existing_params:Dict[str,Any],
189-
)->Any:
160+
defadd_loss_to_params(learner_type,existing_params:Dict[str,Any],)->Any:
190161
iflearner_typenotinLOSS_FUNCTIONS:
191162
return [existing_params]
192163
loss_param,loss_functions=LOSS_FUNCTIONS[learner_type]
@@ -216,12 +187,7 @@ def ask_randomly(
216187
learner:Union[Learner1D,LearnerND,Learner2D],
217188
rounds:Tuple[int,int],
218189
points:Tuple[int,int],
219-
)->Union[
220-
Tuple[List[Union[Tuple[float,float,float],Tuple[int,int,int]]],List[float]],
221-
Tuple[List[Union[Tuple[float,float],Tuple[int,int]]],List[float]],
222-
Tuple[List[float],List[float]],
223-
Tuple[List[Union[Tuple[int,int],Tuple[float,float]]],List[float]],
224-
]:
190+
):
225191
n_rounds=random.randrange(*rounds)
226192
n_points= [random.randrange(*points)for_inrange(n_rounds)]
227193

@@ -240,7 +206,7 @@ def ask_randomly(
240206

241207
@run_with(Learner1D)
242208
deftest_uniform_sampling1D(
243-
learner_type:Type[Learner1D],
209+
learner_type,
244210
f:Callable,
245211
learner_kwargs:Dict[str,Union[Tuple[int,int],Callable]],
246212
)->None:
@@ -262,7 +228,7 @@ def test_uniform_sampling1D(
262228
@pytest.mark.xfail
263229
@run_with(Learner2D,LearnerND)
264230
deftest_uniform_sampling2D(
265-
learner_type:Union[Type[Learner2D],Type[LearnerND]],
231+
learner_type,
266232
f:Callable,
267233
learner_kwargs:Dict[
268234
str,
@@ -304,8 +270,7 @@ def test_uniform_sampling2D(
304270
],
305271
)
306272
deftest_learner_accepts_lists(
307-
learner_type:Union[Type[Learner2D],Type[LearnerND],Type[Learner1D]],
308-
bounds:Union[Tuple[int,int],List[Tuple[int,int]]],
273+
learner_type,bounds:Union[Tuple[int,int],List[Tuple[int,int]]],
309274
)->None:
310275
deff(x):
311276
return [0,1]
@@ -316,11 +281,7 @@ def f(x):
316281

317282
@run_with(Learner1D,Learner2D,LearnerND,SequenceLearner)
318283
deftest_adding_existing_data_is_idempotent(
319-
learner_type:Union[
320-
Type[SequenceLearner],Type[LearnerND],Type[Learner1D],Type[Learner2D]
321-
],
322-
f:Callable,
323-
learner_kwargs:Dict[str,Any],
284+
learner_type,f:Callable,learner_kwargs:Dict[str,Any],
324285
)->None:
325286
"""Adding already existing data is an idempotent operation.
326287
@@ -369,15 +330,7 @@ def test_adding_existing_data_is_idempotent(
369330
# but we xfail it now, as Learner2D will be deprecated anyway
370331
@run_with(Learner1D,xfail(Learner2D),LearnerND,AverageLearner,SequenceLearner)
371332
deftest_adding_non_chosen_data(
372-
learner_type:Union[
373-
Type[Learner2D],
374-
Type[SequenceLearner],
375-
Type[AverageLearner],
376-
Type[Learner1D],
377-
Type[LearnerND],
378-
],
379-
f:Callable,
380-
learner_kwargs:Dict[str,Any],
333+
learner_type,f:Callable,learner_kwargs:Dict[str,Any],
381334
)->None:
382335
"""Adding data for a point that was not returned by 'ask'."""
383336
# XXX: learner, control and bounds are not defined
@@ -421,9 +374,7 @@ def test_adding_non_chosen_data(
421374

422375
@run_with(Learner1D,xfail(Learner2D),xfail(LearnerND),AverageLearner)
423376
deftest_point_adding_order_is_irrelevant(
424-
learner_type:Union[
425-
Type[AverageLearner],Type[LearnerND],Type[Learner1D],Type[Learner2D]
426-
],
377+
learner_type,
427378
f:Callable,
428379
learner_kwargs:Dict[
429380
str,
@@ -478,9 +429,7 @@ def test_point_adding_order_is_irrelevant(
478429
# see https://github.com/python-adaptive/adaptive/issues/55
479430
@run_with(Learner1D,xfail(Learner2D),LearnerND,AverageLearner)
480431
deftest_expected_loss_improvement_is_less_than_total_loss(
481-
learner_type:Union[
482-
Type[AverageLearner],Type[LearnerND],Type[Learner1D],Type[Learner2D]
483-
],
432+
learner_type,
484433
f:Callable,
485434
learner_kwargs:Dict[
486435
str,
@@ -519,7 +468,7 @@ def test_expected_loss_improvement_is_less_than_total_loss(
519468
# but we xfail it now, as Learner2D will be deprecated anyway
520469
@run_with(Learner1D,xfail(Learner2D),LearnerND)
521470
deftest_learner_performance_is_invariant_under_scaling(
522-
learner_type:Union[Type[Learner2D],Type[LearnerND],Type[Learner1D]],
471+
learner_type,
523472
f:Callable,
524473
learner_kwargs:Dict[
525474
str,
@@ -583,15 +532,7 @@ def test_learner_performance_is_invariant_under_scaling(
583532
with_all_loss_functions=False,
584533
)
585534
deftest_balancing_learner(
586-
learner_type:Union[
587-
Type[Learner2D],
588-
Type[SequenceLearner],
589-
Type[AverageLearner],
590-
Type[Learner1D],
591-
Type[LearnerND],
592-
],
593-
f:Callable,
594-
learner_kwargs:Dict[str,Any],
535+
learner_type,f:Callable,learner_kwargs:Dict[str,Any],
595536
)->None:
596537
"""Test if the BalancingLearner works with the different types of learners."""
597538
learners= [
@@ -638,17 +579,7 @@ def test_balancing_learner(
638579
SequenceLearner,
639580
with_all_loss_functions=False,
640581
)
641-
deftest_saving(
642-
learner_type:Union[
643-
Type[Learner2D],
644-
Type[SequenceLearner],
645-
Type[AverageLearner],
646-
Type[Learner1D],
647-
Type[LearnerND],
648-
],
649-
f:Callable,
650-
learner_kwargs:Dict[str,Any],
651-
)->None:
582+
deftest_saving(learner_type,f:Callable,learner_kwargs:Dict[str,Any],)->None:
652583
f=generate_random_parametrization(f)
653584
learner=learner_type(f,**learner_kwargs)
654585
control=learner_type(f,**learner_kwargs)
@@ -680,15 +611,7 @@ def test_saving(
680611
with_all_loss_functions=False,
681612
)
682613
deftest_saving_of_balancing_learner(
683-
learner_type:Union[
684-
Type[Learner2D],
685-
Type[SequenceLearner],
686-
Type[AverageLearner],
687-
Type[Learner1D],
688-
Type[LearnerND],
689-
],
690-
f:Callable,
691-
learner_kwargs:Dict[str,Any],
614+
learner_type,f:Callable,learner_kwargs:Dict[str,Any],
692615
)->None:
693616
f=generate_random_parametrization(f)
694617
learner=BalancingLearner([learner_type(f,**learner_kwargs)])
@@ -727,9 +650,7 @@ def fname(learner):
727650
with_all_loss_functions=False,
728651
)
729652
deftest_saving_with_datasaver(
730-
learner_type:Union[
731-
Type[Learner2D],Type[AverageLearner],Type[LearnerND],Type[Learner1D]
732-
],
653+
learner_type,
733654
f:Callable,
734655
learner_kwargs:Dict[
735656
str,
@@ -770,7 +691,7 @@ def test_saving_with_datasaver(
770691
@pytest.mark.xfail
771692
@run_with(Learner1D,Learner2D,LearnerND)
772693
deftest_convergence_for_arbitrary_ordering(
773-
learner_type:Union[Type[Learner2D],Type[LearnerND],Type[Learner1D]],
694+
learner_type,
774695
f:Callable,
775696
learner_kwargs:Dict[
776697
str,
@@ -794,17 +715,7 @@ def test_convergence_for_arbitrary_ordering(
794715
@pytest.mark.xfail
795716
@run_with(Learner1D,Learner2D,LearnerND)
796717
deftest_learner_subdomain(
797-
learner_type:Union[Type[Learner2D],Type[LearnerND],Type[Learner1D]],
798-
f:Callable,
799-
learner_kwargs:Dict[
800-
str,
801-
Union[
802-
Tuple[Tuple[int,int],Tuple[int,int]],
803-
Callable,
804-
Tuple[int,int],
805-
Tuple[Tuple[int,int],Tuple[int,int],Tuple[int,int]],
806-
],
807-
],
718+
learner_type,f:Callable,learner_kwargs,
808719
):
809720
"""Learners that never receive data outside of a subdomain should
810721
perform 'similarly' to learners defined on that subdomain only."""

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp