1
- import logging
2
1
import os
2
+ import socket
3
3
import subprocess
4
4
import tempfile
5
5
import platform
6
+ import time
6
7
7
8
from ..utils import reserve_port
8
9
@@ -50,10 +51,10 @@ def __init__(self, conn_params: ConnectionParams):
50
51
self .ssh_key = conn_params .ssh_key
51
52
self .port = conn_params .port
52
53
self .ssh_cmd = ["-o StrictHostKeyChecking=no" ]
53
- if self .port :
54
- self .ssh_cmd += ["-p" ,self .port ]
55
54
if self .ssh_key :
56
55
self .ssh_cmd += ["-i" ,self .ssh_key ]
56
+ if self .port :
57
+ self .ssh_cmd += ["-p" ,self .port ]
57
58
self .remote = True
58
59
self .username = conn_params .username or self .get_user ()
59
60
self .tunnel_process = None
@@ -64,17 +65,36 @@ def __enter__(self):
64
65
def __exit__ (self ,exc_type ,exc_val ,exc_tb ):
65
66
self .close_ssh_tunnel ()
66
67
68
+ @staticmethod
69
+ def is_port_open (host ,port ):
70
+ with socket .socket (socket .AF_INET ,socket .SOCK_STREAM )as sock :
71
+ sock .settimeout (1 )# Таймаут для попытки соединения
72
+ try :
73
+ sock .connect ((host ,port ))
74
+ return True
75
+ except socket .error :
76
+ return False
77
+
67
78
def establish_ssh_tunnel (self ,local_port ,remote_port ):
68
79
"""
69
80
Establish an SSH tunnel from a local port to a remote PostgreSQL port.
70
81
"""
71
82
ssh_cmd = ['-N' ,'-L' ,f"{ local_port } :localhost:{ remote_port } " ]
72
83
self .tunnel_process = self .exec_command (ssh_cmd ,get_process = True ,timeout = 300 )
84
+ timeout = 10
85
+ start_time = time .time ()
86
+ while time .time ()- start_time < timeout :
87
+ if self .is_port_open ('localhost' ,local_port ):
88
+ print ("SSH tunnel established." )
89
+ return
90
+ time .sleep (0.5 )
91
+ raise Exception ("Failed to establish SSH tunnel within the timeout period." )
73
92
74
93
def close_ssh_tunnel (self ):
75
- if hasattr ( self , ' tunnel_process' ) :
94
+ if self . tunnel_process :
76
95
self .tunnel_process .terminate ()
77
96
self .tunnel_process .wait ()
97
+ print ("SSH tunnel closed." )
78
98
del self .tunnel_process
79
99
else :
80
100
print ("No active tunnel to close." )
@@ -240,9 +260,9 @@ def mkdtemp(self, prefix=None):
240
260
- prefix (str): The prefix of the temporary directory name.
241
261
"""
242
262
if prefix :
243
- command = ["ssh" ] + self . ssh_cmd + [ f"{ self .username } @{ self .host } " , f"mktemp -d{ prefix } XXXXX" ]
263
+ command = ["ssh" + f"{ self .username } @{ self .host } " ] + self . ssh_cmd + [ f"mktemp -d{ prefix } XXXXX" ]
244
264
else :
245
- command = ["ssh" ] + self . ssh_cmd + [ f"{ self .username } @{ self .host } " , "mktemp -d" ]
265
+ command = ["ssh" , f"{ self .username } @{ self .host } " ] + self . ssh_cmd + [ "mktemp -d" ]
246
266
247
267
result = subprocess .run (command ,stdout = subprocess .PIPE ,stderr = subprocess .PIPE ,text = True )
248
268
@@ -285,7 +305,7 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
285
305
mode = "r+b" if binary else "r+"
286
306
287
307
with tempfile .NamedTemporaryFile (mode = mode ,delete = False )as tmp_file :
288
- # Because in scp we set up port using -P option instead -p
308
+ # Because in scp we set up port using -P option
289
309
scp_ssh_cmd = ['-P' if x == '-p' else x for x in self .ssh_cmd ]
290
310
291
311
if not truncate :
@@ -307,9 +327,9 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
307
327
tmp_file .flush ()
308
328
scp_cmd = ['scp' ]+ scp_ssh_cmd + [tmp_file .name ,f"{ self .username } @{ self .host } :{ filename } " ]
309
329
subprocess .run (scp_cmd ,check = True )
310
- remote_directory = os .path .dirname (filename )
311
330
312
- mkdir_cmd = ['ssh' ]+ self .ssh_cmd + [f"{ self .username } @{ self .host } " ,f'mkdir -p{ remote_directory } ' ]
331
+ remote_directory = os .path .dirname (filename )
332
+ mkdir_cmd = ['ssh' ,f"{ self .username } @{ self .host } " ]+ self .ssh_cmd + [f"mkdir -p{ remote_directory } " ]
313
333
subprocess .run (mkdir_cmd ,check = True )
314
334
315
335
os .remove (tmp_file .name )
@@ -374,7 +394,7 @@ def get_pid(self):
374
394
return int (self .exec_command ("echo $$" ,encoding = get_default_encoding ()))
375
395
376
396
def get_process_children (self ,pid ):
377
- command = ["ssh" ] + self . ssh_cmd + [ f"{ self .username } @{ self .host } " , f"pgrep -P{ pid } " ]
397
+ command = ["ssh" , f"{ self .username } @{ self .host } " ] + self . ssh_cmd + [ f"pgrep -P{ pid } " ]
378
398
379
399
result = subprocess .run (command ,stdout = subprocess .PIPE ,stderr = subprocess .PIPE ,text = True )
380
400
@@ -389,15 +409,16 @@ def db_connect(self, dbname, user, password=None, host="localhost", port=5432):
389
409
"""
390
410
Establish SSH tunnel and connect to a PostgreSQL database.
391
411
"""
392
- self . establish_ssh_tunnel ( local_port = port , remote_port = self . conn_params . port )
393
-
412
+ local_port = reserve_port ( )
413
+ self . establish_ssh_tunnel ( local_port = local_port , remote_port = port )
394
414
try :
395
415
conn = pglib .connect (
396
416
host = host ,
397
- port = port ,
417
+ port = local_port ,
398
418
database = dbname ,
399
419
user = user ,
400
420
password = password ,
421
+ timeout = 10
401
422
)
402
423
print ("Database connection established successfully." )
403
424
return conn