diff --git a/.gitignore b/.gitignore index fb1d634..4087523 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,9 @@ # Configuration file *config.py +# Data folder +data/ + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/src/client.py b/src/client.py index ff9e4a2..f578e8b 100644 --- a/src/client.py +++ b/src/client.py @@ -4,6 +4,7 @@ import collections import logging # import signal import os +import ssl class Client: @@ -17,6 +18,7 @@ class Client: self._buffer_length_limit = buffer_length_limit # How many chunks in buffer self._file_path = None self._working = False + self._ssl_context = None @property def host(self) -> str: @@ -46,10 +48,18 @@ class Client: def working(self) -> bool: return self._working + @property + def ssl_context(self) -> ssl.SSLContext: + return self._ssl_context + + def set_ssl_context(self, ssl_context: ssl.SSLContext): + self._ssl_context = ssl_context + async def run_sending_client(self, file_path='~/output.txt'): self._file_path = file_path reader, writer = await asyncio.open_connection(host=self.host, - port=self.port) + port=self.port, + ssl=self.ssl_context) writer.write("sender\n".encode('utf-8')) await writer.drain() await reader.readline() # Wait for server start signal @@ -78,7 +88,8 @@ class Client: async def run_receiving_client(self, file_path='~/input.txt'): self._file_path = file_path reader, writer = await asyncio.open_connection(host=self.host, - port=self.port) + port=self.port, + ssl=self.ssl_context) writer.write("receiver\n".encode('utf-8')) await writer.drain() await reader.readline() # Wait for server start signal @@ -227,6 +238,15 @@ if __name__ == '__main__': host=_host, port=_port, ) + try: + from config import certificate + _ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) + _ssl_context.check_hostname = False + _ssl_context.load_verify_locations(certificate) + client.set_ssl_context(_ssl_context) + except ImportError: + logging.info("Please consider using SSL.") + certificate, key = None, None logging.info("Starting client...") if _action == 'send': loop.run_until_complete( diff --git a/src/server.py b/src/server.py index 5b627ef..94dad9d 100644 --- a/src/server.py +++ b/src/server.py @@ -2,6 +2,7 @@ import argparse import asyncio import collections import logging +import ssl class Server: @@ -16,6 +17,7 @@ class Server: self._working = False self.at_eof = False self._server = None + self._ssl_context = None @property def host(self) -> str: @@ -45,6 +47,13 @@ class Server: def server(self) -> asyncio.base_events.Server: return self._server + @property + def ssl_context(self) -> ssl.SSLContext: + return self._ssl_context + + def set_ssl_context(self, ssl_context: ssl.SSLContext): + self._ssl_context = ssl_context + async def run_reader(self, reader): while not self.stopping: try: @@ -121,9 +130,10 @@ class Server: async def run_server(self): self._server = await asyncio.start_server( + ssl=self.ssl_context, client_connected_cb=self.connect, host=self.host, - port=self.port + port=self.port, ) async with self.server: try: @@ -197,4 +207,13 @@ if __name__ == '__main__': host=_host, port=_port, ) + try: + from config import certificate, key + _ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + _ssl_context.check_hostname = False + _ssl_context.load_cert_chain(certificate, key) + server.set_ssl_context(_ssl_context) + except ImportError: + logging.info("Please consider using SSL.") + certificate, key = None, None server.run()