Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit49e73f7

Browse files
committed
fix point distribution for loss strategy
1 parent53edb77 commit49e73f7

File tree

1 file changed

+20
-9
lines changed

1 file changed

+20
-9
lines changed

‎adaptive/learner/balancing_learner.py‎

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def strategy(self, strategy):
111111
raiseValueError(
112112
'Only strategy="loss_improvements", strategy="loss", or'
113113
' strategy="npoints" is implemented.')
114+
self._points= {}# reset the cache
114115

115116
def_ask_and_tell_based_on_loss_improvements(self,n):
116117
chosen_points= []
@@ -125,12 +126,14 @@ def _ask_and_tell_based_on_loss_improvements(self, n):
125126
self._points[index]=learner.ask(
126127
n=1,tell_pending=False)
127128
points,loss_improvements=self._points[index]
128-
npoints=npoints_per_learner[index]+learner.npoints
129+
npoints= (npoints_per_learner[index]
130+
+learner.npoints
131+
+len(learner.pending_points))
129132
priority= (loss_improvements[0],-npoints)
130133
improvements_per_learner.append(priority)
131134
points_per_learner.append((index,points[0]))
132135

133-
#Chose the optimal improvement.
136+
#Choose the optimal improvement.
134137
(index,point), (loss_improvement,_)=max(
135138
zip(points_per_learner,improvements_per_learner),
136139
key=itemgetter(1))
@@ -142,15 +145,23 @@ def _ask_and_tell_based_on_loss_improvements(self, n):
142145
returnchosen_points,chosen_loss_improvements
143146

144147
def_ask_and_tell_based_on_loss(self,n):
145-
points= []
146-
loss_improvements= []
148+
chosen_points= []
149+
chosen_loss_improvements= []
150+
npoints_per_learner=defaultdict(int)
151+
147152
for_inrange(n):
148153
losses=self._losses(real=False)
149-
max_ind=np.argmax(losses)
150-
xs,ls=self.learners[max_ind].ask(1)
151-
points.append((max_ind,xs[0]))
152-
loss_improvements.append(ls[0])
153-
returnpoints,loss_improvements
154+
npoints= [-(l.npoints
155+
+npoints_per_learner[i]
156+
+len(l.pending_points))
157+
fori,linenumerate(self.learners)]
158+
priority=zip(losses,npoints)
159+
index, (_,_)=max(enumerate(priority),key=itemgetter(1))
160+
npoints_per_learner[index]+=1
161+
points,loss_improvements=self.learners[index].ask(1)
162+
chosen_points.append((index,points[0]))
163+
chosen_loss_improvements.append(loss_improvements[0])
164+
returnchosen_points,chosen_loss_improvements
154165

155166
def_ask_and_tell_based_on_npoints(self,n):
156167
points= []

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp