@@ -8,14 +8,19 @@ import (
8
8
9
9
"github.com/go-chi/chi/v5"
10
10
"github.com/prometheus/client_golang/prometheus"
11
+ cm"github.com/prometheus/client_model/go"
12
+ "github.com/stretchr/testify/assert"
11
13
"github.com/stretchr/testify/require"
12
14
13
15
"github.com/coder/coder/v2/coderd/httpmw"
14
16
"github.com/coder/coder/v2/coderd/tracing"
17
+ "github.com/coder/coder/v2/testutil"
18
+ "github.com/coder/websocket"
15
19
)
16
20
17
21
func TestPrometheus (t * testing.T ) {
18
22
t .Parallel ()
23
+
19
24
t .Run ("All" ,func (t * testing.T ) {
20
25
t .Parallel ()
21
26
req := httptest .NewRequest ("GET" ,"/" ,nil )
@@ -29,4 +34,90 @@ func TestPrometheus(t *testing.T) {
29
34
require .NoError (t ,err )
30
35
require .Greater (t ,len (metrics ),0 )
31
36
})
37
+
38
+ t .Run ("Concurrent" ,func (t * testing.T ) {
39
+ t .Parallel ()
40
+ ctx ,cancel := context .WithTimeout (context .Background (),testutil .WaitShort )
41
+ defer cancel ()
42
+
43
+ reg := prometheus .NewRegistry ()
44
+ promMW := httpmw .Prometheus (reg )
45
+
46
+ // Create a test handler to simulate a WebSocket connection
47
+ testHandler := http .HandlerFunc (func (rw http.ResponseWriter ,r * http.Request ) {
48
+ conn ,err := websocket .Accept (rw ,r ,nil )
49
+ if ! assert .NoError (t ,err ,"failed to accept websocket" ) {
50
+ return
51
+ }
52
+ defer conn .Close (websocket .StatusGoingAway ,"" )
53
+ })
54
+
55
+ wrappedHandler := promMW (testHandler )
56
+
57
+ r := chi .NewRouter ()
58
+ r .Use (tracing .StatusWriterMiddleware ,promMW )
59
+ r .Get ("/api/v2/build/{build}/logs" ,func (rw http.ResponseWriter ,r * http.Request ) {
60
+ wrappedHandler .ServeHTTP (rw ,r )
61
+ })
62
+
63
+ srv := httptest .NewServer (r )
64
+ defer srv .Close ()
65
+ // nolint: bodyclose
66
+ conn ,_ ,err := websocket .Dial (ctx ,srv .URL + "/api/v2/build/1/logs" ,nil )
67
+ require .NoError (t ,err ,"failed to dial WebSocket" )
68
+ defer conn .Close (websocket .StatusNormalClosure ,"" )
69
+
70
+ metrics ,err := reg .Gather ()
71
+ require .NoError (t ,err )
72
+ require .Greater (t ,len (metrics ),0 )
73
+ metricLabels := getMetricLabels (metrics )
74
+
75
+ concurrentWebsockets ,ok := metricLabels ["coderd_api_concurrent_websockets" ]
76
+ require .True (t ,ok ,"coderd_api_concurrent_websockets metric not found" )
77
+ require .Equal (t ,"/api/v2/build/{build}/logs" ,concurrentWebsockets ["path" ])
78
+ })
79
+
80
+ t .Run ("UserRoute" ,func (t * testing.T ) {
81
+ t .Parallel ()
82
+ reg := prometheus .NewRegistry ()
83
+ promMW := httpmw .Prometheus (reg )
84
+
85
+ r := chi .NewRouter ()
86
+ r .With (promMW ).Get ("/api/v2/users/{user}" ,func (w http.ResponseWriter ,r * http.Request ) {})
87
+
88
+ req := httptest .NewRequest ("GET" ,"/api/v2/users/john" ,nil )
89
+
90
+ sw := & tracing.StatusWriter {ResponseWriter :httptest .NewRecorder ()}
91
+
92
+ r .ServeHTTP (sw ,req )
93
+
94
+ metrics ,err := reg .Gather ()
95
+ require .NoError (t ,err )
96
+ require .Greater (t ,len (metrics ),0 )
97
+ metricLabels := getMetricLabels (metrics )
98
+
99
+ reqProcessed ,ok := metricLabels ["coderd_api_requests_processed_total" ]
100
+ require .True (t ,ok ,"coderd_api_requests_processed_total metric not found" )
101
+ require .Equal (t ,"/api/v2/users/{user}" ,reqProcessed ["path" ])
102
+ require .Equal (t ,"GET" ,reqProcessed ["method" ])
103
+
104
+ concurrentRequests ,ok := metricLabels ["coderd_api_concurrent_requests" ]
105
+ require .True (t ,ok ,"coderd_api_concurrent_requests metric not found" )
106
+ require .Equal (t ,"/api/v2/users/{user}" ,concurrentRequests ["path" ])
107
+ require .Equal (t ,"GET" ,concurrentRequests ["method" ])
108
+ })
109
+ }
110
+
111
+ func getMetricLabels (metrics []* cm.MetricFamily )map [string ]map [string ]string {
112
+ metricLabels := map [string ]map [string ]string {}
113
+ for _ ,metricFamily := range metrics {
114
+ metricName := metricFamily .GetName ()
115
+ metricLabels [metricName ]= map [string ]string {}
116
+ for _ ,metric := range metricFamily .GetMetric () {
117
+ for _ ,labelPair := range metric .GetLabel () {
118
+ metricLabels [metricName ][labelPair .GetName ()]= labelPair .GetValue ()
119
+ }
120
+ }
121
+ }
122
+ return metricLabels
32
123
}