@@ -105,7 +105,7 @@ static PGresult *storeQueryResult(volatile storeInfo *sinfo, PGconn *conn, const
105
105
static void storeRow (volatile storeInfo * sinfo ,PGresult * res ,bool first );
106
106
static remoteConn * getConnectionByName (const char * name );
107
107
static HTAB * createConnHash (void );
108
- static void createNewConnection (const char * name , remoteConn * rconn );
108
+ static remoteConn * createNewConnection (const char * name );
109
109
static void deleteConnection (const char * name );
110
110
static char * * get_pkey_attnames (Relation rel ,int16 * indnkeyatts );
111
111
static char * * get_text_array_contents (ArrayType * array ,int * numitems );
@@ -119,7 +119,8 @@ static Relation get_rel_from_relname(text *relname_text, LOCKMODE lockmode, AclM
119
119
static char * generate_relation_name (Relation rel );
120
120
static void dblink_connstr_check (const char * connstr );
121
121
static bool dblink_connstr_has_pw (const char * connstr );
122
- static void dblink_security_check (PGconn * conn ,remoteConn * rconn ,const char * connstr );
122
+ static void dblink_security_check (PGconn * conn ,const char * connname ,
123
+ const char * connstr );
123
124
static void dblink_res_error (PGconn * conn ,const char * conname ,PGresult * res ,
124
125
bool fail ,const char * fmt ,...)pg_attribute_printf (5 ,6 );
125
126
static char * get_connect_string (const char * servername );
@@ -147,16 +148,22 @@ static uint32 dblink_we_get_conn = 0;
147
148
static uint32 dblink_we_get_result = 0 ;
148
149
149
150
/*
150
- *Following islist that holds multiple remote connections.
151
+ *Following ishash that holds multiple remote connections.
151
152
*Calling convention of each dblink function changes to accept
152
- *connection name as the first parameter. The connectionlist is
153
+ *connection name as the first parameter. The connectionhash is
153
154
*much like ecpg e.g. a mapping between a name and a PGconn object.
155
+ *
156
+ *To avoid potentially leaking a PGconn object in case of out-of-memory
157
+ *errors, we first create the hash entry, then open the PGconn.
158
+ *Hence, a hash entry whose rconn.conn pointer is NULL must be
159
+ *understood as a leftover from a failed create; it should be ignored
160
+ *by lookup operations, and silently replaced by create operations.
154
161
*/
155
162
156
163
typedef struct remoteConnHashEnt
157
164
{
158
165
char name [NAMEDATALEN ];
159
- remoteConn * rconn ;
166
+ remoteConn rconn ;
160
167
}remoteConnHashEnt ;
161
168
162
169
/* initial number of connection hashes */
@@ -233,7 +240,7 @@ dblink_get_conn(char *conname_or_str,
233
240
errmsg ("could not establish connection" ),
234
241
errdetail_internal ("%s" ,msg )));
235
242
}
236
- dblink_security_check (conn ,rconn ,connstr );
243
+ dblink_security_check (conn ,NULL ,connstr );
237
244
if (PQclientEncoding (conn )!= GetDatabaseEncoding ())
238
245
PQsetClientEncoding (conn ,GetDatabaseEncodingName ());
239
246
freeconn = true;
@@ -296,15 +303,6 @@ dblink_connect(PG_FUNCTION_ARGS)
296
303
else if (PG_NARGS ()== 1 )
297
304
conname_or_str = text_to_cstring (PG_GETARG_TEXT_PP (0 ));
298
305
299
- if (connname )
300
- {
301
- rconn = (remoteConn * )MemoryContextAlloc (TopMemoryContext ,
302
- sizeof (remoteConn ));
303
- rconn -> conn = NULL ;
304
- rconn -> openCursorCount = 0 ;
305
- rconn -> newXactForCursor = false;
306
- }
307
-
308
306
/* first check for valid foreign data server */
309
307
connstr = get_connect_string (conname_or_str );
310
308
if (connstr == NULL )
@@ -317,15 +315,22 @@ dblink_connect(PG_FUNCTION_ARGS)
317
315
if (dblink_we_connect == 0 )
318
316
dblink_we_connect = WaitEventExtensionNew ("DblinkConnect" );
319
317
318
+ /* if we need a hashtable entry, make that first, since it might fail */
319
+ if (connname )
320
+ {
321
+ rconn = createNewConnection (connname );
322
+ Assert (rconn -> conn == NULL );
323
+ }
324
+
320
325
/* OK to make connection */
321
326
conn = libpqsrv_connect (connstr ,dblink_we_connect );
322
327
323
328
if (PQstatus (conn )== CONNECTION_BAD )
324
329
{
325
330
msg = pchomp (PQerrorMessage (conn ));
326
331
libpqsrv_disconnect (conn );
327
- if (rconn )
328
- pfree ( rconn );
332
+ if (connname )
333
+ deleteConnection ( connname );
329
334
330
335
ereport (ERROR ,
331
336
(errcode (ERRCODE_SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION ),
@@ -334,16 +339,16 @@ dblink_connect(PG_FUNCTION_ARGS)
334
339
}
335
340
336
341
/* check password actually used if not superuser */
337
- dblink_security_check (conn ,rconn ,connstr );
342
+ dblink_security_check (conn ,connname ,connstr );
338
343
339
344
/* attempt to set client encoding to match server encoding, if needed */
340
345
if (PQclientEncoding (conn )!= GetDatabaseEncoding ())
341
346
PQsetClientEncoding (conn ,GetDatabaseEncodingName ());
342
347
348
+ /* all OK, save away the conn */
343
349
if (connname )
344
350
{
345
351
rconn -> conn = conn ;
346
- createNewConnection (connname ,rconn );
347
352
}
348
353
else
349
354
{
@@ -383,10 +388,7 @@ dblink_disconnect(PG_FUNCTION_ARGS)
383
388
384
389
libpqsrv_disconnect (conn );
385
390
if (rconn )
386
- {
387
391
deleteConnection (conname );
388
- pfree (rconn );
389
- }
390
392
else
391
393
pconn -> conn = NULL ;
392
394
@@ -1304,6 +1306,9 @@ dblink_get_connections(PG_FUNCTION_ARGS)
1304
1306
hash_seq_init (& status ,remoteConnHash );
1305
1307
while ((hentry = (remoteConnHashEnt * )hash_seq_search (& status ))!= NULL )
1306
1308
{
1309
+ /* ignore it if it's not an open connection */
1310
+ if (hentry -> rconn .conn == NULL )
1311
+ continue ;
1307
1312
/* stash away current value */
1308
1313
astate = accumArrayResult (astate ,
1309
1314
CStringGetTextDatum (hentry -> name ),
@@ -2539,8 +2544,8 @@ getConnectionByName(const char *name)
2539
2544
hentry = (remoteConnHashEnt * )hash_search (remoteConnHash ,
2540
2545
key ,HASH_FIND ,NULL );
2541
2546
2542
- if (hentry )
2543
- return hentry -> rconn ;
2547
+ if (hentry && hentry -> rconn . conn != NULL )
2548
+ return & hentry -> rconn ;
2544
2549
2545
2550
return NULL ;
2546
2551
}
@@ -2557,8 +2562,8 @@ createConnHash(void)
2557
2562
HASH_ELEM |HASH_STRINGS );
2558
2563
}
2559
2564
2560
- static void
2561
- createNewConnection (const char * name , remoteConn * rconn )
2565
+ static remoteConn *
2566
+ createNewConnection (const char * name )
2562
2567
{
2563
2568
remoteConnHashEnt * hentry ;
2564
2569
bool found ;
@@ -2572,17 +2577,15 @@ createNewConnection(const char *name, remoteConn *rconn)
2572
2577
hentry = (remoteConnHashEnt * )hash_search (remoteConnHash ,key ,
2573
2578
HASH_ENTER ,& found );
2574
2579
2575
- if (found )
2576
- {
2577
- libpqsrv_disconnect (rconn -> conn );
2578
- pfree (rconn );
2579
-
2580
+ if (found && hentry -> rconn .conn != NULL )
2580
2581
ereport (ERROR ,
2581
2582
(errcode (ERRCODE_DUPLICATE_OBJECT ),
2582
2583
errmsg ("duplicate connection name" )));
2583
- }
2584
2584
2585
- hentry -> rconn = rconn ;
2585
+ /* New, or reusable, so initialize the rconn struct to zeroes */
2586
+ memset (& hentry -> rconn ,0 ,sizeof (remoteConn ));
2587
+
2588
+ return & hentry -> rconn ;
2586
2589
}
2587
2590
2588
2591
static void
@@ -2671,9 +2674,12 @@ dblink_connstr_has_required_scram_options(const char *connstr)
2671
2674
* We need to make sure that the connection made used credentials
2672
2675
* which were provided by the user, so check what credentials were
2673
2676
* used to connect and then make sure that they came from the user.
2677
+ *
2678
+ * On failure, we close "conn" and also delete the hashtable entry
2679
+ * identified by "connname" (if that's not NULL).
2674
2680
*/
2675
2681
static void
2676
- dblink_security_check (PGconn * conn ,remoteConn * rconn ,const char * connstr )
2682
+ dblink_security_check (PGconn * conn ,const char * connname ,const char * connstr )
2677
2683
{
2678
2684
/* Superuser bypasses security check */
2679
2685
if (superuser ())
@@ -2703,8 +2709,8 @@ dblink_security_check(PGconn *conn, remoteConn *rconn, const char *connstr)
2703
2709
2704
2710
/* Otherwise, fail out */
2705
2711
libpqsrv_disconnect (conn );
2706
- if (rconn )
2707
- pfree ( rconn );
2712
+ if (connname )
2713
+ deleteConnection ( connname );
2708
2714
2709
2715
ereport (ERROR ,
2710
2716
(errcode (ERRCODE_S_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED ),