2020error_markers = [b'error' ,b'Permission denied' ]
2121
2222
23+ class PsUtilProcessProxy :
24+ def __init__ (self ,ssh ,pid ):
25+ self .ssh = ssh
26+ self .pid = pid
27+
28+ def kill (self ):
29+ command = f"kill{ self .pid } "
30+ self .ssh .exec_command (command )
31+
32+ def cmdline (self ):
33+ command = f"ps -p{ self .pid } -o cmd --no-headers"
34+ stdin ,stdout ,stderr = self .ssh .exec_command (command )
35+ cmdline = stdout .read ().decode ('utf-8' ).strip ()
36+ return cmdline .split ()
37+
38+
2339class RemoteOperations (OsOperations ):
2440def __init__ (self ,host = "127.0.0.1" ,hostname = 'localhost' ,port = None ,ssh_key = None ,username = None ):
2541super ().__init__ (username )
@@ -71,7 +87,7 @@ def exec_command(self, cmd: str, wait_exit=False, verbose=False, expect_error=Fa
7187self .ssh = self .ssh_connect ()
7288
7389if isinstance (cmd ,list ):
74- cmd = " " .join (cmd )
90+ cmd = ' ' .join (item . decode ( 'utf-8' ) if isinstance ( item , bytes ) else item for item in cmd )
7591if input :
7692stdin ,stdout ,stderr = self .ssh .exec_command (cmd )
7793stdin .write (input )
@@ -140,17 +156,6 @@ def is_executable(self, file):
140156is_exec = self .exec_command (f"test -x{ file } && echo OK" )
141157return is_exec == b"OK\n "
142158
143- def add_to_path (self ,new_path ):
144- pathsep = self .pathsep
145- # Check if the directory is already in PATH
146- path = self .environ ("PATH" )
147- if new_path not in path .split (pathsep ):
148- if self .remote :
149- self .exec_command (f"export PATH={ new_path } { pathsep } { path } " )
150- else :
151- os .environ ["PATH" ]= f"{ new_path } { pathsep } { path } "
152- return pathsep
153-
154159def set_env (self ,var_name :str ,var_val :str ):
155160"""
156161 Set the value of an environment variable.
@@ -243,9 +248,17 @@ def mkdtemp(self, prefix=None):
243248raise ExecUtilException ("Could not create temporary directory." )
244249
245250def mkstemp (self ,prefix = None ):
246- cmd = f"mktemp{ prefix } XXXXXX"
247- filename = self .exec_command (cmd ).strip ()
248- return filename
251+ if prefix :
252+ temp_dir = self .exec_command (f"mktemp{ prefix } XXXXX" ,encoding = 'utf-8' )
253+ else :
254+ temp_dir = self .exec_command ("mktemp" ,encoding = 'utf-8' )
255+
256+ if temp_dir :
257+ if not os .path .isabs (temp_dir ):
258+ temp_dir = os .path .join ('/home' ,self .username ,temp_dir .strip ())
259+ return temp_dir
260+ else :
261+ raise ExecUtilException ("Could not create temporary directory." )
249262
250263def copytree (self ,src ,dst ):
251264if not os .path .isabs (dst ):
@@ -291,7 +304,7 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
291304data = data .encode (encoding )
292305if isinstance (data ,list ):
293306# ensure each line ends with a newline
294- data = [s if s . endswith ( ' \n ' )else s + '\n ' for s in data ]
307+ data = [( s if isinstance ( s , str ) else s . decode ( 'utf-8' )). rstrip ( ' \n ' )+ '\n ' for s in data ]
295308tmp_file .writelines (data )
296309else :
297310tmp_file .write (data )
@@ -351,8 +364,8 @@ def isfile(self, remote_file):
351364
352365def isdir (self ,dirname ):
353366cmd = f"if [ -d{ dirname } ]; then echo True; else echo False; fi"
354- response = self .exec_command (cmd , encoding = 'utf-8' )
355- return response .strip ()== "True"
367+ response = self .exec_command (cmd )
368+ return response .strip ()== b "True"
356369
357370def remove_file (self ,filename ):
358371cmd = f"rm{ filename } "
@@ -366,16 +379,16 @@ def kill(self, pid, signal):
366379
367380def get_pid (self ):
368381# Get current process id
369- return self .exec_command ("echo $$" )
382+ return int ( self .exec_command ("echo $$" , encoding = 'utf-8' ) )
370383
371384def get_remote_children (self ,pid ):
372385command = f"pgrep -P{ pid } "
373386stdin ,stdout ,stderr = self .ssh .exec_command (command )
374387children = stdout .readlines ()
375- return [int (child_pid .strip ())for child_pid in children ]
388+ return [PsUtilProcessProxy ( self . ssh , int (child_pid .strip () ))for child_pid in children ]
376389
377390# Database control
378- def db_connect (self ,dbname ,user ,password = None ,host = "127.0.0.1" ,port = 5432 ):
391+ def db_connect (self ,dbname ,user ,password = None ,host = "127.0.0.1" ,port = 5432 , ssh_key = None ):
379392"""
380393 Connects to a PostgreSQL database on the remote system.
381394 Args:
@@ -389,19 +402,26 @@ def db_connect(self, dbname, user, password=None, host="127.0.0.1", port=5432):
389402 This function establishes a connection to a PostgreSQL database on the remote system using the specified
390403 parameters. It returns a connection object that can be used to interact with the database.
391404 """
392- with sshtunnel .open_tunnel (
393- (host ,22 ),# Remote server IP and SSH port
394- ssh_username = self .username ,
395- ssh_pkey = self .ssh_key ,
396- remote_bind_address = (host ,port ),# PostgreSQL server IP and PostgreSQL port
397- local_bind_address = ('localhost' ,port ),# Local machine IP and available port
398- ):
405+ tunnel = sshtunnel .open_tunnel (
406+ (host ,22 ),# Remote server IP and SSH port
407+ ssh_username = user or self .username ,
408+ ssh_pkey = ssh_key or self .ssh_key ,
409+ remote_bind_address = (host ,port ),# PostgreSQL server IP and PostgreSQL port
410+ local_bind_address = ('localhost' ,port )# Local machine IP and available port
411+ )
412+
413+ tunnel .start ()
414+
415+ try :
399416conn = pglib .connect (
400- host = host ,
401- port = port ,
417+ host = host ,# change to 'localhost' because we're connecting through a local ssh tunnel
418+ port = tunnel . local_bind_port , # use the local bind port set up by the tunnel
402419dbname = dbname ,
403- user = user ,
420+ user = user or self . username ,
404421password = password
405422 )
406423
407- return conn
424+ return conn
425+ except Exception as e :
426+ tunnel .stop ()
427+ raise e