@@ -111,6 +111,7 @@ def strategy(self, strategy):
111111raise ValueError (
112112'Only strategy="loss_improvements", strategy="loss", or'
113113' strategy="npoints" is implemented.' )
114+ self ._points = {}# reset the cache
114115
115116def _ask_and_tell_based_on_loss_improvements (self ,n ):
116117chosen_points = []
@@ -125,12 +126,14 @@ def _ask_and_tell_based_on_loss_improvements(self, n):
125126self ._points [index ]= learner .ask (
126127n = 1 ,tell_pending = False )
127128points ,loss_improvements = self ._points [index ]
128- npoints = npoints_per_learner [index ]+ learner .npoints
129+ npoints = (npoints_per_learner [index ]
130+ + learner .npoints
131+ + len (learner .pending_points ))
129132priority = (loss_improvements [0 ],- npoints )
130133improvements_per_learner .append (priority )
131134points_per_learner .append ((index ,points [0 ]))
132135
133- #Chose the optimal improvement.
136+ #Choose the optimal improvement.
134137 (index ,point ), (loss_improvement ,_ )= max (
135138zip (points_per_learner ,improvements_per_learner ),
136139key = itemgetter (1 ))
@@ -142,15 +145,23 @@ def _ask_and_tell_based_on_loss_improvements(self, n):
142145return chosen_points ,chosen_loss_improvements
143146
144147def _ask_and_tell_based_on_loss (self ,n ):
145- points = []
146- loss_improvements = []
148+ chosen_points = []
149+ chosen_loss_improvements = []
150+ npoints_per_learner = defaultdict (int )
151+
147152for _ in range (n ):
148153losses = self ._losses (real = False )
149- max_ind = np .argmax (losses )
150- xs ,ls = self .learners [max_ind ].ask (1 )
151- points .append ((max_ind ,xs [0 ]))
152- loss_improvements .append (ls [0 ])
153- return points ,loss_improvements
154+ npoints = [- (l .npoints
155+ + npoints_per_learner [i ]
156+ + len (l .pending_points ))
157+ for i ,l in enumerate (self .learners )]
158+ priority = zip (losses ,npoints )
159+ index , (_ ,_ )= max (enumerate (priority ),key = itemgetter (1 ))
160+ npoints_per_learner [index ]+= 1
161+ points ,loss_improvements = self .learners [index ].ask (1 )
162+ chosen_points .append ((index ,points [0 ]))
163+ chosen_loss_improvements .append (loss_improvements [0 ])
164+ return chosen_points ,chosen_loss_improvements
154165
155166def _ask_and_tell_based_on_npoints (self ,n ):
156167points = []