1- import logging
21import os
2+ import socket
33import subprocess
44import tempfile
55import platform
6+ import time
67
78from ..utils import reserve_port
89
@@ -50,10 +51,10 @@ def __init__(self, conn_params: ConnectionParams):
5051self .ssh_key = conn_params .ssh_key
5152self .port = conn_params .port
5253self .ssh_cmd = ["-o StrictHostKeyChecking=no" ]
53- if self .port :
54- self .ssh_cmd += ["-p" ,self .port ]
5554if self .ssh_key :
5655self .ssh_cmd += ["-i" ,self .ssh_key ]
56+ if self .port :
57+ self .ssh_cmd += ["-p" ,self .port ]
5758self .remote = True
5859self .username = conn_params .username or self .get_user ()
5960self .tunnel_process = None
@@ -64,17 +65,36 @@ def __enter__(self):
6465def __exit__ (self ,exc_type ,exc_val ,exc_tb ):
6566self .close_ssh_tunnel ()
6667
68+ @staticmethod
69+ def is_port_open (host ,port ):
70+ with socket .socket (socket .AF_INET ,socket .SOCK_STREAM )as sock :
71+ sock .settimeout (1 )# Таймаут для попытки соединения
72+ try :
73+ sock .connect ((host ,port ))
74+ return True
75+ except socket .error :
76+ return False
77+
6778def establish_ssh_tunnel (self ,local_port ,remote_port ):
6879"""
6980 Establish an SSH tunnel from a local port to a remote PostgreSQL port.
7081 """
7182ssh_cmd = ['-N' ,'-L' ,f"{ local_port } :localhost:{ remote_port } " ]
7283self .tunnel_process = self .exec_command (ssh_cmd ,get_process = True ,timeout = 300 )
84+ timeout = 10
85+ start_time = time .time ()
86+ while time .time ()- start_time < timeout :
87+ if self .is_port_open ('localhost' ,local_port ):
88+ print ("SSH tunnel established." )
89+ return
90+ time .sleep (0.5 )
91+ raise Exception ("Failed to establish SSH tunnel within the timeout period." )
7392
7493def close_ssh_tunnel (self ):
75- if hasattr ( self , ' tunnel_process' ) :
94+ if self . tunnel_process :
7695self .tunnel_process .terminate ()
7796self .tunnel_process .wait ()
97+ print ("SSH tunnel closed." )
7898del self .tunnel_process
7999else :
80100print ("No active tunnel to close." )
@@ -240,9 +260,9 @@ def mkdtemp(self, prefix=None):
240260 - prefix (str): The prefix of the temporary directory name.
241261 """
242262if prefix :
243- command = ["ssh" ] + self . ssh_cmd + [ f"{ self .username } @{ self .host } " , f"mktemp -d{ prefix } XXXXX" ]
263+ command = ["ssh" + f"{ self .username } @{ self .host } " ] + self . ssh_cmd + [ f"mktemp -d{ prefix } XXXXX" ]
244264else :
245- command = ["ssh" ] + self . ssh_cmd + [ f"{ self .username } @{ self .host } " , "mktemp -d" ]
265+ command = ["ssh" , f"{ self .username } @{ self .host } " ] + self . ssh_cmd + [ "mktemp -d" ]
246266
247267result = subprocess .run (command ,stdout = subprocess .PIPE ,stderr = subprocess .PIPE ,text = True )
248268
@@ -285,7 +305,7 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
285305mode = "r+b" if binary else "r+"
286306
287307with tempfile .NamedTemporaryFile (mode = mode ,delete = False )as tmp_file :
288- # Because in scp we set up port using -P option instead -p
308+ # Because in scp we set up port using -P option
289309scp_ssh_cmd = ['-P' if x == '-p' else x for x in self .ssh_cmd ]
290310
291311if not truncate :
@@ -307,9 +327,9 @@ def write(self, filename, data, truncate=False, binary=False, read_and_write=Fal
307327tmp_file .flush ()
308328scp_cmd = ['scp' ]+ scp_ssh_cmd + [tmp_file .name ,f"{ self .username } @{ self .host } :{ filename } " ]
309329subprocess .run (scp_cmd ,check = True )
310- remote_directory = os .path .dirname (filename )
311330
312- mkdir_cmd = ['ssh' ]+ self .ssh_cmd + [f"{ self .username } @{ self .host } " ,f'mkdir -p{ remote_directory } ' ]
331+ remote_directory = os .path .dirname (filename )
332+ mkdir_cmd = ['ssh' ,f"{ self .username } @{ self .host } " ]+ self .ssh_cmd + [f"mkdir -p{ remote_directory } " ]
313333subprocess .run (mkdir_cmd ,check = True )
314334
315335os .remove (tmp_file .name )
@@ -374,7 +394,7 @@ def get_pid(self):
374394return int (self .exec_command ("echo $$" ,encoding = get_default_encoding ()))
375395
376396def get_process_children (self ,pid ):
377- command = ["ssh" ] + self . ssh_cmd + [ f"{ self .username } @{ self .host } " , f"pgrep -P{ pid } " ]
397+ command = ["ssh" , f"{ self .username } @{ self .host } " ] + self . ssh_cmd + [ f"pgrep -P{ pid } " ]
378398
379399result = subprocess .run (command ,stdout = subprocess .PIPE ,stderr = subprocess .PIPE ,text = True )
380400
@@ -389,15 +409,16 @@ def db_connect(self, dbname, user, password=None, host="localhost", port=5432):
389409"""
390410 Establish SSH tunnel and connect to a PostgreSQL database.
391411 """
392- self . establish_ssh_tunnel ( local_port = port , remote_port = self . conn_params . port )
393-
412+ local_port = reserve_port ( )
413+ self . establish_ssh_tunnel ( local_port = local_port , remote_port = port )
394414try :
395415conn = pglib .connect (
396416host = host ,
397- port = port ,
417+ port = local_port ,
398418database = dbname ,
399419user = user ,
400420password = password ,
421+ timeout = 10
401422 )
402423print ("Database connection established successfully." )
403424return conn