@@ -66,31 +66,44 @@ func TestBlockedRequestPhase1_GeoIPBlocking(t *testing.T) {
6666geoIPBlock ,err := geoIPHandler .LoadGeoIPDatabase (geoIPdata )
6767assert .NoError (t ,err )
6868
69+ state := & WAFState {}
70+
6971middleware := & Middleware {
7072logger :logger ,
73+ ipBlacklist :iptrie .NewTrie (),
7174geoIPHandler :geoIPHandler ,
7275CountryBlock :CountryAccessFilter {
7376Enabled :true ,
74- CountryList : []string {"US" },
77+ CountryList : []string {"US" , "RU" },
7578GeoIPDBPath :geoIPdata ,// Path to a test GeoIP database
7679geoIP :geoIPBlock ,
7780},
7881CustomResponses :customResponse ,
7982}
8083
81- // Simulate a request from a blocked country (US)
82- req := httptest .NewRequest ("GET" ,testURL ,nil )
83- req .RemoteAddr = googleUSIP
8484w := httptest .NewRecorder ()
85- state := & WAFState {}
8685
87- // Process the request in Phase 1
88- middleware .handlePhase (w ,req ,1 ,state )
86+ t .Run ("Allow unblocked CN by GeoIP" ,func (t * testing.T ) {
87+ req := httptest .NewRequest ("GET" ,testURL ,nil )
88+ req .RemoteAddr = aliCNIP
8989
90- // Verify that the request was blocked
91- assert .True (t ,state .Blocked ,"Request should be blocked" )
92- assert .Equal (t ,http .StatusForbidden ,w .Code ,"Expected status code 403" )
93- assert .Contains (t ,w .Body .String (),"Access Denied" ,"Response body should contain 'Access Denied'" )
90+ // Process the request in Phase 1
91+ middleware .handlePhase (w ,req ,1 ,state )
92+ assert .False (t ,state .Blocked ,"Request should be allowed" )
93+ })
94+
95+ t .Run ("Block US IP by GeoIP" ,func (t * testing.T ) {
96+ req := httptest .NewRequest ("GET" ,testURL ,nil )
97+ req .RemoteAddr = googleUSIP
98+
99+ // Process the request in Phase 1
100+ middleware .handlePhase (w ,req ,1 ,state )
101+
102+ // Verify that the request was blocked
103+ assert .True (t ,state .Blocked ,"Request should be blocked" )
104+ assert .Equal (t ,http .StatusForbidden ,w .Code ,"Expected status code 403" )
105+ assert .Contains (t ,w .Body .String (),"Access Denied" ,"Response body should contain 'Access Denied'" )
106+ })
94107}
95108
96109func TestBlockedRequestPhase1_IPBlocking (t * testing.T ) {
@@ -133,11 +146,11 @@ func TestBlockedRequestPhase1_IPBlocking(t *testing.T) {
133146CustomResponses :customResponse ,
134147}
135148
136- req0 := httptest .NewRequest ("GET" ,testURL ,nil )
137- req0 .RemoteAddr = "192.168.1.1"
149+ req := httptest .NewRequest ("GET" ,testURL ,nil )
150+ req .RemoteAddr = "192.168.1.1"
138151
139152// Process the request in Phase 1
140- middleware .handlePhase (w ,req0 ,1 ,state )
153+ middleware .handlePhase (w ,req ,1 ,state )
141154
142155// Verify that the request was blocked
143156assert .True (t ,state .Blocked ,"Request should be blocked" )