@@ -100,7 +100,7 @@ static PGresult *storeQueryResult(volatile storeInfo *sinfo, PGconn *conn, const
100
100
static void storeRow (volatile storeInfo * sinfo ,PGresult * res ,bool first );
101
101
static remoteConn * getConnectionByName (const char * name );
102
102
static HTAB * createConnHash (void );
103
- static void createNewConnection (const char * name , remoteConn * rconn );
103
+ static remoteConn * createNewConnection (const char * name );
104
104
static void deleteConnection (const char * name );
105
105
static char * * get_pkey_attnames (Relation rel ,int16 * indnkeyatts );
106
106
static char * * get_text_array_contents (ArrayType * array ,int * numitems );
@@ -114,7 +114,8 @@ static Relation get_rel_from_relname(text *relname_text, LOCKMODE lockmode, AclM
114
114
static char * generate_relation_name (Relation rel );
115
115
static void dblink_connstr_check (const char * connstr );
116
116
static bool dblink_connstr_has_pw (const char * connstr );
117
- static void dblink_security_check (PGconn * conn ,remoteConn * rconn ,const char * connstr );
117
+ static void dblink_security_check (PGconn * conn ,const char * connname ,
118
+ const char * connstr );
118
119
static void dblink_res_error (PGconn * conn ,const char * conname ,PGresult * res ,
119
120
bool fail ,const char * fmt ,...)pg_attribute_printf (5 ,6 );
120
121
static char * get_connect_string (const char * servername );
@@ -137,16 +138,22 @@ static uint32 dblink_we_get_conn = 0;
137
138
static uint32 dblink_we_get_result = 0 ;
138
139
139
140
/*
140
- *Following islist that holds multiple remote connections.
141
+ *Following ishash that holds multiple remote connections.
141
142
*Calling convention of each dblink function changes to accept
142
- *connection name as the first parameter. The connectionlist is
143
+ *connection name as the first parameter. The connectionhash is
143
144
*much like ecpg e.g. a mapping between a name and a PGconn object.
145
+ *
146
+ *To avoid potentially leaking a PGconn object in case of out-of-memory
147
+ *errors, we first create the hash entry, then open the PGconn.
148
+ *Hence, a hash entry whose rconn.conn pointer is NULL must be
149
+ *understood as a leftover from a failed create; it should be ignored
150
+ *by lookup operations, and silently replaced by create operations.
144
151
*/
145
152
146
153
typedef struct remoteConnHashEnt
147
154
{
148
155
char name [NAMEDATALEN ];
149
- remoteConn * rconn ;
156
+ remoteConn rconn ;
150
157
}remoteConnHashEnt ;
151
158
152
159
/* initial number of connection hashes */
@@ -225,7 +232,7 @@ dblink_get_conn(char *conname_or_str,
225
232
errmsg ("could not establish connection" ),
226
233
errdetail_internal ("%s" ,msg )));
227
234
}
228
- dblink_security_check (conn ,rconn ,connstr );
235
+ dblink_security_check (conn ,NULL ,connstr );
229
236
if (PQclientEncoding (conn )!= GetDatabaseEncoding ())
230
237
PQsetClientEncoding (conn ,GetDatabaseEncodingName ());
231
238
freeconn = true;
@@ -288,15 +295,6 @@ dblink_connect(PG_FUNCTION_ARGS)
288
295
else if (PG_NARGS ()== 1 )
289
296
conname_or_str = text_to_cstring (PG_GETARG_TEXT_PP (0 ));
290
297
291
- if (connname )
292
- {
293
- rconn = (remoteConn * )MemoryContextAlloc (TopMemoryContext ,
294
- sizeof (remoteConn ));
295
- rconn -> conn = NULL ;
296
- rconn -> openCursorCount = 0 ;
297
- rconn -> newXactForCursor = false;
298
- }
299
-
300
298
/* first check for valid foreign data server */
301
299
connstr = get_connect_string (conname_or_str );
302
300
if (connstr == NULL )
@@ -309,15 +307,22 @@ dblink_connect(PG_FUNCTION_ARGS)
309
307
if (dblink_we_connect == 0 )
310
308
dblink_we_connect = WaitEventExtensionNew ("DblinkConnect" );
311
309
310
+ /* if we need a hashtable entry, make that first, since it might fail */
311
+ if (connname )
312
+ {
313
+ rconn = createNewConnection (connname );
314
+ Assert (rconn -> conn == NULL );
315
+ }
316
+
312
317
/* OK to make connection */
313
318
conn = libpqsrv_connect (connstr ,dblink_we_connect );
314
319
315
320
if (PQstatus (conn )== CONNECTION_BAD )
316
321
{
317
322
msg = pchomp (PQerrorMessage (conn ));
318
323
libpqsrv_disconnect (conn );
319
- if (rconn )
320
- pfree ( rconn );
324
+ if (connname )
325
+ deleteConnection ( connname );
321
326
322
327
ereport (ERROR ,
323
328
(errcode (ERRCODE_SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION ),
@@ -326,16 +331,16 @@ dblink_connect(PG_FUNCTION_ARGS)
326
331
}
327
332
328
333
/* check password actually used if not superuser */
329
- dblink_security_check (conn ,rconn ,connstr );
334
+ dblink_security_check (conn ,connname ,connstr );
330
335
331
336
/* attempt to set client encoding to match server encoding, if needed */
332
337
if (PQclientEncoding (conn )!= GetDatabaseEncoding ())
333
338
PQsetClientEncoding (conn ,GetDatabaseEncodingName ());
334
339
340
+ /* all OK, save away the conn */
335
341
if (connname )
336
342
{
337
343
rconn -> conn = conn ;
338
- createNewConnection (connname ,rconn );
339
344
}
340
345
else
341
346
{
@@ -375,10 +380,7 @@ dblink_disconnect(PG_FUNCTION_ARGS)
375
380
376
381
libpqsrv_disconnect (conn );
377
382
if (rconn )
378
- {
379
383
deleteConnection (conname );
380
- pfree (rconn );
381
- }
382
384
else
383
385
pconn -> conn = NULL ;
384
386
@@ -1296,6 +1298,9 @@ dblink_get_connections(PG_FUNCTION_ARGS)
1296
1298
hash_seq_init (& status ,remoteConnHash );
1297
1299
while ((hentry = (remoteConnHashEnt * )hash_seq_search (& status ))!= NULL )
1298
1300
{
1301
+ /* ignore it if it's not an open connection */
1302
+ if (hentry -> rconn .conn == NULL )
1303
+ continue ;
1299
1304
/* stash away current value */
1300
1305
astate = accumArrayResult (astate ,
1301
1306
CStringGetTextDatum (hentry -> name ),
@@ -2533,8 +2538,8 @@ getConnectionByName(const char *name)
2533
2538
hentry = (remoteConnHashEnt * )hash_search (remoteConnHash ,
2534
2539
key ,HASH_FIND ,NULL );
2535
2540
2536
- if (hentry )
2537
- return hentry -> rconn ;
2541
+ if (hentry && hentry -> rconn . conn != NULL )
2542
+ return & hentry -> rconn ;
2538
2543
2539
2544
return NULL ;
2540
2545
}
@@ -2551,8 +2556,8 @@ createConnHash(void)
2551
2556
HASH_ELEM |HASH_STRINGS );
2552
2557
}
2553
2558
2554
- static void
2555
- createNewConnection (const char * name , remoteConn * rconn )
2559
+ static remoteConn *
2560
+ createNewConnection (const char * name )
2556
2561
{
2557
2562
remoteConnHashEnt * hentry ;
2558
2563
bool found ;
@@ -2566,17 +2571,15 @@ createNewConnection(const char *name, remoteConn *rconn)
2566
2571
hentry = (remoteConnHashEnt * )hash_search (remoteConnHash ,key ,
2567
2572
HASH_ENTER ,& found );
2568
2573
2569
- if (found )
2570
- {
2571
- libpqsrv_disconnect (rconn -> conn );
2572
- pfree (rconn );
2573
-
2574
+ if (found && hentry -> rconn .conn != NULL )
2574
2575
ereport (ERROR ,
2575
2576
(errcode (ERRCODE_DUPLICATE_OBJECT ),
2576
2577
errmsg ("duplicate connection name" )));
2577
- }
2578
2578
2579
- hentry -> rconn = rconn ;
2579
+ /* New, or reusable, so initialize the rconn struct to zeroes */
2580
+ memset (& hentry -> rconn ,0 ,sizeof (remoteConn ));
2581
+
2582
+ return & hentry -> rconn ;
2580
2583
}
2581
2584
2582
2585
static void
@@ -2604,9 +2607,12 @@ deleteConnection(const char *name)
2604
2607
* We need to make sure that the connection made used credentials
2605
2608
* which were provided by the user, so check what credentials were
2606
2609
* used to connect and then make sure that they came from the user.
2610
+ *
2611
+ * On failure, we close "conn" and also delete the hashtable entry
2612
+ * identified by "connname" (if that's not NULL).
2607
2613
*/
2608
2614
static void
2609
- dblink_security_check (PGconn * conn ,remoteConn * rconn ,const char * connstr )
2615
+ dblink_security_check (PGconn * conn ,const char * connname ,const char * connstr )
2610
2616
{
2611
2617
/* Superuser bypasses security check */
2612
2618
if (superuser ())
@@ -2624,8 +2630,8 @@ dblink_security_check(PGconn *conn, remoteConn *rconn, const char *connstr)
2624
2630
2625
2631
/* Otherwise, fail out */
2626
2632
libpqsrv_disconnect (conn );
2627
- if (rconn )
2628
- pfree ( rconn );
2633
+ if (connname )
2634
+ deleteConnection ( connname );
2629
2635
2630
2636
ereport (ERROR ,
2631
2637
(errcode (ERRCODE_S_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED ),