33import os
44import subprocess
55import tempfile
6- import time
76
8- import sshtunnel
7+ # we support both pg8000 and psycopg2
8+ try :
9+ import psycopg2 as pglib
10+ except ImportError :
11+ try :
12+ import pg8000 as pglib
13+ except ImportError :
14+ raise ImportError ("You must have psycopg2 or pg8000 modules installed" )
915
1016from ..exceptions import ExecUtilException
1117
1218from .os_ops import OsOperations ,ConnectionParams
13- from .os_ops import pglib
14-
15- sshtunnel .SSH_TIMEOUT = 5.0
16- sshtunnel .TUNNEL_TIMEOUT = 5.0
1719
1820ConsoleEncoding = locale .getdefaultlocale ()[1 ]
1921if not ConsoleEncoding :
@@ -50,21 +52,28 @@ def __init__(self, conn_params: ConnectionParams):
5052self .remote = True
5153self .username = conn_params .username or self .get_user ()
5254self .add_known_host (self .host )
55+ self .tunnel_process = None
5356
5457def __enter__ (self ):
5558return self
5659
5760def __exit__ (self ,exc_type ,exc_val ,exc_tb ):
58- self .close_tunnel ()
61+ self .close_ssh_tunnel ()
5962
60- def close_tunnel (self ):
61- if getattr (self ,'tunnel' ,None ):
62- self .tunnel .stop (force = True )
63- start_time = time .time ()
64- while self .tunnel .is_active :
65- if time .time ()- start_time > sshtunnel .TUNNEL_TIMEOUT :
66- break
67- time .sleep (0.5 )
63+ def establish_ssh_tunnel (self ,local_port ,remote_port ):
64+ """
65+ Establish an SSH tunnel from a local port to a remote PostgreSQL port.
66+ """
67+ ssh_cmd = ['-N' ,'-L' ,f"{ local_port } :localhost:{ remote_port } " ]
68+ self .tunnel_process = self .exec_command (ssh_cmd ,get_process = True ,timeout = 300 )
69+
70+ def close_ssh_tunnel (self ):
71+ if hasattr (self ,'tunnel_process' ):
72+ self .tunnel_process .terminate ()
73+ self .tunnel_process .wait ()
74+ del self .tunnel_process
75+ else :
76+ print ("No active tunnel to close." )
6877
6978def add_known_host (self ,host ):
7079cmd = 'ssh-keyscan -H %s >> /home/%s/.ssh/known_hosts' % (host ,os .getlogin ())
@@ -78,21 +87,29 @@ def add_known_host(self, host):
7887raise ExecUtilException (message = "Failed to add %s to known_hosts. Error: %s" % (host ,str (e )),command = cmd ,
7988exit_code = e .returncode ,out = e .stderr )
8089
81- def exec_command (self ,cmd : str ,wait_exit = False ,verbose = False ,expect_error = False ,
90+ def exec_command (self ,cmd ,wait_exit = False ,verbose = False ,expect_error = False ,
8291encoding = None ,shell = True ,text = False ,input = None ,stdin = None ,stdout = None ,
83- stderr = None ,proc = None ):
92+ stderr = None ,get_process = None , timeout = None ):
8493"""
8594 Execute a command in the SSH session.
8695 Args:
8796 - cmd (str): The command to be executed.
8897 """
98+ ssh_cmd = []
8999if isinstance (cmd ,str ):
90100ssh_cmd = ['ssh' ,f"{ self .username } @{ self .host } " ,'-i' ,self .ssh_key ,cmd ]
91101elif isinstance (cmd ,list ):
92102ssh_cmd = ['ssh' ,f"{ self .username } @{ self .host } " ,'-i' ,self .ssh_key ]+ cmd
93103process = subprocess .Popen (ssh_cmd ,stdin = subprocess .PIPE ,stdout = subprocess .PIPE ,stderr = subprocess .PIPE )
104+ if get_process :
105+ return process
106+
107+ try :
108+ result ,error = process .communicate (input ,timeout = timeout )
109+ except subprocess .TimeoutExpired :
110+ process .kill ()
111+ raise ExecUtilException ("Command timed out after {} seconds." .format (timeout ))
94112
95- result ,error = process .communicate (input )
96113exit_status = process .returncode
97114
98115if encoding :
@@ -372,41 +389,19 @@ def get_process_children(self, pid):
372389raise ExecUtilException (f"Error in getting process children. Error:{ result .stderr } " )
373390
374391# Database control
375- def db_connect (self ,dbname ,user ,password = None ,host = "127.0.0.1 " ,port = 5432 , ssh_key = None ):
392+ def db_connect (self ,dbname ,user ,password = None ,host = "localhost " ,port = 5432 ):
376393"""
377- Connects to a PostgreSQL database on the remote system.
378- Args:
379- - dbname (str): The name of the database to connect to.
380- - user (str): The username for the database connection.
381- - password (str, optional): The password for the database connection. Defaults to None.
382- - host (str, optional): The IP address of the remote system. Defaults to "localhost".
383- - port (int, optional): The port number of the PostgreSQL service. Defaults to 5432.
384-
385- This function establishes a connection to a PostgreSQL database on the remote system using the specified
386- parameters. It returns a connection object that can be used to interact with the database.
394+ Established SSH tunnel and Connects to a PostgreSQL
387395 """
388- self .close_tunnel ()
389- self .tunnel = sshtunnel .open_tunnel (
390- (self .host ,22 ),# Remote server IP and SSH port
391- ssh_username = self .username ,
392- ssh_pkey = self .ssh_key ,
393- remote_bind_address = (self .host ,port ),# PostgreSQL server IP and PostgreSQL port
394- local_bind_address = ('localhost' ,0 )
395- # Local machine IP and available port (0 means it will pick any available port)
396- )
397- self .tunnel .start ()
398-
396+ self .establish_ssh_tunnel (local_port = port ,remote_port = 5432 )
399397try :
400- # Use localhost and self.tunnel.local_bind_port to connect
401398conn = pglib .connect (
402- host = 'localhost' , # Connect to localhost
403- port = self . tunnel . local_bind_port , # use the local bind port set up by the tunnel
399+ host = host ,
400+ port = port ,
404401database = dbname ,
405- user = user or self . username ,
406- password = password
402+ user = user ,
403+ password = password ,
407404 )
408-
409405return conn
410406except Exception as e :
411- self .tunnel .stop ()
412- raise ExecUtilException ("Could not create db tunnel. {}" .format (e ))
407+ raise Exception (f"Could not connect to the database. Error:{ e } " )