@@ -1187,3 +1187,152 @@ func Test_CreatePullRequestReview(t *testing.T) {
1187
1187
})
1188
1188
}
1189
1189
}
1190
+
1191
+ func Test_CreatePullRequest (t * testing.T ) {
1192
+ // Verify tool definition once
1193
+ mockClient := github .NewClient (nil )
1194
+ tool ,_ := createPullRequest (mockClient ,translations .NullTranslationHelper )
1195
+
1196
+ assert .Equal (t ,"create_pull_request" ,tool .Name )
1197
+ assert .NotEmpty (t ,tool .Description )
1198
+ assert .Contains (t ,tool .InputSchema .Properties ,"owner" )
1199
+ assert .Contains (t ,tool .InputSchema .Properties ,"repo" )
1200
+ assert .Contains (t ,tool .InputSchema .Properties ,"title" )
1201
+ assert .Contains (t ,tool .InputSchema .Properties ,"body" )
1202
+ assert .Contains (t ,tool .InputSchema .Properties ,"head" )
1203
+ assert .Contains (t ,tool .InputSchema .Properties ,"base" )
1204
+ assert .Contains (t ,tool .InputSchema .Properties ,"draft" )
1205
+ assert .Contains (t ,tool .InputSchema .Properties ,"maintainer_can_modify" )
1206
+ assert .ElementsMatch (t ,tool .InputSchema .Required , []string {"owner" ,"repo" ,"title" ,"head" ,"base" })
1207
+
1208
+ // Setup mock PR for success case
1209
+ mockPR := & github.PullRequest {
1210
+ Number :github .Ptr (42 ),
1211
+ Title :github .Ptr ("Test PR" ),
1212
+ State :github .Ptr ("open" ),
1213
+ HTMLURL :github .Ptr ("https://github.com/owner/repo/pull/42" ),
1214
+ Head :& github.PullRequestBranch {
1215
+ SHA :github .Ptr ("abcd1234" ),
1216
+ Ref :github .Ptr ("feature-branch" ),
1217
+ },
1218
+ Base :& github.PullRequestBranch {
1219
+ SHA :github .Ptr ("efgh5678" ),
1220
+ Ref :github .Ptr ("main" ),
1221
+ },
1222
+ Body :github .Ptr ("This is a test PR" ),
1223
+ Draft :github .Ptr (false ),
1224
+ MaintainerCanModify :github .Ptr (true ),
1225
+ User :& github.User {
1226
+ Login :github .Ptr ("testuser" ),
1227
+ },
1228
+ }
1229
+
1230
+ tests := []struct {
1231
+ name string
1232
+ mockedClient * http.Client
1233
+ requestArgs map [string ]interface {}
1234
+ expectError bool
1235
+ expectedPR * github.PullRequest
1236
+ expectedErrMsg string
1237
+ }{
1238
+ {
1239
+ name :"successful PR creation" ,
1240
+ mockedClient :mock .NewMockedHTTPClient (
1241
+ mock .WithRequestMatchHandler (
1242
+ mock .PostReposPullsByOwnerByRepo ,
1243
+ mockResponse (t ,http .StatusCreated ,mockPR ),
1244
+ ),
1245
+ ),
1246
+
1247
+ requestArgs :map [string ]interface {}{
1248
+ "owner" :"owner" ,
1249
+ "repo" :"repo" ,
1250
+ "title" :"Test PR" ,
1251
+ "body" :"This is a test PR" ,
1252
+ "head" :"feature-branch" ,
1253
+ "base" :"main" ,
1254
+ "draft" :false ,
1255
+ "maintainer_can_modify" :true ,
1256
+ },
1257
+ expectError :false ,
1258
+ expectedPR :mockPR ,
1259
+ },
1260
+ {
1261
+ name :"missing required parameter" ,
1262
+ mockedClient :mock .NewMockedHTTPClient (),
1263
+ requestArgs :map [string ]interface {}{
1264
+ "owner" :"owner" ,
1265
+ "repo" :"repo" ,
1266
+ // missing title, head, base
1267
+ },
1268
+ expectError :true ,
1269
+ expectedErrMsg :"missing required parameter: title" ,
1270
+ },
1271
+ {
1272
+ name :"PR creation fails" ,
1273
+ mockedClient :mock .NewMockedHTTPClient (
1274
+ mock .WithRequestMatchHandler (
1275
+ mock .PostReposPullsByOwnerByRepo ,
1276
+ http .HandlerFunc (func (w http.ResponseWriter ,_ * http.Request ) {
1277
+ w .WriteHeader (http .StatusUnprocessableEntity )
1278
+ _ ,_ = w .Write ([]byte (`{"message":"Validation failed","errors":[{"resource":"PullRequest","code":"invalid"}]}` ))
1279
+ }),
1280
+ ),
1281
+ ),
1282
+ requestArgs :map [string ]interface {}{
1283
+ "owner" :"owner" ,
1284
+ "repo" :"repo" ,
1285
+ "title" :"Test PR" ,
1286
+ "head" :"feature-branch" ,
1287
+ "base" :"main" ,
1288
+ },
1289
+ expectError :true ,
1290
+ expectedErrMsg :"failed to create pull request" ,
1291
+ },
1292
+ }
1293
+
1294
+ for _ ,tc := range tests {
1295
+ t .Run (tc .name ,func (t * testing.T ) {
1296
+ // Setup client with mock
1297
+ client := github .NewClient (tc .mockedClient )
1298
+ _ ,handler := createPullRequest (client ,translations .NullTranslationHelper )
1299
+
1300
+ // Create call request
1301
+ request := createMCPRequest (tc .requestArgs )
1302
+
1303
+ // Call handler
1304
+ result ,err := handler (context .Background (),request )
1305
+
1306
+ // Verify results
1307
+ if tc .expectError {
1308
+ if err != nil {
1309
+ assert .Contains (t ,err .Error (),tc .expectedErrMsg )
1310
+ return
1311
+ }
1312
+
1313
+ // If no error returned but in the result
1314
+ textContent := getTextResult (t ,result )
1315
+ assert .Contains (t ,textContent .Text ,tc .expectedErrMsg )
1316
+ return
1317
+ }
1318
+
1319
+ require .NoError (t ,err )
1320
+
1321
+ // Parse the result and get the text content if no error
1322
+ textContent := getTextResult (t ,result )
1323
+
1324
+ // Unmarshal and verify the result
1325
+ var returnedPR github.PullRequest
1326
+ err = json .Unmarshal ([]byte (textContent .Text ),& returnedPR )
1327
+ require .NoError (t ,err )
1328
+ assert .Equal (t ,* tc .expectedPR .Number ,* returnedPR .Number )
1329
+ assert .Equal (t ,* tc .expectedPR .Title ,* returnedPR .Title )
1330
+ assert .Equal (t ,* tc .expectedPR .State ,* returnedPR .State )
1331
+ assert .Equal (t ,* tc .expectedPR .HTMLURL ,* returnedPR .HTMLURL )
1332
+ assert .Equal (t ,* tc .expectedPR .Head .SHA ,* returnedPR .Head .SHA )
1333
+ assert .Equal (t ,* tc .expectedPR .Base .Ref ,* returnedPR .Base .Ref )
1334
+ assert .Equal (t ,* tc .expectedPR .Body ,* returnedPR .Body )
1335
+ assert .Equal (t ,* tc .expectedPR .User .Login ,* returnedPR .User .Login )
1336
+ })
1337
+ }
1338
+ }