diff --git a/src/client.py b/src/client.py index ac198df..ac44bff 100644 --- a/src/client.py +++ b/src/client.py @@ -2,7 +2,6 @@ import argparse import asyncio import collections import logging -# import signal import os import random import ssl @@ -28,6 +27,8 @@ class Client: self._password = password self._ssl_context = None self._encryption_complete = False + self._file_name = None + self._file_size = None @property def host(self) -> str: @@ -77,6 +78,14 @@ class Client: def encryption_complete(self): return self._encryption_complete + @property + def file_name(self): + return self._file_name + + @property + def file_size(self): + return self._file_size + async def run_sending_client(self, file_path='~/output.txt'): self._file_path = file_path file_name = os.path.basename(os.path.abspath(file_path)) @@ -186,7 +195,11 @@ class Client: logging.info("Server disconnected.") return server_hello = server_hello.decode('utf-8').strip('\n') - if server_hello == 'start!': + if server_hello.startswith('info'): + _, file_name, file_size = server_hello.split('|') + self.set_file_information(file_name=file_name, + file_size=file_size) + elif server_hello == 'start!': break logging.info(f"Server said: {server_hello}") await self.receive(reader=reader) @@ -233,6 +246,12 @@ class Client: else: raise KeyboardInterrupt("Not working yet...") + def set_file_information(self, file_name=None, file_size=None): + if file_name is not None: + self._file_name = file_name + if file_size is not None: + self._file_size = file_size + def get_action(action): """Parse abbreviations for `action`.""" diff --git a/src/server.py b/src/server.py index fee5b3d..aef76f2 100644 --- a/src/server.py +++ b/src/server.py @@ -113,23 +113,27 @@ class Server: """ client_hello = await reader.readline() client_hello = client_hello.decode('utf-8').strip('\n').split('|') - peer_is_sender = client_hello[0] == 's' + if len(client_hello) not in (2, 4,): + await self.refuse_connection(writer=writer, + message="Invalid client_hello!") + return connection_token = client_hello[1] if connection_token not in self.connections: self.connections[connection_token] = dict( sender=False, receiver=False ) - if peer_is_sender: + if client_hello[0] == 's': if self.connections[connection_token]['sender']: - writer.write( - "Invalid token! " - "A sender client is already connected!\n".encode('utf-8') + await self.refuse_connection( + writer=writer, + message="Invalid token! " + "A sender client is already connected!\n" ) - await writer.drain() - writer.close() return self.connections[connection_token]['sender'] = True + self.connections[connection_token]['file_name'] = client_hello[2] + self.connections[connection_token]['file_size'] = client_hello[3] self.buffers[connection_token] = collections.deque() logging.info("Sender is connecting...") index, step = 0, 1 @@ -148,14 +152,13 @@ class Server: await self.run_reader(reader=reader, connection_token=connection_token) logging.info("Incoming transmission ended") - else: + else: # Receiver client connection if self.connections[connection_token]['receiver']: - writer.write( - "Invalid token! " - "A receiver client is already connected!\n".encode('utf-8') + await self.refuse_connection( + writer=writer, + message="Invalid token! " + "A receiver client is already connected!\n" ) - await writer.drain() - writer.close() return self.connections[connection_token]['receiver'] = True logging.info("Receiver is connecting...") @@ -168,7 +171,13 @@ class Server: step += 1 index = 0 await asyncio.sleep(.5) - # Send start signal to client + # Send file information and start signal to client + writer.write( + "info|" + f"{self.connections[connection_token]['file_name']}|" + f"{self.connections[connection_token]['file_size']}" + "\n".encode('utf-8') + ) writer.write("start!\n".encode('utf-8')) await writer.drain() await self.run_writer(writer=writer, @@ -176,7 +185,6 @@ class Server: logging.info("Outgoing transmission ended") del self.buffers[connection_token] del self.connections[connection_token] - return def run(self): loop = asyncio.get_event_loop() @@ -203,7 +211,18 @@ class Server: ) async with self.server: await self.server.serve_forever() - return + + @staticmethod + async def refuse_connection(writer: asyncio.StreamWriter, + message: str = None): + """Send a `message` via writer and close it.""" + if message is None: + message = "Connection refused!\n" + writer.write( + message.encode('utf-8') + ) + await writer.drain() + writer.close() def main():