2020from adaptive .learner import (
2121AverageLearner ,
2222BalancingLearner ,
23+ BaseLearner ,
2324DataSaver ,
2425IntegratorLearner ,
2526Learner1D ,
@@ -92,28 +93,15 @@ def uniform(a: Union[int, float], b: int) -> Callable:
9293learner_function_combos = collections .defaultdict (list )
9394
9495
95- def learn_with (
96- learner_type :Union [
97- Type [Learner2D ],
98- Type [SequenceLearner ],
99- Type [AverageLearner ],
100- Type [Learner1D ],
101- Type [LearnerND ],
102- ],
103- ** init_kwargs ,
104- )-> Callable :
96+ def learn_with (learner_type :Type [BaseLearner ],** init_kwargs ,)-> Callable :
10597def _ (f ):
10698learner_function_combos [learner_type ].append ((f ,init_kwargs ))
10799return f
108100
109101return _
110102
111103
112- def xfail (
113- learner :Union [Type [Learner2D ],Type [LearnerND ]]
114- )-> Union [
115- Tuple [MarkDecorator ,Type [Learner2D ]],Tuple [MarkDecorator ,Type [LearnerND ]]
116- ]:
104+ def xfail (learner :Type [BaseLearner ])-> Tuple [MarkDecorator ,Type [BaseLearner ]]:
117105return pytest .mark .xfail ,learner
118106
119107
@@ -141,14 +129,7 @@ def linear_with_peak(x: Union[int, float], d: uniform(-1, 1)) -> float:
141129@learn_with (Learner2D ,bounds = ((- 1 ,1 ), (- 1 ,1 )))
142130@learn_with (SequenceLearner ,sequence = np .random .rand (1000 ,2 ))
143131def ring_of_fire (
144- xy :Union [
145- Tuple [float ,float ],
146- np .ndarray ,
147- Tuple [int ,int ],
148- Tuple [float ,float ],
149- Tuple [float ,float ],
150- ],
151- d :uniform (0.2 ,1 ),
132+ xy :Union [np .ndarray ,Tuple [float ,float ]],d :uniform (0.2 ,1 ),
152133)-> float :
153134a = 0.2
154135x ,y = xy
@@ -158,8 +139,7 @@ def ring_of_fire(
158139@learn_with (LearnerND ,bounds = ((- 1 ,1 ), (- 1 ,1 ), (- 1 ,1 )))
159140@learn_with (SequenceLearner ,sequence = np .random .rand (1000 ,3 ))
160141def sphere_of_fire (
161- xyz :Union [Tuple [float ,float ,float ],Tuple [int ,int ,int ],np .ndarray ],
162- d :uniform (0.2 ,1 ),
142+ xyz :Union [Tuple [float ,float ,float ],np .ndarray ],d :uniform (0.2 ,1 ),
163143)-> float :
164144a = 0.2
165145x ,y ,z = xyz
@@ -177,16 +157,7 @@ def gaussian(n: int) -> float:
177157
178158# Create a sequence of learner parameters by adding all
179159# possible loss functions to an existing parameter set.
180- def add_loss_to_params (
181- learner_type :Union [
182- Type [Learner2D ],
183- Type [SequenceLearner ],
184- Type [AverageLearner ],
185- Type [Learner1D ],
186- Type [LearnerND ],
187- ],
188- existing_params :Dict [str ,Any ],
189- )-> Any :
160+ def add_loss_to_params (learner_type ,existing_params :Dict [str ,Any ],)-> Any :
190161if learner_type not in LOSS_FUNCTIONS :
191162return [existing_params ]
192163loss_param ,loss_functions = LOSS_FUNCTIONS [learner_type ]
@@ -216,12 +187,7 @@ def ask_randomly(
216187learner :Union [Learner1D ,LearnerND ,Learner2D ],
217188rounds :Tuple [int ,int ],
218189points :Tuple [int ,int ],
219- )-> Union [
220- Tuple [List [Union [Tuple [float ,float ,float ],Tuple [int ,int ,int ]]],List [float ]],
221- Tuple [List [Union [Tuple [float ,float ],Tuple [int ,int ]]],List [float ]],
222- Tuple [List [float ],List [float ]],
223- Tuple [List [Union [Tuple [int ,int ],Tuple [float ,float ]]],List [float ]],
224- ]:
190+ ):
225191n_rounds = random .randrange (* rounds )
226192n_points = [random .randrange (* points )for _ in range (n_rounds )]
227193
@@ -240,7 +206,7 @@ def ask_randomly(
240206
241207@run_with (Learner1D )
242208def test_uniform_sampling1D (
243- learner_type : Type [ Learner1D ] ,
209+ learner_type ,
244210f :Callable ,
245211learner_kwargs :Dict [str ,Union [Tuple [int ,int ],Callable ]],
246212)-> None :
@@ -262,7 +228,7 @@ def test_uniform_sampling1D(
262228@pytest .mark .xfail
263229@run_with (Learner2D ,LearnerND )
264230def test_uniform_sampling2D (
265- learner_type : Union [ Type [ Learner2D ], Type [ LearnerND ]] ,
231+ learner_type ,
266232f :Callable ,
267233learner_kwargs :Dict [
268234str ,
@@ -304,8 +270,7 @@ def test_uniform_sampling2D(
304270 ],
305271)
306272def test_learner_accepts_lists (
307- learner_type :Union [Type [Learner2D ],Type [LearnerND ],Type [Learner1D ]],
308- bounds :Union [Tuple [int ,int ],List [Tuple [int ,int ]]],
273+ learner_type ,bounds :Union [Tuple [int ,int ],List [Tuple [int ,int ]]],
309274)-> None :
310275def f (x ):
311276return [0 ,1 ]
@@ -316,11 +281,7 @@ def f(x):
316281
317282@run_with (Learner1D ,Learner2D ,LearnerND ,SequenceLearner )
318283def test_adding_existing_data_is_idempotent (
319- learner_type :Union [
320- Type [SequenceLearner ],Type [LearnerND ],Type [Learner1D ],Type [Learner2D ]
321- ],
322- f :Callable ,
323- learner_kwargs :Dict [str ,Any ],
284+ learner_type ,f :Callable ,learner_kwargs :Dict [str ,Any ],
324285)-> None :
325286"""Adding already existing data is an idempotent operation.
326287
@@ -369,15 +330,7 @@ def test_adding_existing_data_is_idempotent(
369330# but we xfail it now, as Learner2D will be deprecated anyway
370331@run_with (Learner1D ,xfail (Learner2D ),LearnerND ,AverageLearner ,SequenceLearner )
371332def test_adding_non_chosen_data (
372- learner_type :Union [
373- Type [Learner2D ],
374- Type [SequenceLearner ],
375- Type [AverageLearner ],
376- Type [Learner1D ],
377- Type [LearnerND ],
378- ],
379- f :Callable ,
380- learner_kwargs :Dict [str ,Any ],
333+ learner_type ,f :Callable ,learner_kwargs :Dict [str ,Any ],
381334)-> None :
382335"""Adding data for a point that was not returned by 'ask'."""
383336# XXX: learner, control and bounds are not defined
@@ -421,9 +374,7 @@ def test_adding_non_chosen_data(
421374
422375@run_with (Learner1D ,xfail (Learner2D ),xfail (LearnerND ),AverageLearner )
423376def test_point_adding_order_is_irrelevant (
424- learner_type :Union [
425- Type [AverageLearner ],Type [LearnerND ],Type [Learner1D ],Type [Learner2D ]
426- ],
377+ learner_type ,
427378f :Callable ,
428379learner_kwargs :Dict [
429380str ,
@@ -478,9 +429,7 @@ def test_point_adding_order_is_irrelevant(
478429# see https://github.com/python-adaptive/adaptive/issues/55
479430@run_with (Learner1D ,xfail (Learner2D ),LearnerND ,AverageLearner )
480431def test_expected_loss_improvement_is_less_than_total_loss (
481- learner_type :Union [
482- Type [AverageLearner ],Type [LearnerND ],Type [Learner1D ],Type [Learner2D ]
483- ],
432+ learner_type ,
484433f :Callable ,
485434learner_kwargs :Dict [
486435str ,
@@ -519,7 +468,7 @@ def test_expected_loss_improvement_is_less_than_total_loss(
519468# but we xfail it now, as Learner2D will be deprecated anyway
520469@run_with (Learner1D ,xfail (Learner2D ),LearnerND )
521470def test_learner_performance_is_invariant_under_scaling (
522- learner_type : Union [ Type [ Learner2D ], Type [ LearnerND ], Type [ Learner1D ]] ,
471+ learner_type ,
523472f :Callable ,
524473learner_kwargs :Dict [
525474str ,
@@ -583,15 +532,7 @@ def test_learner_performance_is_invariant_under_scaling(
583532with_all_loss_functions = False ,
584533)
585534def test_balancing_learner (
586- learner_type :Union [
587- Type [Learner2D ],
588- Type [SequenceLearner ],
589- Type [AverageLearner ],
590- Type [Learner1D ],
591- Type [LearnerND ],
592- ],
593- f :Callable ,
594- learner_kwargs :Dict [str ,Any ],
535+ learner_type ,f :Callable ,learner_kwargs :Dict [str ,Any ],
595536)-> None :
596537"""Test if the BalancingLearner works with the different types of learners."""
597538learners = [
@@ -638,17 +579,7 @@ def test_balancing_learner(
638579SequenceLearner ,
639580with_all_loss_functions = False ,
640581)
641- def test_saving (
642- learner_type :Union [
643- Type [Learner2D ],
644- Type [SequenceLearner ],
645- Type [AverageLearner ],
646- Type [Learner1D ],
647- Type [LearnerND ],
648- ],
649- f :Callable ,
650- learner_kwargs :Dict [str ,Any ],
651- )-> None :
582+ def test_saving (learner_type ,f :Callable ,learner_kwargs :Dict [str ,Any ],)-> None :
652583f = generate_random_parametrization (f )
653584learner = learner_type (f ,** learner_kwargs )
654585control = learner_type (f ,** learner_kwargs )
@@ -680,15 +611,7 @@ def test_saving(
680611with_all_loss_functions = False ,
681612)
682613def test_saving_of_balancing_learner (
683- learner_type :Union [
684- Type [Learner2D ],
685- Type [SequenceLearner ],
686- Type [AverageLearner ],
687- Type [Learner1D ],
688- Type [LearnerND ],
689- ],
690- f :Callable ,
691- learner_kwargs :Dict [str ,Any ],
614+ learner_type ,f :Callable ,learner_kwargs :Dict [str ,Any ],
692615)-> None :
693616f = generate_random_parametrization (f )
694617learner = BalancingLearner ([learner_type (f ,** learner_kwargs )])
@@ -727,9 +650,7 @@ def fname(learner):
727650with_all_loss_functions = False ,
728651)
729652def test_saving_with_datasaver (
730- learner_type :Union [
731- Type [Learner2D ],Type [AverageLearner ],Type [LearnerND ],Type [Learner1D ]
732- ],
653+ learner_type ,
733654f :Callable ,
734655learner_kwargs :Dict [
735656str ,
@@ -770,7 +691,7 @@ def test_saving_with_datasaver(
770691@pytest .mark .xfail
771692@run_with (Learner1D ,Learner2D ,LearnerND )
772693def test_convergence_for_arbitrary_ordering (
773- learner_type : Union [ Type [ Learner2D ], Type [ LearnerND ], Type [ Learner1D ]] ,
694+ learner_type ,
774695f :Callable ,
775696learner_kwargs :Dict [
776697str ,
@@ -794,17 +715,7 @@ def test_convergence_for_arbitrary_ordering(
794715@pytest .mark .xfail
795716@run_with (Learner1D ,Learner2D ,LearnerND )
796717def test_learner_subdomain (
797- learner_type :Union [Type [Learner2D ],Type [LearnerND ],Type [Learner1D ]],
798- f :Callable ,
799- learner_kwargs :Dict [
800- str ,
801- Union [
802- Tuple [Tuple [int ,int ],Tuple [int ,int ]],
803- Callable ,
804- Tuple [int ,int ],
805- Tuple [Tuple [int ,int ],Tuple [int ,int ],Tuple [int ,int ]],
806- ],
807- ],
718+ learner_type ,f :Callable ,learner_kwargs ,
808719):
809720"""Learners that never receive data outside of a subdomain should
810721 perform 'similarly' to learners defined on that subdomain only."""