@@ -100,7 +100,7 @@ static PGresult *storeQueryResult(volatile storeInfo *sinfo, PGconn *conn, const
100
100
static void storeRow (volatile storeInfo * sinfo ,PGresult * res ,bool first );
101
101
static remoteConn * getConnectionByName (const char * name );
102
102
static HTAB * createConnHash (void );
103
- static void createNewConnection (const char * name , remoteConn * rconn );
103
+ static remoteConn * createNewConnection (const char * name );
104
104
static void deleteConnection (const char * name );
105
105
static char * * get_pkey_attnames (Relation rel ,int16 * indnkeyatts );
106
106
static char * * get_text_array_contents (ArrayType * array ,int * numitems );
@@ -114,7 +114,8 @@ static Relation get_rel_from_relname(text *relname_text, LOCKMODE lockmode, AclM
114
114
static char * generate_relation_name (Relation rel );
115
115
static void dblink_connstr_check (const char * connstr );
116
116
static bool dblink_connstr_has_pw (const char * connstr );
117
- static void dblink_security_check (PGconn * conn ,remoteConn * rconn ,const char * connstr );
117
+ static void dblink_security_check (PGconn * conn ,const char * connname ,
118
+ const char * connstr );
118
119
static void dblink_res_error (PGconn * conn ,const char * conname ,PGresult * res ,
119
120
bool fail ,const char * fmt ,...)pg_attribute_printf (5 ,6 );
120
121
static char * get_connect_string (const char * servername );
@@ -132,16 +133,22 @@ static remoteConn *pconn = NULL;
132
133
static HTAB * remoteConnHash = NULL ;
133
134
134
135
/*
135
- *Following islist that holds multiple remote connections.
136
+ *Following ishash that holds multiple remote connections.
136
137
*Calling convention of each dblink function changes to accept
137
- *connection name as the first parameter. The connectionlist is
138
+ *connection name as the first parameter. The connectionhash is
138
139
*much like ecpg e.g. a mapping between a name and a PGconn object.
140
+ *
141
+ *To avoid potentially leaking a PGconn object in case of out-of-memory
142
+ *errors, we first create the hash entry, then open the PGconn.
143
+ *Hence, a hash entry whose rconn.conn pointer is NULL must be
144
+ *understood as a leftover from a failed create; it should be ignored
145
+ *by lookup operations, and silently replaced by create operations.
139
146
*/
140
147
141
148
typedef struct remoteConnHashEnt
142
149
{
143
150
char name [NAMEDATALEN ];
144
- remoteConn * rconn ;
151
+ remoteConn rconn ;
145
152
}remoteConnHashEnt ;
146
153
147
154
/* initial number of connection hashes */
@@ -216,7 +223,7 @@ dblink_get_conn(char *conname_or_str,
216
223
errmsg ("could not establish connection" ),
217
224
errdetail_internal ("%s" ,msg )));
218
225
}
219
- dblink_security_check (conn ,rconn ,connstr );
226
+ dblink_security_check (conn ,NULL ,connstr );
220
227
if (PQclientEncoding (conn )!= GetDatabaseEncoding ())
221
228
PQsetClientEncoding (conn ,GetDatabaseEncodingName ());
222
229
freeconn = true;
@@ -276,15 +283,6 @@ dblink_connect(PG_FUNCTION_ARGS)
276
283
else if (PG_NARGS ()== 1 )
277
284
conname_or_str = text_to_cstring (PG_GETARG_TEXT_PP (0 ));
278
285
279
- if (connname )
280
- {
281
- rconn = (remoteConn * )MemoryContextAlloc (TopMemoryContext ,
282
- sizeof (remoteConn ));
283
- rconn -> conn = NULL ;
284
- rconn -> openCursorCount = 0 ;
285
- rconn -> newXactForCursor = false;
286
- }
287
-
288
286
/* first check for valid foreign data server */
289
287
connstr = get_connect_string (conname_or_str );
290
288
if (connstr == NULL )
@@ -293,15 +291,22 @@ dblink_connect(PG_FUNCTION_ARGS)
293
291
/* check password in connection string if not superuser */
294
292
dblink_connstr_check (connstr );
295
293
294
+ /* if we need a hashtable entry, make that first, since it might fail */
295
+ if (connname )
296
+ {
297
+ rconn = createNewConnection (connname );
298
+ Assert (rconn -> conn == NULL );
299
+ }
300
+
296
301
/* OK to make connection */
297
302
conn = libpqsrv_connect (connstr ,PG_WAIT_EXTENSION );
298
303
299
304
if (PQstatus (conn )== CONNECTION_BAD )
300
305
{
301
306
msg = pchomp (PQerrorMessage (conn ));
302
307
libpqsrv_disconnect (conn );
303
- if (rconn )
304
- pfree ( rconn );
308
+ if (connname )
309
+ deleteConnection ( connname );
305
310
306
311
ereport (ERROR ,
307
312
(errcode (ERRCODE_SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION ),
@@ -310,16 +315,16 @@ dblink_connect(PG_FUNCTION_ARGS)
310
315
}
311
316
312
317
/* check password actually used if not superuser */
313
- dblink_security_check (conn ,rconn ,connstr );
318
+ dblink_security_check (conn ,connname ,connstr );
314
319
315
320
/* attempt to set client encoding to match server encoding, if needed */
316
321
if (PQclientEncoding (conn )!= GetDatabaseEncoding ())
317
322
PQsetClientEncoding (conn ,GetDatabaseEncodingName ());
318
323
324
+ /* all OK, save away the conn */
319
325
if (connname )
320
326
{
321
327
rconn -> conn = conn ;
322
- createNewConnection (connname ,rconn );
323
328
}
324
329
else
325
330
{
@@ -359,10 +364,7 @@ dblink_disconnect(PG_FUNCTION_ARGS)
359
364
360
365
libpqsrv_disconnect (conn );
361
366
if (rconn )
362
- {
363
367
deleteConnection (conname );
364
- pfree (rconn );
365
- }
366
368
else
367
369
pconn -> conn = NULL ;
368
370
@@ -1280,6 +1282,9 @@ dblink_get_connections(PG_FUNCTION_ARGS)
1280
1282
hash_seq_init (& status ,remoteConnHash );
1281
1283
while ((hentry = (remoteConnHashEnt * )hash_seq_search (& status ))!= NULL )
1282
1284
{
1285
+ /* ignore it if it's not an open connection */
1286
+ if (hentry -> rconn .conn == NULL )
1287
+ continue ;
1283
1288
/* stash away current value */
1284
1289
astate = accumArrayResult (astate ,
1285
1290
CStringGetTextDatum (hentry -> name ),
@@ -2520,8 +2525,8 @@ getConnectionByName(const char *name)
2520
2525
hentry = (remoteConnHashEnt * )hash_search (remoteConnHash ,
2521
2526
key ,HASH_FIND ,NULL );
2522
2527
2523
- if (hentry )
2524
- return hentry -> rconn ;
2528
+ if (hentry && hentry -> rconn . conn != NULL )
2529
+ return & hentry -> rconn ;
2525
2530
2526
2531
return NULL ;
2527
2532
}
@@ -2538,8 +2543,8 @@ createConnHash(void)
2538
2543
HASH_ELEM |HASH_STRINGS );
2539
2544
}
2540
2545
2541
- static void
2542
- createNewConnection (const char * name , remoteConn * rconn )
2546
+ static remoteConn *
2547
+ createNewConnection (const char * name )
2543
2548
{
2544
2549
remoteConnHashEnt * hentry ;
2545
2550
bool found ;
@@ -2553,18 +2558,15 @@ createNewConnection(const char *name, remoteConn *rconn)
2553
2558
hentry = (remoteConnHashEnt * )hash_search (remoteConnHash ,key ,
2554
2559
HASH_ENTER ,& found );
2555
2560
2556
- if (found )
2557
- {
2558
- libpqsrv_disconnect (rconn -> conn );
2559
- pfree (rconn );
2560
-
2561
+ if (found && hentry -> rconn .conn != NULL )
2561
2562
ereport (ERROR ,
2562
2563
(errcode (ERRCODE_DUPLICATE_OBJECT ),
2563
2564
errmsg ("duplicate connection name" )));
2564
- }
2565
2565
2566
- hentry -> rconn = rconn ;
2567
- strlcpy (hentry -> name ,name ,sizeof (hentry -> name ));
2566
+ /* New, or reusable, so initialize the rconn struct to zeroes */
2567
+ memset (& hentry -> rconn ,0 ,sizeof (remoteConn ));
2568
+
2569
+ return & hentry -> rconn ;
2568
2570
}
2569
2571
2570
2572
static void
@@ -2592,9 +2594,12 @@ deleteConnection(const char *name)
2592
2594
* We need to make sure that the connection made used credentials
2593
2595
* which were provided by the user, so check what credentials were
2594
2596
* used to connect and then make sure that they came from the user.
2597
+ *
2598
+ * On failure, we close "conn" and also delete the hashtable entry
2599
+ * identified by "connname" (if that's not NULL).
2595
2600
*/
2596
2601
static void
2597
- dblink_security_check (PGconn * conn ,remoteConn * rconn ,const char * connstr )
2602
+ dblink_security_check (PGconn * conn ,const char * connname ,const char * connstr )
2598
2603
{
2599
2604
/* Superuser bypasses security check */
2600
2605
if (superuser ())
@@ -2612,8 +2617,8 @@ dblink_security_check(PGconn *conn, remoteConn *rconn, const char *connstr)
2612
2617
2613
2618
/* Otherwise, fail out */
2614
2619
libpqsrv_disconnect (conn );
2615
- if (rconn )
2616
- pfree ( rconn );
2620
+ if (connname )
2621
+ deleteConnection ( connname );
2617
2622
2618
2623
ereport (ERROR ,
2619
2624
(errcode (ERRCODE_S_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED ),