diff --git a/src/client.py b/src/client.py index 8490f10..ac198df 100644 --- a/src/client.py +++ b/src/client.py @@ -4,14 +4,15 @@ import collections import logging # import signal import os +import random import ssl +import string class Client: def __init__(self, host='localhost', port=3001, buffer_chunk_size=10**4, buffer_length_limit=10**4, - password=None): - self._password = password + password=None, token=None): self._host = host self._port = port self._stopping = False @@ -23,6 +24,8 @@ class Client: self._buffer_length_limit = buffer_length_limit self._file_path = None self._working = False + self._token = token + self._password = password self._ssl_context = None self._encryption_complete = False @@ -61,6 +64,10 @@ class Client: def set_ssl_context(self, ssl_context: ssl.SSLContext): self._ssl_context = ssl_context + @property + def token(self): + return self._token + @property def password(self): """Password for file encryption or decryption.""" @@ -72,12 +79,31 @@ class Client: async def run_sending_client(self, file_path='~/output.txt'): self._file_path = file_path - reader, writer = await asyncio.open_connection(host=self.host, - port=self.port, - ssl=self.ssl_context) - writer.write("sender\n".encode('utf-8')) + file_name = os.path.basename(os.path.abspath(file_path)) + file_size = os.path.getsize(os.path.abspath(file_path)) + try: + reader, writer = await asyncio.open_connection( + host=self.host, + port=self.port, + ssl=self.ssl_context + ) + except ConnectionRefusedError as exception: + logging.error(exception) + return + writer.write( + f"s|{self.token}|{file_name}|{file_size}\n".encode('utf-8') + ) await writer.drain() - await reader.readline() # Wait for server start signal + # Wait for server start signal + while 1: + server_hello = await reader.readline() + if not server_hello: + logging.info("Server disconnected.") + return + server_hello = server_hello.decode('utf-8').strip('\n') + if server_hello == 'start!': + break + logging.info(f"Server said: {server_hello}") await self.send(writer=writer) async def encrypt_file(self, input_file, output_file): @@ -142,12 +168,27 @@ class Client: async def run_receiving_client(self, file_path='~/input.txt'): self._file_path = file_path - reader, writer = await asyncio.open_connection(host=self.host, - port=self.port, - ssl=self.ssl_context) - writer.write("receiver\n".encode('utf-8')) + try: + reader, writer = await asyncio.open_connection( + host=self.host, + port=self.port, + ssl=self.ssl_context + ) + except ConnectionRefusedError as exception: + logging.error(exception) + return + writer.write(f"r|{self.token}\n".encode('utf-8')) await writer.drain() - await reader.readline() # Wait for server start signal + # Wait for server start signal + while 1: + server_hello = await reader.readline() + if not server_hello: + logging.info("Server disconnected.") + return + server_hello = server_hello.decode('utf-8').strip('\n') + if server_hello == 'start!': + break + logging.info(f"Server said: {server_hello}") await self.receive(reader=reader) async def receive(self, reader: asyncio.StreamReader): @@ -258,6 +299,11 @@ if __name__ == '__main__': default=None, required=False, help='Password for file encryption or decryption') + cli_parser.add_argument('--token', '--t', '--session_token', type=str, + default=None, + required=False, + help='Session token ' + '(must be the same for both clients)') cli_parser.add_argument('others', metavar='R or S', nargs='*', @@ -268,6 +314,7 @@ if __name__ == '__main__': _action = get_action(args['action']) _file_path = args['path'] _password = args['password'] + _token = args['token'] # If _host and _port are not provided from command-line, try to import them if _host is None: @@ -303,6 +350,11 @@ if __name__ == '__main__': from config import password as _password except ImportError: _password = None + if _token is None: + try: + from config import token as _token + except ImportError: + _token = None # If import fails, prompt user for _host or _port while _host is None: @@ -328,11 +380,29 @@ if __name__ == '__main__': "Your file will be unencoded unless you provide a password in " "config file." ) + if _token is None and _action == 'send': + # Generate a random [6-10] chars-long alphanumerical token + _token = ''.join( + random.SystemRandom().choice( + string.ascii_uppercase + string.digits + ) + for _ in range(random.SystemRandom().randint(6, 10)) + ) + logging.info( + "You have not provided a token for this connection.\n" + f"A token has been generated for you:\t\t{_token}\n" + "Your peer must be informed of this token.\n" + "For future connections, you may provide a custom token writing " + "it in config file." + ) + while _token is None or not (6 <= len(_token) <= 10): + _token = input("Please enter a 6-10 chars token.\t\t\t\t") loop = asyncio.get_event_loop() client = Client( host=_host, port=_port, - password=_password + password=_password, + token=_token ) try: from config import certificate diff --git a/src/server.py b/src/server.py index cda6345..fee5b3d 100644 --- a/src/server.py +++ b/src/server.py @@ -10,9 +10,9 @@ class Server: buffer_chunk_size=10**4, buffer_length_limit=10**4): self._host = host self._port = port - self._stopping = False - # Shared queue of bytes - self.buffer = collections.deque() + self.connections = collections.OrderedDict() + # Dict of queues of bytes + self.buffers = collections.OrderedDict() # How many bytes per chunk self._buffer_chunk_size = buffer_chunk_size # How many chunks in buffer @@ -29,10 +29,6 @@ class Server: def port(self) -> int: return self._port - @property - def stopping(self) -> bool: - return self._stopping - @property def buffer_length_limit(self) -> int: return self._buffer_length_limit @@ -53,28 +49,40 @@ class Server: def ssl_context(self) -> ssl.SSLContext: return self._ssl_context + @property + def buffer_is_full(self): + return ( + sum(len(buffer) + for buffer in self.buffers.values()) + >= self.buffer_length_limit + ) + def set_ssl_context(self, ssl_context: ssl.SSLContext): self._ssl_context = ssl_context - async def run_reader(self, reader): - while not self.stopping: + async def run_reader(self, reader, connection_token): + while 1: try: - # Stop if buffer is full - while len(self.buffer) >= self.buffer_length_limit: + # Wait one second if buffer is full + while self.buffer_is_full: await asyncio.sleep(1) continue input_data = await reader.read(self.buffer_chunk_size) - self.buffer.append(input_data) + if connection_token not in self.buffers: + break + self.buffers[connection_token].append(input_data) except Exception as e: - logging.error(e) + logging.error(e, exc_info=True) - async def run_writer(self, writer): + async def run_writer(self, writer, connection_token): consecutive_interruptions = 0 errors = 0 - while not self.stopping: + while 1: try: try: - input_data = self.buffer.popleft() + if connection_token not in self.buffers: + break + input_data = self.buffers[connection_token].popleft() except IndexError: # Slow down if buffer is short consecutive_interruptions += 1 @@ -89,7 +97,7 @@ class Server: writer.write(input_data) await writer.drain() except Exception as e: - logging.error(e) + logging.error(e, exc_info=True) errors += 1 if errors > 3: break @@ -104,25 +112,70 @@ class Server: Decide whether client is sender or receiver and start transmission. """ client_hello = await reader.readline() - peer_is_sender = client_hello.decode('utf-8') == 'sender\n' + client_hello = client_hello.decode('utf-8').strip('\n').split('|') + peer_is_sender = client_hello[0] == 's' + connection_token = client_hello[1] + if connection_token not in self.connections: + self.connections[connection_token] = dict( + sender=False, + receiver=False + ) if peer_is_sender: - self._working = True + if self.connections[connection_token]['sender']: + writer.write( + "Invalid token! " + "A sender client is already connected!\n".encode('utf-8') + ) + await writer.drain() + writer.close() + return + self.connections[connection_token]['sender'] = True + self.buffers[connection_token] = collections.deque() logging.info("Sender is connecting...") - # Send start signal to client - writer.write("Start!\n".encode('utf-8')) - await writer.drain() - await self.run_reader(reader=reader) - logging.info("Incoming transmission ended") - else: - logging.info("Receiver is connecting...") - while len(self.buffer) == 0: + index, step = 0, 1 + while not self.connections[connection_token]['receiver']: + index += 1 + if index >= step: + writer.write("Waiting for receiver...\n".encode('utf-8')) + await writer.drain() + step += 1 + index = 0 await asyncio.sleep(.5) # Send start signal to client - writer.write("Start!\n".encode('utf-8')) + writer.write("start!\n".encode('utf-8')) await writer.drain() - await self.run_writer(writer=writer) + logging.info("Incoming transmission starting...") + await self.run_reader(reader=reader, + connection_token=connection_token) + logging.info("Incoming transmission ended") + else: + if self.connections[connection_token]['receiver']: + writer.write( + "Invalid token! " + "A receiver client is already connected!\n".encode('utf-8') + ) + await writer.drain() + writer.close() + return + self.connections[connection_token]['receiver'] = True + logging.info("Receiver is connecting...") + index, step = 0, 1 + while not self.connections[connection_token]['sender']: + index += 1 + if index >= step: + writer.write("Waiting for sender...\n".encode('utf-8')) + await writer.drain() + step += 1 + index = 0 + await asyncio.sleep(.5) + # Send start signal to client + writer.write("start!\n".encode('utf-8')) + await writer.drain() + await self.run_writer(writer=writer, + connection_token=connection_token) logging.info("Outgoing transmission ended") - self._working = False + del self.buffers[connection_token] + del self.connections[connection_token] return def run(self): @@ -149,23 +202,11 @@ class Server: port=self.port, ) async with self.server: - try: - await self.server.serve_forever() - except KeyboardInterrupt: - logging.info("Stopping...") - self.server.close() - await self.server.wait_closed() + await self.server.serve_forever() return - def stop(self, *_): - if self.working and not self.stopping: - logging.info("Received interruption signal, stopping...") - self._stopping = True - else: - raise KeyboardInterrupt("Not working yet...") - -if __name__ == '__main__': +def main(): # noinspection SpellCheckingInspection log_formatter = logging.Formatter( "%(asctime)s [%(module)-15s %(levelname)-8s] %(message)s", @@ -221,12 +262,16 @@ if __name__ == '__main__': port=_port, ) try: + # noinspection PyUnresolvedReferences from config import certificate, key _ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) _ssl_context.check_hostname = False _ssl_context.load_cert_chain(certificate, key) server.set_ssl_context(_ssl_context) except ImportError: - logging.info("Please consider using SSL.") - certificate, key = None, None + logging.warning("Please consider using SSL.") server.run() + + +if __name__ == '__main__': + main()