Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitd9ceaf7

Browse files
committed
make the ask functions more similar
1 parent1401a9a commitd9ceaf7

File tree

1 file changed

+9
-15
lines changed

1 file changed

+9
-15
lines changed

‎adaptive/learner/balancing_learner.py‎

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,8 @@ def strategy(self, strategy):
115115
def_ask_and_tell_based_on_loss_improvements(self,n):
116116
chosen_points= []
117117
chosen_loss_improvements= []
118-
npoints_per_learner=defaultdict(int)
119-
118+
npoints=[l.npoints+len(l.pending_points)
119+
forlinself.learners]
120120
for_inrange(n):
121121
improvements_per_learner= []
122122
points_per_learner= []
@@ -126,18 +126,16 @@ def _ask_and_tell_based_on_loss_improvements(self, n):
126126
self._points[index]=learner.ask(
127127
n=1,tell_pending=False)
128128
points,loss_improvements=self._points[index]
129-
npoints= (npoints_per_learner[index]
130-
+learner.npoints
131-
+len(learner.pending_points))
132-
priority= (loss_improvements[0],-npoints)
129+
130+
priority= (loss_improvements[0],-npoints[index])
133131
improvements_per_learner.append(priority)
134132
points_per_learner.append((index,points[0]))
135133

136134
# Choose the optimal improvement.
137135
(index,point), (loss_improvement,_)=max(
138136
zip(points_per_learner,improvements_per_learner),
139137
key=itemgetter(1))
140-
npoints_per_learner[index]+=1
138+
npoints[index]+=1
141139
chosen_points.append((index,point))
142140
chosen_loss_improvements.append(loss_improvement)
143141
self.tell_pending((index,point))
@@ -147,17 +145,13 @@ def _ask_and_tell_based_on_loss_improvements(self, n):
147145
def_ask_and_tell_based_on_loss(self,n):
148146
chosen_points= []
149147
chosen_loss_improvements= []
150-
npoints_per_learner=defaultdict(int)
151-
148+
npoints=[l.npoints+len(l.pending_points)
149+
forlinself.learners]
152150
for_inrange(n):
153151
losses=self._losses(real=False)
154-
npoints= [-(l.npoints
155-
+npoints_per_learner[i]
156-
+len(l.pending_points))
157-
fori,linenumerate(self.learners)]
158-
priority=zip(losses,npoints)
152+
priority=zip(losses, (-nforninnpoints))
159153
index=max(enumerate(priority),key=itemgetter(1))[0]
160-
npoints_per_learner[index]+=1
154+
npoints[index]+=1
161155

162156
# Take the points from the cache
163157
ifindexnotinself._points:

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp