Allow multiple client connections

This commit is contained in:
Davte 2020-04-11 19:59:09 +02:00
parent c0dd046670
commit e68ab4282c
2 changed files with 174 additions and 59 deletions

View File

@ -4,14 +4,15 @@ import collections
import logging
# import signal
import os
import random
import ssl
import string
class Client:
def __init__(self, host='localhost', port=3001,
buffer_chunk_size=10**4, buffer_length_limit=10**4,
password=None):
self._password = password
password=None, token=None):
self._host = host
self._port = port
self._stopping = False
@ -23,6 +24,8 @@ class Client:
self._buffer_length_limit = buffer_length_limit
self._file_path = None
self._working = False
self._token = token
self._password = password
self._ssl_context = None
self._encryption_complete = False
@ -61,6 +64,10 @@ class Client:
def set_ssl_context(self, ssl_context: ssl.SSLContext):
self._ssl_context = ssl_context
@property
def token(self):
return self._token
@property
def password(self):
"""Password for file encryption or decryption."""
@ -72,12 +79,31 @@ class Client:
async def run_sending_client(self, file_path='~/output.txt'):
self._file_path = file_path
reader, writer = await asyncio.open_connection(host=self.host,
port=self.port,
ssl=self.ssl_context)
writer.write("sender\n".encode('utf-8'))
file_name = os.path.basename(os.path.abspath(file_path))
file_size = os.path.getsize(os.path.abspath(file_path))
try:
reader, writer = await asyncio.open_connection(
host=self.host,
port=self.port,
ssl=self.ssl_context
)
except ConnectionRefusedError as exception:
logging.error(exception)
return
writer.write(
f"s|{self.token}|{file_name}|{file_size}\n".encode('utf-8')
)
await writer.drain()
await reader.readline() # Wait for server start signal
# Wait for server start signal
while 1:
server_hello = await reader.readline()
if not server_hello:
logging.info("Server disconnected.")
return
server_hello = server_hello.decode('utf-8').strip('\n')
if server_hello == 'start!':
break
logging.info(f"Server said: {server_hello}")
await self.send(writer=writer)
async def encrypt_file(self, input_file, output_file):
@ -142,12 +168,27 @@ class Client:
async def run_receiving_client(self, file_path='~/input.txt'):
self._file_path = file_path
reader, writer = await asyncio.open_connection(host=self.host,
port=self.port,
ssl=self.ssl_context)
writer.write("receiver\n".encode('utf-8'))
try:
reader, writer = await asyncio.open_connection(
host=self.host,
port=self.port,
ssl=self.ssl_context
)
except ConnectionRefusedError as exception:
logging.error(exception)
return
writer.write(f"r|{self.token}\n".encode('utf-8'))
await writer.drain()
await reader.readline() # Wait for server start signal
# Wait for server start signal
while 1:
server_hello = await reader.readline()
if not server_hello:
logging.info("Server disconnected.")
return
server_hello = server_hello.decode('utf-8').strip('\n')
if server_hello == 'start!':
break
logging.info(f"Server said: {server_hello}")
await self.receive(reader=reader)
async def receive(self, reader: asyncio.StreamReader):
@ -258,6 +299,11 @@ if __name__ == '__main__':
default=None,
required=False,
help='Password for file encryption or decryption')
cli_parser.add_argument('--token', '--t', '--session_token', type=str,
default=None,
required=False,
help='Session token '
'(must be the same for both clients)')
cli_parser.add_argument('others',
metavar='R or S',
nargs='*',
@ -268,6 +314,7 @@ if __name__ == '__main__':
_action = get_action(args['action'])
_file_path = args['path']
_password = args['password']
_token = args['token']
# If _host and _port are not provided from command-line, try to import them
if _host is None:
@ -303,6 +350,11 @@ if __name__ == '__main__':
from config import password as _password
except ImportError:
_password = None
if _token is None:
try:
from config import token as _token
except ImportError:
_token = None
# If import fails, prompt user for _host or _port
while _host is None:
@ -328,11 +380,29 @@ if __name__ == '__main__':
"Your file will be unencoded unless you provide a password in "
"config file."
)
if _token is None and _action == 'send':
# Generate a random [6-10] chars-long alphanumerical token
_token = ''.join(
random.SystemRandom().choice(
string.ascii_uppercase + string.digits
)
for _ in range(random.SystemRandom().randint(6, 10))
)
logging.info(
"You have not provided a token for this connection.\n"
f"A token has been generated for you:\t\t{_token}\n"
"Your peer must be informed of this token.\n"
"For future connections, you may provide a custom token writing "
"it in config file."
)
while _token is None or not (6 <= len(_token) <= 10):
_token = input("Please enter a 6-10 chars token.\t\t\t\t")
loop = asyncio.get_event_loop()
client = Client(
host=_host,
port=_port,
password=_password
password=_password,
token=_token
)
try:
from config import certificate

View File

@ -10,9 +10,9 @@ class Server:
buffer_chunk_size=10**4, buffer_length_limit=10**4):
self._host = host
self._port = port
self._stopping = False
# Shared queue of bytes
self.buffer = collections.deque()
self.connections = collections.OrderedDict()
# Dict of queues of bytes
self.buffers = collections.OrderedDict()
# How many bytes per chunk
self._buffer_chunk_size = buffer_chunk_size
# How many chunks in buffer
@ -29,10 +29,6 @@ class Server:
def port(self) -> int:
return self._port
@property
def stopping(self) -> bool:
return self._stopping
@property
def buffer_length_limit(self) -> int:
return self._buffer_length_limit
@ -53,28 +49,40 @@ class Server:
def ssl_context(self) -> ssl.SSLContext:
return self._ssl_context
@property
def buffer_is_full(self):
return (
sum(len(buffer)
for buffer in self.buffers.values())
>= self.buffer_length_limit
)
def set_ssl_context(self, ssl_context: ssl.SSLContext):
self._ssl_context = ssl_context
async def run_reader(self, reader):
while not self.stopping:
async def run_reader(self, reader, connection_token):
while 1:
try:
# Stop if buffer is full
while len(self.buffer) >= self.buffer_length_limit:
# Wait one second if buffer is full
while self.buffer_is_full:
await asyncio.sleep(1)
continue
input_data = await reader.read(self.buffer_chunk_size)
self.buffer.append(input_data)
if connection_token not in self.buffers:
break
self.buffers[connection_token].append(input_data)
except Exception as e:
logging.error(e)
logging.error(e, exc_info=True)
async def run_writer(self, writer):
async def run_writer(self, writer, connection_token):
consecutive_interruptions = 0
errors = 0
while not self.stopping:
while 1:
try:
try:
input_data = self.buffer.popleft()
if connection_token not in self.buffers:
break
input_data = self.buffers[connection_token].popleft()
except IndexError:
# Slow down if buffer is short
consecutive_interruptions += 1
@ -89,7 +97,7 @@ class Server:
writer.write(input_data)
await writer.drain()
except Exception as e:
logging.error(e)
logging.error(e, exc_info=True)
errors += 1
if errors > 3:
break
@ -104,25 +112,70 @@ class Server:
Decide whether client is sender or receiver and start transmission.
"""
client_hello = await reader.readline()
peer_is_sender = client_hello.decode('utf-8') == 'sender\n'
client_hello = client_hello.decode('utf-8').strip('\n').split('|')
peer_is_sender = client_hello[0] == 's'
connection_token = client_hello[1]
if connection_token not in self.connections:
self.connections[connection_token] = dict(
sender=False,
receiver=False
)
if peer_is_sender:
self._working = True
if self.connections[connection_token]['sender']:
writer.write(
"Invalid token! "
"A sender client is already connected!\n".encode('utf-8')
)
await writer.drain()
writer.close()
return
self.connections[connection_token]['sender'] = True
self.buffers[connection_token] = collections.deque()
logging.info("Sender is connecting...")
# Send start signal to client
writer.write("Start!\n".encode('utf-8'))
await writer.drain()
await self.run_reader(reader=reader)
logging.info("Incoming transmission ended")
else:
logging.info("Receiver is connecting...")
while len(self.buffer) == 0:
index, step = 0, 1
while not self.connections[connection_token]['receiver']:
index += 1
if index >= step:
writer.write("Waiting for receiver...\n".encode('utf-8'))
await writer.drain()
step += 1
index = 0
await asyncio.sleep(.5)
# Send start signal to client
writer.write("Start!\n".encode('utf-8'))
writer.write("start!\n".encode('utf-8'))
await writer.drain()
await self.run_writer(writer=writer)
logging.info("Incoming transmission starting...")
await self.run_reader(reader=reader,
connection_token=connection_token)
logging.info("Incoming transmission ended")
else:
if self.connections[connection_token]['receiver']:
writer.write(
"Invalid token! "
"A receiver client is already connected!\n".encode('utf-8')
)
await writer.drain()
writer.close()
return
self.connections[connection_token]['receiver'] = True
logging.info("Receiver is connecting...")
index, step = 0, 1
while not self.connections[connection_token]['sender']:
index += 1
if index >= step:
writer.write("Waiting for sender...\n".encode('utf-8'))
await writer.drain()
step += 1
index = 0
await asyncio.sleep(.5)
# Send start signal to client
writer.write("start!\n".encode('utf-8'))
await writer.drain()
await self.run_writer(writer=writer,
connection_token=connection_token)
logging.info("Outgoing transmission ended")
self._working = False
del self.buffers[connection_token]
del self.connections[connection_token]
return
def run(self):
@ -149,23 +202,11 @@ class Server:
port=self.port,
)
async with self.server:
try:
await self.server.serve_forever()
except KeyboardInterrupt:
logging.info("Stopping...")
self.server.close()
await self.server.wait_closed()
await self.server.serve_forever()
return
def stop(self, *_):
if self.working and not self.stopping:
logging.info("Received interruption signal, stopping...")
self._stopping = True
else:
raise KeyboardInterrupt("Not working yet...")
if __name__ == '__main__':
def main():
# noinspection SpellCheckingInspection
log_formatter = logging.Formatter(
"%(asctime)s [%(module)-15s %(levelname)-8s] %(message)s",
@ -221,12 +262,16 @@ if __name__ == '__main__':
port=_port,
)
try:
# noinspection PyUnresolvedReferences
from config import certificate, key
_ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
_ssl_context.check_hostname = False
_ssl_context.load_cert_chain(certificate, key)
server.set_ssl_context(_ssl_context)
except ImportError:
logging.info("Please consider using SSL.")
certificate, key = None, None
logging.warning("Please consider using SSL.")
server.run()
if __name__ == '__main__':
main()