Join GitHub today
GitHub is home to over 28 million developers working together to host and review code, manage projects, and build software together.
Sign up# Wrapper module for _ssl, providing some additional facilities | |
# implemented in Python. Written by Bill Janssen. | |
"""This module provides some more Pythonic support for SSL. | |
Object types: | |
SSLSocket -- subtype of socket.socket which does SSL over the socket | |
Exceptions: | |
SSLError -- exception raised for I/O errors | |
Functions: | |
cert_time_to_seconds -- convert time string used for certificate | |
notBefore and notAfter functions to integer | |
seconds past the Epoch (the time values | |
returned from time.time()) | |
fetch_server_certificate (HOST, PORT) -- fetch the certificate provided | |
by the server running on HOST at port PORT. No | |
validation of the certificate is performed. | |
Integer constants: | |
SSL_ERROR_ZERO_RETURN | |
SSL_ERROR_WANT_READ | |
SSL_ERROR_WANT_WRITE | |
SSL_ERROR_WANT_X509_LOOKUP | |
SSL_ERROR_SYSCALL | |
SSL_ERROR_SSL | |
SSL_ERROR_WANT_CONNECT | |
SSL_ERROR_EOF | |
SSL_ERROR_INVALID_ERROR_CODE | |
The following group define certificate requirements that one side is | |
allowing/requiring from the other side: | |
CERT_NONE - no certificates from the other side are required (or will | |
be looked at if provided) | |
CERT_OPTIONAL - certificates are not required, but if provided will be | |
validated, and if validation fails, the connection will | |
also fail | |
CERT_REQUIRED - certificates are required, and will be validated, and | |
if validation fails, the connection will also fail | |
The following constants identify various SSL protocol variants: | |
PROTOCOL_SSLv2 | |
PROTOCOL_SSLv3 | |
PROTOCOL_SSLv23 | |
PROTOCOL_TLS | |
PROTOCOL_TLS_CLIENT | |
PROTOCOL_TLS_SERVER | |
PROTOCOL_TLSv1 | |
PROTOCOL_TLSv1_1 | |
PROTOCOL_TLSv1_2 | |
The following constants identify various SSL alert message descriptions as per | |
http://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-6 | |
ALERT_DESCRIPTION_CLOSE_NOTIFY | |
ALERT_DESCRIPTION_UNEXPECTED_MESSAGE | |
ALERT_DESCRIPTION_BAD_RECORD_MAC | |
ALERT_DESCRIPTION_RECORD_OVERFLOW | |
ALERT_DESCRIPTION_DECOMPRESSION_FAILURE | |
ALERT_DESCRIPTION_HANDSHAKE_FAILURE | |
ALERT_DESCRIPTION_BAD_CERTIFICATE | |
ALERT_DESCRIPTION_UNSUPPORTED_CERTIFICATE | |
ALERT_DESCRIPTION_CERTIFICATE_REVOKED | |
ALERT_DESCRIPTION_CERTIFICATE_EXPIRED | |
ALERT_DESCRIPTION_CERTIFICATE_UNKNOWN | |
ALERT_DESCRIPTION_ILLEGAL_PARAMETER | |
ALERT_DESCRIPTION_UNKNOWN_CA | |
ALERT_DESCRIPTION_ACCESS_DENIED | |
ALERT_DESCRIPTION_DECODE_ERROR | |
ALERT_DESCRIPTION_DECRYPT_ERROR | |
ALERT_DESCRIPTION_PROTOCOL_VERSION | |
ALERT_DESCRIPTION_INSUFFICIENT_SECURITY | |
ALERT_DESCRIPTION_INTERNAL_ERROR | |
ALERT_DESCRIPTION_USER_CANCELLED | |
ALERT_DESCRIPTION_NO_RENEGOTIATION | |
ALERT_DESCRIPTION_UNSUPPORTED_EXTENSION | |
ALERT_DESCRIPTION_CERTIFICATE_UNOBTAINABLE | |
ALERT_DESCRIPTION_UNRECOGNIZED_NAME | |
ALERT_DESCRIPTION_BAD_CERTIFICATE_STATUS_RESPONSE | |
ALERT_DESCRIPTION_BAD_CERTIFICATE_HASH_VALUE | |
ALERT_DESCRIPTION_UNKNOWN_PSK_IDENTITY | |
""" | |
import sys | |
import os | |
from collectionsimport namedtuple | |
from enumimport Enumas _Enum, IntEnumas _IntEnum, IntFlagas _IntFlag | |
import _ssl# if we can't import it, let the error propagate | |
from _sslimportOPENSSL_VERSION_NUMBER,OPENSSL_VERSION_INFO,OPENSSL_VERSION | |
from _sslimport _SSLContext, MemoryBIO, SSLSession | |
from _sslimport ( | |
SSLError, SSLZeroReturnError, SSLWantReadError, SSLWantWriteError, | |
SSLSyscallError, SSLEOFError, SSLCertVerificationError | |
) | |
from _sslimport txt2objas _txt2obj, nid2objas _nid2obj | |
from _sslimportRAND_status,RAND_add,RAND_bytes,RAND_pseudo_bytes | |
try: | |
from _sslimportRAND_egd | |
exceptImportError: | |
# LibreSSL does not provide RAND_egd | |
pass | |
from _sslimport ( | |
HAS_SNI,HAS_ECDH,HAS_NPN,HAS_ALPN,HAS_SSLv2,HAS_SSLv3,HAS_TLSv1, | |
HAS_TLSv1_1,HAS_TLSv1_2,HAS_TLSv1_3 | |
) | |
from _sslimport_DEFAULT_CIPHERS,_OPENSSL_API_VERSION | |
_IntEnum._convert_( | |
'_SSLMethod',__name__, | |
lambdaname: name.startswith('PROTOCOL_')and name!='PROTOCOL_SSLv23', | |
source=_ssl) | |
_IntFlag._convert_( | |
'Options',__name__, | |
lambdaname: name.startswith('OP_'), | |
source=_ssl) | |
_IntEnum._convert_( | |
'AlertDescription',__name__, | |
lambdaname: name.startswith('ALERT_DESCRIPTION_'), | |
source=_ssl) | |
_IntEnum._convert_( | |
'SSLErrorNumber',__name__, | |
lambdaname: name.startswith('SSL_ERROR_'), | |
source=_ssl) | |
_IntFlag._convert_( | |
'VerifyFlags',__name__, | |
lambdaname: name.startswith('VERIFY_'), | |
source=_ssl) | |
_IntEnum._convert_( | |
'VerifyMode',__name__, | |
lambdaname: name.startswith('CERT_'), | |
source=_ssl) | |
PROTOCOL_SSLv23= _SSLMethod.PROTOCOL_SSLv23= _SSLMethod.PROTOCOL_TLS | |
_PROTOCOL_NAMES= {value: namefor name, valuein _SSLMethod.__members__.items()} | |
_SSLv2_IF_EXISTS=getattr(_SSLMethod,'PROTOCOL_SSLv2',None) | |
classTLSVersion(_IntEnum): | |
MINIMUM_SUPPORTED= _ssl.PROTO_MINIMUM_SUPPORTED | |
SSLv3= _ssl.PROTO_SSLv3 | |
TLSv1= _ssl.PROTO_TLSv1 | |
TLSv1_1= _ssl.PROTO_TLSv1_1 | |
TLSv1_2= _ssl.PROTO_TLSv1_2 | |
TLSv1_3= _ssl.PROTO_TLSv1_3 | |
MAXIMUM_SUPPORTED= _ssl.PROTO_MAXIMUM_SUPPORTED | |
if sys.platform=="win32": | |
from _sslimport enum_certificates, enum_crls | |
from socketimport socket,AF_INET,SOCK_STREAM, create_connection | |
from socketimportSOL_SOCKET,SO_TYPE | |
import socketas _socket | |
import base64# for DER-to-PEM translation | |
import errno | |
import warnings | |
socket_error=OSError# keep that public name in module namespace | |
CHANNEL_BINDING_TYPES= ['tls-unique'] | |
HAS_NEVER_CHECK_COMMON_NAME=hasattr(_ssl,'HOSTFLAG_NEVER_CHECK_SUBJECT') | |
_RESTRICTED_SERVER_CIPHERS=_DEFAULT_CIPHERS | |
CertificateError= SSLCertVerificationError | |
def_dnsname_match(dn,hostname): | |
"""Matching according to RFC 6125, section 6.4.3 | |
- Hostnames are compared lower case. | |
- For IDNA, both dn and hostname must be encoded as IDN A-label (ACE). | |
- Partial wildcards like 'www*.example.org', multiple wildcards, sole | |
wildcard or wildcards in labels other then the left-most label are not | |
supported and a CertificateError is raised. | |
- A wildcard must match at least one character. | |
""" | |
ifnot dn: | |
returnFalse | |
wildcards= dn.count('*') | |
# speed up common case w/o wildcards | |
ifnot wildcards: | |
return dn.lower()== hostname.lower() | |
if wildcards>1: | |
raise CertificateError( | |
"too many wildcards in certificate DNS name:{!r}.".format(dn)) | |
dn_leftmost, sep, dn_remainder= dn.partition('.') | |
if'*'in dn_remainder: | |
# Only match wildcard in leftmost segment. | |
raise CertificateError( | |
"wildcard can only be present in the leftmost label:" | |
"{!r}.".format(dn)) | |
ifnot sep: | |
# no right side | |
raise CertificateError( | |
"sole wildcard without additional labels are not support:" | |
"{!r}.".format(dn)) | |
if dn_leftmost!='*': | |
# no partial wildcard matching | |
raise CertificateError( | |
"partial wildcards in leftmost label are not supported:" | |
"{!r}.".format(dn)) | |
hostname_leftmost, sep, hostname_remainder= hostname.partition('.') | |
ifnot hostname_leftmostornot sep: | |
# wildcard must match at least one char | |
returnFalse | |
return dn_remainder.lower()== hostname_remainder.lower() | |
def_inet_paton(ipname): | |
"""Try to convert an IP address to packed binary form | |
Supports IPv4 addresses on all platforms and IPv6 on platforms with IPv6 | |
support. | |
""" | |
# inet_aton() also accepts strings like '1' | |
if ipname.count('.')==3: | |
try: | |
return _socket.inet_aton(ipname) | |
exceptOSError: | |
pass | |
try: | |
return _socket.inet_pton(_socket.AF_INET6, ipname) | |
exceptOSError: | |
raiseValueError("{!r} is neither an IPv4 nor an IP6" | |
"address.".format(ipname)) | |
exceptAttributeError: | |
# AF_INET6 not available | |
pass | |
raiseValueError("{!r} is not an IPv4 address.".format(ipname)) | |
def_ipaddress_match(ipname,host_ip): | |
"""Exact matching of IP addresses. | |
RFC 6125 explicitly doesn't define an algorithm for this | |
(section 1.7.2 - "Out of Scope"). | |
""" | |
# OpenSSL may add a trailing newline to a subjectAltName's IP address | |
ip= _inet_paton(ipname.rstrip()) | |
return ip== host_ip | |
defmatch_hostname(cert,hostname): | |
"""Verify that *cert* (in decoded format as returned by | |
SSLSocket.getpeercert()) matches the *hostname*. RFC 2818 and RFC 6125 | |
rules are followed. | |
The function matches IP addresses rather than dNSNames if hostname is a | |
valid ipaddress string. IPv4 addresses are supported on all platforms. | |
IPv6 addresses are supported on platforms with IPv6 support (AF_INET6 | |
and inet_pton). | |
CertificateError is raised on failure. On success, the function | |
returns nothing. | |
""" | |
ifnot cert: | |
raiseValueError("empty or no certificate, match_hostname needs a" | |
"SSL socket or SSL context with either" | |
"CERT_OPTIONAL or CERT_REQUIRED") | |
try: | |
host_ip= _inet_paton(hostname) | |
exceptValueError: | |
# Not an IP address (common case) | |
host_ip=None | |
dnsnames= [] | |
san= cert.get('subjectAltName', ()) | |
for key, valuein san: | |
if key=='DNS': | |
if host_ipisNoneand _dnsname_match(value, hostname): | |
return | |
dnsnames.append(value) | |
elif key=='IP Address': | |
if host_ipisnotNoneand _ipaddress_match(value, host_ip): | |
return | |
dnsnames.append(value) | |
ifnot dnsnames: | |
# The subject is only checked when there is no dNSName entry | |
# in subjectAltName | |
for subin cert.get('subject', ()): | |
for key, valuein sub: | |
#XXX according to RFC 2818, the most specific Common Name | |
# must be used. | |
if key=='commonName': | |
if _dnsname_match(value, hostname): | |
return | |
dnsnames.append(value) | |
iflen(dnsnames)>1: | |
raise CertificateError("hostname%r" | |
"doesn't match either of%s" | |
% (hostname,','.join(map(repr, dnsnames)))) | |
eliflen(dnsnames)==1: | |
raise CertificateError("hostname%r" | |
"doesn't match%r" | |
% (hostname, dnsnames[0])) | |
else: | |
raise CertificateError("no appropriate commonName or" | |
"subjectAltName fields were found") | |
DefaultVerifyPaths= namedtuple("DefaultVerifyPaths", | |
"cafile capath openssl_cafile_env openssl_cafile openssl_capath_env" | |
"openssl_capath") | |
defget_default_verify_paths(): | |
"""Return paths to default cafile and capath. | |
""" | |
parts= _ssl.get_default_verify_paths() | |
# environment vars shadow paths | |
cafile= os.environ.get(parts[0], parts[1]) | |
capath= os.environ.get(parts[2], parts[3]) | |
return DefaultVerifyPaths(cafileif os.path.isfile(cafile)elseNone, | |
capathif os.path.isdir(capath)elseNone, | |
*parts) | |
class_ASN1Object(namedtuple("_ASN1Object","nid shortname longname oid")): | |
"""ASN.1 object identifier lookup | |
""" | |
__slots__= () | |
def__new__(cls,oid): | |
returnsuper().__new__(cls,*_txt2obj(oid,name=False)) | |
@classmethod | |
deffromnid(cls,nid): | |
"""Create _ASN1Object from OpenSSL numeric ID | |
""" | |
returnsuper().__new__(cls,*_nid2obj(nid)) | |
@classmethod | |
deffromname(cls,name): | |
"""Create _ASN1Object from short name, long name or OID | |
""" | |
returnsuper().__new__(cls,*_txt2obj(name,name=True)) | |
classPurpose(_ASN1Object,_Enum): | |
"""SSLContext purpose flags with X509v3 Extended Key Usage objects | |
""" | |
SERVER_AUTH='1.3.6.1.5.5.7.3.1' | |
CLIENT_AUTH='1.3.6.1.5.5.7.3.2' | |
classSSLContext(_SSLContext): | |
"""An SSLContext holds various SSL-related configuration options and | |
data, such as certificates and possibly a private key.""" | |
_windows_cert_stores= ("CA","ROOT") | |
sslsocket_class=None# SSLSocket is assigned later. | |
sslobject_class=None# SSLObject is assigned later. | |
def__new__(cls,protocol=PROTOCOL_TLS,*args,**kwargs): | |
self= _SSLContext.__new__(cls, protocol) | |
returnself | |
def_encode_hostname(self,hostname): | |
if hostnameisNone: | |
returnNone | |
elifisinstance(hostname,str): | |
return hostname.encode('idna').decode('ascii') | |
else: | |
return hostname.decode('ascii') | |
defwrap_socket(self,sock,server_side=False, | |
do_handshake_on_connect=True, | |
suppress_ragged_eofs=True, | |
server_hostname=None,session=None): | |
# SSLSocket class handles server_hostname encoding before it calls | |
# ctx._wrap_socket() | |
returnself.sslsocket_class._create( | |
sock=sock, | |
server_side=server_side, | |
do_handshake_on_connect=do_handshake_on_connect, | |
suppress_ragged_eofs=suppress_ragged_eofs, | |
server_hostname=server_hostname, | |
context=self, | |
session=session | |
) | |
defwrap_bio(self,incoming,outgoing,server_side=False, | |
server_hostname=None,session=None): | |
# Need to encode server_hostname here because _wrap_bio() can only | |
# handle ASCII str. | |
returnself.sslobject_class._create( | |
incoming, outgoing,server_side=server_side, | |
server_hostname=self._encode_hostname(server_hostname), | |
session=session,context=self, | |
) | |
defset_npn_protocols(self,npn_protocols): | |
protos=bytearray() | |
for protocolin npn_protocols: | |
b=bytes(protocol,'ascii') | |
iflen(b)==0orlen(b)>255: | |
raise SSLError('NPN protocols must be 1 to 255 in length') | |
protos.append(len(b)) | |
protos.extend(b) | |
self._set_npn_protocols(protos) | |
defset_servername_callback(self,server_name_callback): | |
if server_name_callbackisNone: | |
self.sni_callback=None | |
else: | |
ifnotcallable(server_name_callback): | |
raiseTypeError("not a callable object") | |
defshim_cb(sslobj,servername,sslctx): | |
servername=self._encode_hostname(servername) | |
return server_name_callback(sslobj, servername, sslctx) | |
self.sni_callback= shim_cb | |
defset_alpn_protocols(self,alpn_protocols): | |
protos=bytearray() | |
for protocolin alpn_protocols: | |
b=bytes(protocol,'ascii') | |
iflen(b)==0orlen(b)>255: | |
raise SSLError('ALPN protocols must be 1 to 255 in length') | |
protos.append(len(b)) | |
protos.extend(b) | |
self._set_alpn_protocols(protos) | |
def_load_windows_store_certs(self,storename,purpose): | |
certs=bytearray() | |
try: | |
for cert, encoding, trustin enum_certificates(storename): | |
# CA certs are never PKCS#7 encoded | |
if encoding=="x509_asn": | |
if trustisTrueor purpose.oidin trust: | |
certs.extend(cert) | |
exceptPermissionError: | |
warnings.warn("unable to enumerate Windows certificate store") | |
if certs: | |
self.load_verify_locations(cadata=certs) | |
return certs | |
defload_default_certs(self,purpose=Purpose.SERVER_AUTH): | |
ifnotisinstance(purpose, _ASN1Object): | |
raiseTypeError(purpose) | |
if sys.platform=="win32": | |
for storenameinself._windows_cert_stores: | |
self._load_windows_store_certs(storename, purpose) | |
self.set_default_verify_paths() | |
ifhasattr(_SSLContext,'minimum_version'): | |
@property | |
defminimum_version(self): | |
return TLSVersion(super().minimum_version) | |
@minimum_version.setter | |
defminimum_version(self,value): | |
if value== TLSVersion.SSLv3: | |
self.options&=~Options.OP_NO_SSLv3 | |
super(SSLContext, SSLContext).minimum_version.__set__(self, value) | |
@property | |
defmaximum_version(self): | |
return TLSVersion(super().maximum_version) | |
@maximum_version.setter | |
defmaximum_version(self,value): | |
super(SSLContext, SSLContext).maximum_version.__set__(self, value) | |
@property | |
defoptions(self): | |
return Options(super().options) | |
@options.setter | |
defoptions(self,value): | |
super(SSLContext, SSLContext).options.__set__(self, value) | |
ifhasattr(_ssl,'HOSTFLAG_NEVER_CHECK_SUBJECT'): | |
@property | |
defhostname_checks_common_name(self): | |
ncs=self._host_flags& _ssl.HOSTFLAG_NEVER_CHECK_SUBJECT | |
return ncs!= _ssl.HOSTFLAG_NEVER_CHECK_SUBJECT | |
@hostname_checks_common_name.setter | |
defhostname_checks_common_name(self,value): | |
if value: | |
self._host_flags&=~_ssl.HOSTFLAG_NEVER_CHECK_SUBJECT | |
else: | |
self._host_flags|= _ssl.HOSTFLAG_NEVER_CHECK_SUBJECT | |
else: | |
@property | |
defhostname_checks_common_name(self): | |
returnTrue | |
@property | |
defprotocol(self): | |
return _SSLMethod(super().protocol) | |
@property | |
defverify_flags(self): | |
return VerifyFlags(super().verify_flags) | |
@verify_flags.setter | |
defverify_flags(self,value): | |
super(SSLContext, SSLContext).verify_flags.__set__(self, value) | |
@property | |
defverify_mode(self): | |
value=super().verify_mode | |
try: | |
return VerifyMode(value) | |
exceptValueError: | |
return value | |
@verify_mode.setter | |
defverify_mode(self,value): | |
super(SSLContext, SSLContext).verify_mode.__set__(self, value) | |
defcreate_default_context(purpose=Purpose.SERVER_AUTH,*,cafile=None, | |
capath=None,cadata=None): | |
"""Create a SSLContext object with default settings. | |
NOTE: The protocol and settings may change anytime without prior | |
deprecation. The values represent a fair balance between maximum | |
compatibility and security. | |
""" | |
ifnotisinstance(purpose, _ASN1Object): | |
raiseTypeError(purpose) | |
# SSLContext sets OP_NO_SSLv2, OP_NO_SSLv3, OP_NO_COMPRESSION, | |
# OP_CIPHER_SERVER_PREFERENCE, OP_SINGLE_DH_USE and OP_SINGLE_ECDH_USE | |
# by default. | |
context= SSLContext(PROTOCOL_TLS) | |
if purpose== Purpose.SERVER_AUTH: | |
# verify certs and host name in client mode | |
context.verify_mode=CERT_REQUIRED | |
context.check_hostname=True | |
if cafileor capathor cadata: | |
context.load_verify_locations(cafile, capath, cadata) | |
elif context.verify_mode!=CERT_NONE: | |
# no explicit cafile, capath or cadata but the verify mode is | |
# CERT_OPTIONAL or CERT_REQUIRED. Let's try to load default system | |
# root CA certificates for the given purpose. This may fail silently. | |
context.load_default_certs(purpose) | |
return context | |
def_create_unverified_context(protocol=PROTOCOL_TLS,*,cert_reqs=CERT_NONE, | |
check_hostname=False,purpose=Purpose.SERVER_AUTH, | |
certfile=None,keyfile=None, | |
cafile=None,capath=None,cadata=None): | |
"""Create a SSLContext object for Python stdlib modules | |
All Python stdlib modules shall use this function to create SSLContext | |
objects in order to keep common settings in one place. The configuration | |
is less restrict than create_default_context()'s to increase backward | |
compatibility. | |
""" | |
ifnotisinstance(purpose, _ASN1Object): | |
raiseTypeError(purpose) | |
# SSLContext sets OP_NO_SSLv2, OP_NO_SSLv3, OP_NO_COMPRESSION, | |
# OP_CIPHER_SERVER_PREFERENCE, OP_SINGLE_DH_USE and OP_SINGLE_ECDH_USE | |
# by default. | |
context= SSLContext(protocol) | |
ifnot check_hostname: | |
context.check_hostname=False | |
if cert_reqsisnotNone: | |
context.verify_mode= cert_reqs | |
if check_hostname: | |
context.check_hostname=True | |
if keyfileandnot certfile: | |
raiseValueError("certfile must be specified") | |
if certfileor keyfile: | |
context.load_cert_chain(certfile, keyfile) | |
# load CA root certs | |
if cafileor capathor cadata: | |
context.load_verify_locations(cafile, capath, cadata) | |
elif context.verify_mode!=CERT_NONE: | |
# no explicit cafile, capath or cadata but the verify mode is | |
# CERT_OPTIONAL or CERT_REQUIRED. Let's try to load default system | |
# root CA certificates for the given purpose. This may fail silently. | |
context.load_default_certs(purpose) | |
return context | |
# Used by http.client if no context is explicitly passed. | |
_create_default_https_context= create_default_context | |
# Backwards compatibility alias, even though it's not a public name. | |
_create_stdlib_context= _create_unverified_context | |
classSSLObject: | |
"""This class implements an interface on top of a low-level SSL object as | |
implemented by OpenSSL. This object captures the state of an SSL connection | |
but does not provide any network IO itself. IO needs to be performed | |
through separate "BIO" objects which are OpenSSL's IO abstraction layer. | |
This class does not have a public constructor. Instances are returned by | |
``SSLContext.wrap_bio``. This class is typically used by framework authors | |
that want to implement asynchronous IO for SSL through memory buffers. | |
When compared to ``SSLSocket``, this object lacks the following features: | |
* Any form of network IO, including methods such as ``recv`` and ``send``. | |
* The ``do_handshake_on_connect`` and ``suppress_ragged_eofs`` machinery. | |
""" | |
def__init__(self,*args,**kwargs): | |
raiseTypeError( | |
f"{self.__class__.__name__} does not have a public" | |
f"constructor. Instances are returned by SSLContext.wrap_bio()." | |
) | |
@classmethod | |
def_create(cls,incoming,outgoing,server_side=False, | |
server_hostname=None,session=None,context=None): | |
self=cls.__new__(cls) | |
sslobj= context._wrap_bio( | |
incoming, outgoing,server_side=server_side, | |
server_hostname=server_hostname, | |
owner=self,session=session | |
) | |
self._sslobj= sslobj | |
returnself | |
@property | |
defcontext(self): | |
"""The SSLContext that is currently in use.""" | |
returnself._sslobj.context | |
@context.setter | |
defcontext(self,ctx): | |
self._sslobj.context= ctx | |
@property | |
defsession(self): | |
"""The SSLSession for client socket.""" | |
returnself._sslobj.session | |
@session.setter | |
defsession(self,session): | |
self._sslobj.session= session | |
@property | |
defsession_reused(self): | |
"""Was the client session reused during handshake""" | |
returnself._sslobj.session_reused | |
@property | |
defserver_side(self): | |
"""Whether this is a server-side socket.""" | |
returnself._sslobj.server_side | |
@property | |
defserver_hostname(self): | |
"""The currently set server hostname (for SNI), or ``None`` if no | |
server hostame is set.""" | |
returnself._sslobj.server_hostname | |
defread(self,len=1024,buffer=None): | |
"""Read up to 'len' bytes from the SSL object and return them. | |
If 'buffer' is provided, read into this buffer and return the number of | |
bytes read. | |
""" | |
if bufferisnotNone: | |
v=self._sslobj.read(len, buffer) | |
else: | |
v=self._sslobj.read(len) | |
return v | |
defwrite(self,data): | |
"""Write 'data' to the SSL object and return the number of bytes | |
written. | |
The 'data' argument must support the buffer interface. | |
""" | |
returnself._sslobj.write(data) | |
defgetpeercert(self,binary_form=False): | |
"""Returns a formatted version of the data in the certificate provided | |
by the other end of the SSL channel. | |
Return None if no certificate was provided, {} if a certificate was | |
provided, but not validated. | |
""" | |
returnself._sslobj.getpeercert(binary_form) | |
defselected_npn_protocol(self): | |
"""Return the currently selected NPN protocol as a string, or ``None`` | |
if a next protocol was not negotiated or if NPN is not supported by one | |
of the peers.""" | |
if _ssl.HAS_NPN: | |
returnself._sslobj.selected_npn_protocol() | |
defselected_alpn_protocol(self): | |
"""Return the currently selected ALPN protocol as a string, or ``None`` | |
if a next protocol was not negotiated or if ALPN is not supported by one | |
of the peers.""" | |
if _ssl.HAS_ALPN: | |
returnself._sslobj.selected_alpn_protocol() | |
defcipher(self): | |
"""Return the currently selected cipher as a 3-tuple ``(name, | |
ssl_version, secret_bits)``.""" | |
returnself._sslobj.cipher() | |
defshared_ciphers(self): | |
"""Return a list of ciphers shared by the client during the handshake or | |
None if this is not a valid server connection. | |
""" | |
returnself._sslobj.shared_ciphers() | |
defcompression(self): | |
"""Return the current compression algorithm in use, or ``None`` if | |
compression was not negotiated or not supported by one of the peers.""" | |
returnself._sslobj.compression() | |
defpending(self): | |
"""Return the number of bytes that can be read immediately.""" | |
returnself._sslobj.pending() | |
defdo_handshake(self): | |
"""Start the SSL/TLS handshake.""" | |
self._sslobj.do_handshake() | |
defunwrap(self): | |
"""Start the SSL shutdown handshake.""" | |
returnself._sslobj.shutdown() | |
defget_channel_binding(self,cb_type="tls-unique"): | |
"""Get channel binding data for current connection. Raise ValueError | |
if the requested `cb_type` is not supported. Return bytes of the data | |
or None if the data is not available (e.g. before the handshake).""" | |
returnself._sslobj.get_channel_binding(cb_type) | |
defversion(self): | |
"""Return a string identifying the protocol version used by the | |
current SSL channel.""" | |
returnself._sslobj.version() | |
defverify_client_post_handshake(self): | |
returnself._sslobj.verify_client_post_handshake() | |
classSSLSocket(socket): | |
"""This class implements a subtype of socket.socket that wraps | |
the underlying OS socket in an SSL context when necessary, and | |
provides read and write methods over that channel.""" | |
def__init__(self,*args,**kwargs): | |
raiseTypeError( | |
f"{self.__class__.__name__} does not have a public" | |
f"constructor. Instances are returned by" | |
f"SSLContext.wrap_socket()." | |
) | |
@classmethod | |
def_create(cls,sock,server_side=False,do_handshake_on_connect=True, | |
suppress_ragged_eofs=True,server_hostname=None, | |
context=None,session=None): | |
if sock.getsockopt(SOL_SOCKET,SO_TYPE)!=SOCK_STREAM: | |
raiseNotImplementedError("only stream sockets are supported") | |
if server_side: | |
if server_hostname: | |
raiseValueError("server_hostname can only be specified" | |
"in client mode") | |
if sessionisnotNone: | |
raiseValueError("session can only be specified in" | |
"client mode") | |
if context.check_hostnameandnot server_hostname: | |
raiseValueError("check_hostname requires server_hostname") | |
kwargs=dict( | |
family=sock.family,type=sock.type,proto=sock.proto, | |
fileno=sock.fileno() | |
) | |
self=cls.__new__(cls,**kwargs) | |
super(SSLSocket,self).__init__(**kwargs) | |
self.settimeout(sock.gettimeout()) | |
sock.detach() | |
self._context= context | |
self._session= session | |
self._closed=False | |
self._sslobj=None | |
self.server_side= server_side | |
self.server_hostname= context._encode_hostname(server_hostname) | |
self.do_handshake_on_connect= do_handshake_on_connect | |
self.suppress_ragged_eofs= suppress_ragged_eofs | |
# See if we are connected | |
try: | |
self.getpeername() | |
exceptOSErroras e: | |
if e.errno!= errno.ENOTCONN: | |
raise | |
connected=False | |
else: | |
connected=True | |
self._connected= connected | |
if connected: | |
# create the SSL object | |
try: | |
self._sslobj=self._context._wrap_socket( | |
self, server_side,self.server_hostname, | |
owner=self,session=self._session, | |
) | |
if do_handshake_on_connect: | |
timeout=self.gettimeout() | |
if timeout==0.0: | |
# non-blocking | |
raiseValueError("do_handshake_on_connect should not be specified for non-blocking sockets") | |
self.do_handshake() | |
except (OSError,ValueError): | |
self.close() | |
raise | |
returnself | |
@property | |
defcontext(self): | |
returnself._context | |
@context.setter | |
defcontext(self,ctx): | |
self._context= ctx | |
self._sslobj.context= ctx | |
@property | |
defsession(self): | |
"""The SSLSession for client socket.""" | |
ifself._sslobjisnotNone: | |
returnself._sslobj.session | |
@session.setter | |
defsession(self,session): | |
self._session= session | |
ifself._sslobjisnotNone: | |
self._sslobj.session= session | |
@property | |
defsession_reused(self): | |
"""Was the client session reused during handshake""" | |
ifself._sslobjisnotNone: | |
returnself._sslobj.session_reused | |
defdup(self): | |
raiseNotImplemented("Can't dup()%s instances"% | |
self.__class__.__name__) | |
def_checkClosed(self,msg=None): | |
# raise an exception here if you wish to check for spurious closes | |
pass | |
def_check_connected(self): | |
ifnotself._connected: | |
# getpeername() will raise ENOTCONN if the socket is really | |
# not connected; note that we can be connected even without | |
# _connected being set, e.g. if connect() first returned | |
# EAGAIN. | |
self.getpeername() | |
defread(self,len=1024,buffer=None): | |
"""Read up to LEN bytes and return them. | |
Return zero-length string on EOF.""" | |
self._checkClosed() | |
ifself._sslobjisNone: | |
raiseValueError("Read on closed or unwrapped SSL socket.") | |
try: | |
if bufferisnotNone: | |
returnself._sslobj.read(len, buffer) | |
else: | |
returnself._sslobj.read(len) | |
except SSLErroras x: | |
if x.args[0]==SSL_ERROR_EOFandself.suppress_ragged_eofs: | |
if bufferisnotNone: | |
return0 | |
else: | |
returnb'' | |
else: | |
raise | |
defwrite(self,data): | |
"""Write DATA to the underlying SSL channel. Returns | |
number of bytes of DATA actually transmitted.""" | |
self._checkClosed() | |
ifself._sslobjisNone: | |
raiseValueError("Write on closed or unwrapped SSL socket.") | |
returnself._sslobj.write(data) | |
defgetpeercert(self,binary_form=False): | |
"""Returns a formatted version of the data in the | |
certificate provided by the other end of the SSL channel. | |
Return None if no certificate was provided, {} if a | |
certificate was provided, but not validated.""" | |
self._checkClosed() | |
self._check_connected() | |
returnself._sslobj.getpeercert(binary_form) | |
defselected_npn_protocol(self): | |
self._checkClosed() | |
ifself._sslobjisNoneornot _ssl.HAS_NPN: | |
returnNone | |
else: | |
returnself._sslobj.selected_npn_protocol() | |
defselected_alpn_protocol(self): | |
self._checkClosed() | |
ifself._sslobjisNoneornot _ssl.HAS_ALPN: | |
returnNone | |
else: | |
returnself._sslobj.selected_alpn_protocol() | |
defcipher(self): | |
self._checkClosed() | |
ifself._sslobjisNone: | |
returnNone | |
else: | |
returnself._sslobj.cipher() | |
defshared_ciphers(self): | |
self._checkClosed() | |
ifself._sslobjisNone: | |
returnNone | |
else: | |
returnself._sslobj.shared_ciphers() | |
defcompression(self): | |
self._checkClosed() | |
ifself._sslobjisNone: | |
returnNone | |
else: | |
returnself._sslobj.compression() | |
defsend(self,data,flags=0): | |
self._checkClosed() | |
ifself._sslobjisnotNone: | |
if flags!=0: | |
raiseValueError( | |
"non-zero flags not allowed in calls to send() on%s"% | |
self.__class__) | |
returnself._sslobj.write(data) | |
else: | |
returnsuper().send(data, flags) | |
defsendto(self,data,flags_or_addr,addr=None): | |
self._checkClosed() | |
ifself._sslobjisnotNone: | |
raiseValueError("sendto not allowed on instances of%s"% | |
self.__class__) | |
elif addrisNone: | |
returnsuper().sendto(data, flags_or_addr) | |
else: | |
returnsuper().sendto(data, flags_or_addr, addr) | |
defsendmsg(self,*args,**kwargs): | |
# Ensure programs don't send data unencrypted if they try to | |
# use this method. | |
raiseNotImplementedError("sendmsg not allowed on instances of%s"% | |
self.__class__) | |
defsendall(self,data,flags=0): | |
self._checkClosed() | |
ifself._sslobjisnotNone: | |
if flags!=0: | |
raiseValueError( | |
"non-zero flags not allowed in calls to sendall() on%s"% | |
self.__class__) | |
count=0 | |
withmemoryview(data)as view, view.cast("B")as byte_view: | |
amount=len(byte_view) | |
while count< amount: | |
v=self.send(byte_view[count:]) | |
count+= v | |
else: | |
returnsuper().sendall(data, flags) | |
defsendfile(self,file,offset=0,count=None): | |
"""Send a file, possibly by using os.sendfile() if this is a | |
clear-text socket. Return the total number of bytes sent. | |
""" | |
ifself._sslobjisnotNone: | |
returnself._sendfile_use_send(file, offset, count) | |
else: | |
# os.sendfile() works with plain sockets only | |
returnsuper().sendfile(file, offset, count) | |
defrecv(self,buflen=1024,flags=0): | |
self._checkClosed() | |
ifself._sslobjisnotNone: | |
if flags!=0: | |
raiseValueError( | |
"non-zero flags not allowed in calls to recv() on%s"% | |
self.__class__) | |
returnself.read(buflen) | |
else: | |
returnsuper().recv(buflen, flags) | |
defrecv_into(self,buffer,nbytes=None,flags=0): | |
self._checkClosed() | |
if bufferand (nbytesisNone): | |
nbytes=len(buffer) | |
elif nbytesisNone: | |
nbytes=1024 | |
ifself._sslobjisnotNone: | |
if flags!=0: | |
raiseValueError( | |
"non-zero flags not allowed in calls to recv_into() on%s"% | |
self.__class__) | |
returnself.read(nbytes, buffer) | |
else: | |
returnsuper().recv_into(buffer, nbytes, flags) | |
defrecvfrom(self,buflen=1024,flags=0): | |
self._checkClosed() | |
ifself._sslobjisnotNone: | |
raiseValueError("recvfrom not allowed on instances of%s"% | |
self.__class__) | |
else: | |
returnsuper().recvfrom(buflen, flags) | |
defrecvfrom_into(self,buffer,nbytes=None,flags=0): | |
self._checkClosed() | |
ifself._sslobjisnotNone: | |
raiseValueError("recvfrom_into not allowed on instances of%s"% | |
self.__class__) | |
else: | |
returnsuper().recvfrom_into(buffer, nbytes, flags) | |
defrecvmsg(self,*args,**kwargs): | |
raiseNotImplementedError("recvmsg not allowed on instances of%s"% | |
self.__class__) | |
defrecvmsg_into(self,*args,**kwargs): | |
raiseNotImplementedError("recvmsg_into not allowed on instances of" | |
"%s"%self.__class__) | |
defpending(self): | |
self._checkClosed() | |
ifself._sslobjisnotNone: | |
returnself._sslobj.pending() | |
else: | |
return0 | |
defshutdown(self,how): | |
self._checkClosed() | |
self._sslobj=None | |
super().shutdown(how) | |
defunwrap(self): | |
ifself._sslobj: | |
s=self._sslobj.shutdown() | |
self._sslobj=None | |
return s | |
else: | |
raiseValueError("No SSL wrapper around"+str(self)) | |
defverify_client_post_handshake(self): | |
ifself._sslobj: | |
returnself._sslobj.verify_client_post_handshake() | |
else: | |
raiseValueError("No SSL wrapper around"+str(self)) | |
def_real_close(self): | |
self._sslobj=None | |
super()._real_close() | |
defdo_handshake(self,block=False): | |
"""Perform a TLS/SSL handshake.""" | |
self._check_connected() | |
timeout=self.gettimeout() | |
try: | |
if timeout==0.0and block: | |
self.settimeout(None) | |
self._sslobj.do_handshake() | |
finally: | |
self.settimeout(timeout) | |
def_real_connect(self,addr,connect_ex): | |
ifself.server_side: | |
raiseValueError("can't connect in server-side mode") | |
# Here we assume that the socket is client-side, and not | |
# connected at the time of the call. We connect it, then wrap it. | |
ifself._connectedorself._sslobjisnotNone: | |
raiseValueError("attempt to connect already-connected SSLSocket!") | |
self._sslobj=self.context._wrap_socket( | |
self,False,self.server_hostname, | |
owner=self,session=self._session | |
) | |
try: | |
if connect_ex: | |
rc=super().connect_ex(addr) | |
else: | |
rc=None | |
super().connect(addr) | |
ifnot rc: | |
self._connected=True | |
ifself.do_handshake_on_connect: | |
self.do_handshake() | |
return rc | |
except (OSError,ValueError): | |
self._sslobj=None | |
raise | |
defconnect(self,addr): | |
"""Connects to remote ADDR, and then wraps the connection in | |
an SSL channel.""" | |
self._real_connect(addr,False) | |
defconnect_ex(self,addr): | |
"""Connects to remote ADDR, and then wraps the connection in | |
an SSL channel.""" | |
returnself._real_connect(addr,True) | |
defaccept(self): | |
"""Accepts a new connection from a remote client, and returns | |
a tuple containing that new connection wrapped with a server-side | |
SSL channel, and the address of the remote client.""" | |
newsock, addr=super().accept() | |
newsock=self.context.wrap_socket(newsock, | |
do_handshake_on_connect=self.do_handshake_on_connect, | |
suppress_ragged_eofs=self.suppress_ragged_eofs, | |
server_side=True) | |
return newsock, addr | |
defget_channel_binding(self,cb_type="tls-unique"): | |
"""Get channel binding data for current connection. Raise ValueError | |
if the requested `cb_type` is not supported. Return bytes of the data | |
or None if the data is not available (e.g. before the handshake). | |
""" | |
ifself._sslobjisnotNone: | |
returnself._sslobj.get_channel_binding(cb_type) | |
else: | |
if cb_typenotinCHANNEL_BINDING_TYPES: | |
raiseValueError( | |
"{0} channel binding type not implemented".format(cb_type) | |
) | |
returnNone | |
defversion(self): | |
""" | |
Return a string identifying the protocol version used by the | |
current SSL channel, or None if there is no established channel. | |
""" | |
ifself._sslobjisnotNone: | |
returnself._sslobj.version() | |
else: | |
returnNone | |
# Python does not support forward declaration of types. | |
SSLContext.sslsocket_class= SSLSocket | |
SSLContext.sslobject_class= SSLObject | |
defwrap_socket(sock,keyfile=None,certfile=None, | |
server_side=False,cert_reqs=CERT_NONE, | |
ssl_version=PROTOCOL_TLS,ca_certs=None, | |
do_handshake_on_connect=True, | |
suppress_ragged_eofs=True, | |
ciphers=None): | |
if server_sideandnot certfile: | |
raiseValueError("certfile must be specified for server-side" | |
"operations") | |
if keyfileandnot certfile: | |
raiseValueError("certfile must be specified") | |
context= SSLContext(ssl_version) | |
context.verify_mode= cert_reqs | |
if ca_certs: | |
context.load_verify_locations(ca_certs) | |
if certfile: | |
context.load_cert_chain(certfile, keyfile) | |
if ciphers: | |
context.set_ciphers(ciphers) | |
return context.wrap_socket( | |
sock=sock,server_side=server_side, | |
do_handshake_on_connect=do_handshake_on_connect, | |
suppress_ragged_eofs=suppress_ragged_eofs | |
) | |
# some utility functions | |
defcert_time_to_seconds(cert_time): | |
"""Return the time in seconds since the Epoch, given the timestring | |
representing the "notBefore" or "notAfter" date from a certificate | |
in ``"%b %d %H:%M:%S %Y %Z"`` strptime format (C locale). | |
"notBefore" or "notAfter" dates must use UTC (RFC 5280). | |
Month is one of: Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec | |
UTC should be specified as GMT (see ASN1_TIME_print()) | |
""" | |
from timeimport strptime | |
from calendarimport timegm | |
months= ( | |
"Jan","Feb","Mar","Apr","May","Jun", | |
"Jul","Aug","Sep","Oct","Nov","Dec" | |
) | |
time_format='%d %H:%M:%S %Y GMT'#NOTE: no month, fixed GMT | |
try: | |
month_number= months.index(cert_time[:3].title())+1 | |
exceptValueError: | |
raiseValueError('time data%r does not match' | |
'format "%%b%s"'% (cert_time, time_format)) | |
else: | |
# found valid month | |
tt= strptime(cert_time[3:], time_format) | |
# return an integer, the previous mktime()-based implementation | |
# returned a float (fractional seconds are always zero here). | |
return timegm((tt[0], month_number)+ tt[2:6]) | |
PEM_HEADER="-----BEGIN CERTIFICATE-----" | |
PEM_FOOTER="-----END CERTIFICATE-----" | |
defDER_cert_to_PEM_cert(der_cert_bytes): | |
"""Takes a certificate in binary DER format and returns the | |
PEM version of it as a string.""" | |
f=str(base64.standard_b64encode(der_cert_bytes),'ASCII','strict') | |
ss= [PEM_HEADER] | |
ss+= [f[i:i+64]for iinrange(0,len(f),64)] | |
ss.append(PEM_FOOTER+'\n') | |
return'\n'.join(ss) | |
defPEM_cert_to_DER_cert(pem_cert_string): | |
"""Takes a certificate in ASCII PEM format and returns the | |
DER-encoded version of it as a byte sequence""" | |
ifnot pem_cert_string.startswith(PEM_HEADER): | |
raiseValueError("Invalid PEM encoding; must start with%s" | |
%PEM_HEADER) | |
ifnot pem_cert_string.strip().endswith(PEM_FOOTER): | |
raiseValueError("Invalid PEM encoding; must end with%s" | |
%PEM_FOOTER) | |
d= pem_cert_string.strip()[len(PEM_HEADER):-len(PEM_FOOTER)] | |
return base64.decodebytes(d.encode('ASCII','strict')) | |
defget_server_certificate(addr,ssl_version=PROTOCOL_TLS,ca_certs=None): | |
"""Retrieve the certificate from the server at the specified address, | |
and return it as a PEM-encoded string. | |
If 'ca_certs' is specified, validate the server cert against it. | |
If 'ssl_version' is specified, use it in the connection attempt.""" | |
host, port= addr | |
if ca_certsisnotNone: | |
cert_reqs=CERT_REQUIRED | |
else: | |
cert_reqs=CERT_NONE | |
context= _create_stdlib_context(ssl_version, | |
cert_reqs=cert_reqs, | |
cafile=ca_certs) | |
with create_connection(addr)as sock: | |
with context.wrap_socket(sock)as sslsock: | |
dercert= sslsock.getpeercert(True) | |
return DER_cert_to_PEM_cert(dercert) | |
defget_protocol_name(protocol_code): | |
return_PROTOCOL_NAMES.get(protocol_code,'<unknown>') |