Skip to content

Commit

Permalink
Merge pull request #34 from jptomoya/feature/support_tls
Browse files Browse the repository at this point in the history
Initial implementation of SSL/TLS connection in Socket
  • Loading branch information
ptr-yudai authored Apr 10, 2024
2 parents 6579a72 + c31232b commit 884ca8b
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 1 deletion.
16 changes: 15 additions & 1 deletion ptrlib/connection/sock.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@


class Socket(Tube):
def __init__(self, host: Union[str, bytes], port: Optional[int]=None, timeout: Optional[Union[int, float]]=None):
def __init__(self, host: Union[str, bytes], port: Optional[int]=None,
timeout: Optional[Union[int, float]]=None,
ssl: bool=False, sni: Union[str, bool]=True):
"""Create a socket
Create a new socket and establish a connection to the host.
Expand Down Expand Up @@ -45,6 +47,15 @@ def __init__(self, host: Union[str, bytes], port: Optional[int]=None, timeout: O
self.timeout = timeout
# Create a new socket
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
if ssl:
import ssl as _ssl
self.context = _ssl.SSLContext(_ssl.PROTOCOL_TLS_CLIENT)
self.context.check_hostname = False
self.context.verify_mode = _ssl.CERT_NONE
if sni is True:
self.sock = self.context.wrap_socket(self.sock)
else:
self.sock = self.context.wrap_socket(self.sock, server_hostname=sni)
# Establish a connection
try:
self.sock.connect((self.host, self.port))
Expand Down Expand Up @@ -84,6 +95,9 @@ def _recv(self, size: int=4096, timeout: Optional[Union[int, float]]=None) -> by
except ConnectionAbortedError as e:
logger.warning("Connection aborted by the host")
raise e from None
except ConnectionResetError as e:
logger.warning("Connection reset by the host")
raise e from None

return data

Expand Down
21 changes: 21 additions & 0 deletions tests/connection/test_sock.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import unittest
from socket import gethostbyname
from ptrlib import Socket
from logging import getLogger, FATAL

Expand Down Expand Up @@ -40,3 +41,23 @@ def test_timeout(self):
sock.close()

self.assertEqual(result, True)

def test_tls(self):
host = "www.example.com"

# connect with sni
ip_addr = gethostbyname(host)
sock = Socket(ip_addr, 443, ssl=True, sni=host)
sock.sendline(b'GET / HTTP/1.1\r')
sock.send(b'Host: www.example.com\r\n')
sock.send(b'Connection: close\r\n\r\n')
self.assertTrue(int(sock.recvlineafter('Content-Length: ')) > 0)
sock.close()

# connect without sni
sock = Socket(host, 443, ssl=True)
sock.sendline(b'GET / HTTP/1.1\r')
sock.send(b'Host: www.example.com\r\n')
sock.send(b'Connection: close\r\n\r\n')
self.assertTrue(int(sock.recvlineafter('Content-Length: ')) > 0)
sock.close()

0 comments on commit 884ca8b

Please sign in to comment.