1
- import logging
2
1
import os
2
+ import socket
3
3
import subprocess
4
4
import tempfile
5
5
import platform
@@ -45,47 +45,44 @@ def __init__(self, conn_params: ConnectionParams):
45
45
self .conn_params = conn_params
46
46
self .host = conn_params .host
47
47
self .ssh_key = conn_params .ssh_key
48
+ self .port = conn_params .port
49
+ self .ssh_args = []
48
50
if self .ssh_key :
49
- self .ssh_cmd = ["-i" ,self .ssh_key ]
50
- else :
51
- self .ssh_cmd = []
51
+ self .ssh_args + = ["-i" ,self .ssh_key ]
52
+ if self . port :
53
+ self .ssh_args + = ["-p" , self . port ]
52
54
self .remote = True
53
55
self .username = conn_params .username
54
56
self .ssh_dest = f"{ self .username } @{ self .host } " if self .username else self .host
55
57
self .add_known_host (self .host )
56
58
self .tunnel_process = None
59
+ self .tunnel_port = None
57
60
58
61
def __enter__ (self ):
59
62
return self
60
63
61
64
def __exit__ (self ,exc_type ,exc_val ,exc_tb ):
62
65
self .close_ssh_tunnel ()
63
66
64
- def establish_ssh_tunnel (self ,local_port ,remote_port ):
65
- """
66
- Establish an SSH tunnel from a local port to a remote PostgreSQL port.
67
- """
68
- ssh_cmd = ['-N' ,'-L' ,f"{ local_port } :localhost:{ remote_port } " ]
69
- self .tunnel_process = self .exec_command (ssh_cmd ,get_process = True ,timeout = 300 )
67
+ @staticmethod
68
+ def is_port_open (host ,port ):
69
+ with socket .socket (socket .AF_INET ,socket .SOCK_STREAM )as sock :
70
+ sock .settimeout (1 )# Таймаут для попытки соединения
71
+ try :
72
+ sock .connect ((host ,port ))
73
+ return True
74
+ except socket .error :
75
+ return False
70
76
71
77
def close_ssh_tunnel (self ):
72
- if hasattr ( self , ' tunnel_process' ) :
78
+ if self . tunnel_process :
73
79
self .tunnel_process .terminate ()
74
80
self .tunnel_process .wait ()
81
+ print ("SSH tunnel closed." )
75
82
del self .tunnel_process
76
83
else :
77
84
print ("No active tunnel to close." )
78
85
79
- def add_known_host (self ,host ):
80
- known_hosts_path = os .path .expanduser ("~/.ssh/known_hosts" )
81
- cmd = 'ssh-keyscan -H %s >> %s' % (host ,known_hosts_path )
82
-
83
- try :
84
- subprocess .check_call (cmd ,shell = True )
85
- logging .info ("Successfully added %s to known_hosts." % host )
86
- except subprocess .CalledProcessError as e :
87
- raise Exception ("Failed to add %s to known_hosts. Error: %s" % (host ,str (e )))
88
-
89
86
def exec_command (self ,cmd ,wait_exit = False ,verbose = False ,expect_error = False ,
90
87
encoding = None ,shell = True ,text = False ,input = None ,stdin = None ,stdout = None ,
91
88
stderr = None ,get_process = None ,timeout = None ):
@@ -96,9 +93,9 @@ def exec_command(self, cmd, wait_exit=False, verbose=False, expect_error=False,
96
93
"""
97
94
ssh_cmd = []
98
95
if isinstance (cmd ,str ):
99
- ssh_cmd = ['ssh' , self . ssh_dest ]+ self .ssh_cmd + [cmd ]
96
+ ssh_cmd = ['ssh' ]+ self .ssh_args + [self . ssh_dest , cmd ]
100
97
elif isinstance (cmd ,list ):
101
- ssh_cmd = ['ssh' , self .ssh_dest ] + self .ssh_cmd + cmd
98
+ ssh_cmd = ['ssh' ] + self .ssh_args + [ self .ssh_dest ] + cmd
102
99
process = subprocess .Popen (ssh_cmd ,stdin = subprocess .PIPE ,stdout = subprocess .PIPE ,stderr = subprocess .PIPE )
103
100
if get_process :
104
101
return process
@@ -243,9 +240,9 @@ def mkdtemp(self, prefix=None):
243
240
- prefix (str): The prefix of the temporary directory name.
244
241
"""
245
242
if prefix :
246
- command = ["ssh" ]+ self .ssh_cmd + [self .ssh_dest ,f"mktemp -d{ prefix } XXXXX" ]
243
+ command = ["ssh" ]+ self .ssh_args + [self .ssh_dest ,f"mktemp -d{ prefix } XXXXX" ]
247
244
else :
248
- command = ["ssh" ]+ self .ssh_cmd + [self .ssh_dest ,"mktemp -d" ]
245
+ command = ["ssh" ]+ self .ssh_args + [self .ssh_dest ,"mktemp -d" ]
249
246
250
247
result = subprocess .run (command ,stdout = subprocess .PIPE ,stderr = subprocess .PIPE ,text = True )
251
248
@@ -288,8 +285,11 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
288
285
mode = "r+b" if binary else "r+"
289
286
290
287
with tempfile .NamedTemporaryFile (mode = mode ,delete = False )as tmp_file :
288
+ # Because in scp we set up port using -P option
289
+ scp_args = ['-P' if x == '-p' else x for x in self .ssh_args ]
290
+
291
291
if not truncate :
292
- scp_cmd = ['scp' ]+ self . ssh_cmd + [f"{ self .ssh_dest } :{ filename } " ,tmp_file .name ]
292
+ scp_cmd = ['scp' ]+ scp_args + [f"{ self .ssh_dest } :{ filename } " ,tmp_file .name ]
293
293
subprocess .run (scp_cmd ,check = False )# The file might not exist yet
294
294
tmp_file .seek (0 ,os .SEEK_END )
295
295
@@ -305,11 +305,11 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
305
305
tmp_file .write (data )
306
306
307
307
tmp_file .flush ()
308
- scp_cmd = ['scp' ]+ self . ssh_cmd + [tmp_file .name ,f"{ self .ssh_dest } :{ filename } " ]
308
+ scp_cmd = ['scp' ]+ scp_args + [tmp_file .name ,f"{ self .ssh_dest } :{ filename } " ]
309
309
subprocess .run (scp_cmd ,check = True )
310
310
311
311
remote_directory = os .path .dirname (filename )
312
- mkdir_cmd = ['ssh' ]+ self .ssh_cmd + [self .ssh_dest ,f"mkdir -p{ remote_directory } " ]
312
+ mkdir_cmd = ['ssh' ]+ self .ssh_args + [self .ssh_dest ,f"mkdir -p{ remote_directory } " ]
313
313
subprocess .run (mkdir_cmd ,check = True )
314
314
315
315
os .remove (tmp_file .name )
@@ -374,7 +374,7 @@ def get_pid(self):
374
374
return int (self .exec_command ("echo $$" ,encoding = get_default_encoding ()))
375
375
376
376
def get_process_children (self ,pid ):
377
- command = ["ssh" ]+ self .ssh_cmd + [self .ssh_dest ,f"pgrep -P{ pid } " ]
377
+ command = ["ssh" ]+ self .ssh_args + [self .ssh_dest ,f"pgrep -P{ pid } " ]
378
378
379
379
result = subprocess .run (command ,stdout = subprocess .PIPE ,stderr = subprocess .PIPE ,text = True )
380
380
@@ -386,18 +386,11 @@ def get_process_children(self, pid):
386
386
387
387
# Database control
388
388
def db_connect (self ,dbname ,user ,password = None ,host = "localhost" ,port = 5432 ):
389
- """
390
- Established SSH tunnel and Connects to a PostgreSQL
391
- """
392
- self .establish_ssh_tunnel (local_port = port ,remote_port = 5432 )
393
- try :
394
- conn = pglib .connect (
395
- host = host ,
396
- port = port ,
397
- database = dbname ,
398
- user = user ,
399
- password = password ,
400
- )
401
- return conn
402
- except Exception as e :
403
- raise Exception (f"Could not connect to the database. Error:{ e } " )
389
+ conn = pglib .connect (
390
+ host = host ,
391
+ port = port ,
392
+ database = dbname ,
393
+ user = user ,
394
+ password = password ,
395
+ )
396
+ return conn