1
1
package trafficgen
2
2
3
3
import (
4
+ "bytes"
4
5
"context"
5
6
"encoding/json"
6
7
"io"
@@ -12,6 +13,7 @@ import (
12
13
13
14
"cdr.dev/slog"
14
15
"cdr.dev/slog/sloggers/sloghuman"
16
+
15
17
"github.com/coder/coder/coderd/tracing"
16
18
"github.com/coder/coder/codersdk"
17
19
"github.com/coder/coder/cryptorand"
@@ -72,14 +74,14 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error {
72
74
_ = conn .Close ()
73
75
}()
74
76
75
- // Wrap the conn in a countReadWriter so we can monitor bytes sent/rcvd.
76
- crw := countReadWriter {ReadWriter :conn }
77
-
78
77
// Set a deadline for stopping the text.
79
78
start := time .Now ()
80
79
deadlineCtx ,cancel := context .WithDeadline (ctx ,start .Add (r .cfg .Duration ))
81
80
defer cancel ()
82
81
82
+ // Wrap the conn in a countReadWriter so we can monitor bytes sent/rcvd.
83
+ crw := countReadWriter {ReadWriter :conn ,ctx :deadlineCtx }
84
+
83
85
// Create a ticker for sending data to the PTY.
84
86
tick := time .NewTicker (time .Duration (tickInterval ))
85
87
defer tick .Stop ()
@@ -88,10 +90,15 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error {
88
90
rch := make (chan error )
89
91
wch := make (chan error )
90
92
93
+ go func () {
94
+ <- deadlineCtx .Done ()
95
+ logger .Debug (ctx ,"context deadline reached" ,slog .F ("duration" ,time .Since (start )))
96
+ }()
97
+
91
98
// Read forever in the background.
92
99
go func () {
93
100
logger .Debug (ctx ,"reading from agent" ,slog .F ("agent_id" ,agentID ))
94
- rch <- readContext (deadlineCtx ,& crw ,bytesPerTick * 2 )
101
+ rch <- drainContext (deadlineCtx ,& crw ,bytesPerTick * 2 )
95
102
logger .Debug (ctx ,"done reading from agent" ,slog .F ("agent_id" ,agentID ))
96
103
conn .Close ()
97
104
close (rch )
@@ -109,14 +116,17 @@ func (r *Runner) Run(ctx context.Context, _ string, logs io.Writer) error {
109
116
if wErr := <- wch ;wErr != nil {
110
117
return xerrors .Errorf ("write to pty: %w" ,wErr )
111
118
}
119
+ drainStart := time .Now ()
112
120
if rErr := <- rch ;rErr != nil {
113
121
return xerrors .Errorf ("read from pty: %w" ,rErr )
114
122
}
115
123
116
124
duration := time .Since (start )
125
+ drainDuration := time .Since (drainStart )
117
126
118
- logger .Info (ctx ,"trafficgen result " ,
127
+ logger .Info (ctx ,"results " ,
119
128
slog .F ("duration" ,duration ),
129
+ slog .F ("drain" ,drainDuration ),
120
130
slog .F ("sent" ,crw .BytesWritten ()),
121
131
slog .F ("rcvd" ,crw .BytesRead ()),
122
132
)
@@ -129,14 +139,34 @@ func (*Runner) Cleanup(context.Context, string) error {
129
139
return nil
130
140
}
131
141
132
- func readContext (ctx context.Context ,src io.Reader ,bufSize int64 )error {
133
- buf := make ([]byte ,bufSize )
142
+ // drainContext drains from src until it returns io.EOF or ctx times out.
143
+ func drainContext (ctx context.Context ,src io.Reader ,bufSize int64 )error {
144
+ errCh := make (chan error )
145
+ done := make (chan struct {})
146
+ go func () {
147
+ tmp := make ([]byte ,bufSize )
148
+ buf := bytes .NewBuffer (tmp )
149
+ for {
150
+ select {
151
+ case <- done :
152
+ return
153
+ default :
154
+ _ ,err := io .CopyN (buf ,src ,1 )
155
+ // _, err := src.Read(tmp)
156
+ if err != nil {
157
+ errCh <- err
158
+ close (errCh )
159
+ return
160
+ }
161
+ }
162
+ }
163
+ }()
134
164
for {
135
165
select {
136
166
case <- ctx .Done ():
167
+ close (done )
137
168
return nil
138
- default :
139
- _ ,err := src .Read (buf )
169
+ case err := <- errCh :
140
170
if err != nil {
141
171
if xerrors .Is (err ,io .EOF ) {
142
172
return nil
@@ -175,31 +205,37 @@ func copyContext(ctx context.Context, dst io.Writer, src []byte) (int, error) {
175
205
case <- ctx .Done ():
176
206
return count ,nil
177
207
default :
178
- n ,err := dst .Write (src )
179
- if err != nil {
180
- if xerrors .Is (err ,io .EOF ) {
181
- // On an EOF, assume that all of src was consumed.
182
- return len (src ),nil
208
+ for idx := range src {
209
+ n ,err := dst .Write (src [idx :idx + 1 ])
210
+ if err != nil {
211
+ if xerrors .Is (err ,io .EOF ) {
212
+ return count ,nil
213
+ }
214
+ if xerrors .Is (err ,context .DeadlineExceeded ) {
215
+ // It's OK if we reach the deadline before writing the full payload.
216
+ return count ,nil
217
+ }
218
+ return count ,err
183
219
}
184
- return count ,err
185
- }
186
- count += n
187
- if n == len (src ) {
188
- return count ,nil
220
+ count += n
189
221
}
190
- // Not all of src was consumed. Update src and retry.
191
- src = src [n :]
222
+ return count ,nil
192
223
}
193
224
}
194
225
}
195
226
227
+ // countReadWriter wraps an io.ReadWriter and counts the number of bytes read and written.
196
228
type countReadWriter struct {
229
+ ctx context.Context
197
230
io.ReadWriter
198
231
bytesRead atomic.Int64
199
232
bytesWritten atomic.Int64
200
233
}
201
234
202
235
func (w * countReadWriter )Read (p []byte ) (int ,error ) {
236
+ if err := w .ctx .Err ();err != nil {
237
+ return 0 ,err
238
+ }
203
239
n ,err := w .ReadWriter .Read (p )
204
240
if err == nil {
205
241
w .bytesRead .Add (int64 (n ))
@@ -208,6 +244,9 @@ func (w *countReadWriter) Read(p []byte) (int, error) {
208
244
}
209
245
210
246
func (w * countReadWriter )Write (p []byte ) (int ,error ) {
247
+ if err := w .ctx .Err ();err != nil {
248
+ return 0 ,err
249
+ }
211
250
n ,err := w .ReadWriter .Write (p )
212
251
if err == nil {
213
252
w .bytesWritten .Add (int64 (n ))