@@ -50,10 +50,10 @@ def __init__(self, conn_params: ConnectionParams):
50
50
self .ssh_key = conn_params .ssh_key
51
51
self .port = conn_params .port
52
52
self .ssh_cmd = ["-o StrictHostKeyChecking=no" ]
53
- if self .ssh_key :
54
- self .ssh_cmd += ["-i" ,self .ssh_key ]
55
53
if self .port :
56
54
self .ssh_cmd += ["-p" ,self .port ]
55
+ if self .ssh_key :
56
+ self .ssh_cmd += ["-i" ,self .ssh_key ]
57
57
self .remote = True
58
58
self .username = conn_params .username or self .get_user ()
59
59
self .tunnel_process = None
@@ -285,6 +285,7 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
285
285
mode = "r+b" if binary else "r+"
286
286
287
287
with tempfile .NamedTemporaryFile (mode = mode ,delete = False )as tmp_file :
288
+ # Because in scp we set up port using -P option instead -p
288
289
scp_ssh_cmd = ['-P' if x == '-p' else x for x in self .ssh_cmd ]
289
290
290
291
if not truncate :
@@ -304,12 +305,11 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
304
305
tmp_file .write (data )
305
306
306
307
tmp_file .flush ()
307
- # Because in scp we set up port using -P option
308
308
scp_cmd = ['scp' ]+ scp_ssh_cmd + [tmp_file .name ,f"{ self .username } @{ self .host } :{ filename } " ]
309
309
subprocess .run (scp_cmd ,check = True )
310
-
311
310
remote_directory = os .path .dirname (filename )
312
- mkdir_cmd = ['ssh' ]+ scp_ssh_cmd + [f"{ self .username } @{ self .host } " ,f"mkdir -p{ remote_directory } " ]
311
+
312
+ mkdir_cmd = ['ssh' ]+ self .ssh_cmd + [f"{ self .username } @{ self .host } " ,f'mkdir -p{ remote_directory } ' ]
313
313
subprocess .run (mkdir_cmd ,check = True )
314
314
315
315
os .remove (tmp_file .name )
@@ -387,9 +387,10 @@ def get_process_children(self, pid):
387
387
# Database control
388
388
def db_connect (self ,dbname ,user ,password = None ,host = "localhost" ,port = 5432 ):
389
389
"""
390
- Established SSH tunnel andConnects to a PostgreSQL
390
+ Establish SSH tunnel andconnect to a PostgreSQL database.
391
391
"""
392
- self .establish_ssh_tunnel (local_port = reserve_port (),remote_port = 5432 )
392
+ self .establish_ssh_tunnel (local_port = port ,remote_port = self .conn_params .port )
393
+
393
394
try :
394
395
conn = pglib .connect (
395
396
host = host ,
@@ -398,6 +399,11 @@ def db_connect(self, dbname, user, password=None, host="localhost", port=5432):
398
399
user = user ,
399
400
password = password ,
400
401
)
402
+ print ("Database connection established successfully." )
401
403
return conn
402
404
except Exception as e :
403
- raise Exception (f"Could not connect to the database. Error:{ e } " )
405
+ print (f"Error connecting to the database:{ str (e )} " )
406
+ if self .tunnel_process :
407
+ self .tunnel_process .terminate ()
408
+ print ("SSH tunnel closed due to connection failure." )
409
+ raise