Allow multiple client connections
This commit is contained in:
parent
c0dd046670
commit
e68ab4282c
@ -4,14 +4,15 @@ import collections
|
||||
import logging
|
||||
# import signal
|
||||
import os
|
||||
import random
|
||||
import ssl
|
||||
import string
|
||||
|
||||
|
||||
class Client:
|
||||
def __init__(self, host='localhost', port=3001,
|
||||
buffer_chunk_size=10**4, buffer_length_limit=10**4,
|
||||
password=None):
|
||||
self._password = password
|
||||
password=None, token=None):
|
||||
self._host = host
|
||||
self._port = port
|
||||
self._stopping = False
|
||||
@ -23,6 +24,8 @@ class Client:
|
||||
self._buffer_length_limit = buffer_length_limit
|
||||
self._file_path = None
|
||||
self._working = False
|
||||
self._token = token
|
||||
self._password = password
|
||||
self._ssl_context = None
|
||||
self._encryption_complete = False
|
||||
|
||||
@ -61,6 +64,10 @@ class Client:
|
||||
def set_ssl_context(self, ssl_context: ssl.SSLContext):
|
||||
self._ssl_context = ssl_context
|
||||
|
||||
@property
|
||||
def token(self):
|
||||
return self._token
|
||||
|
||||
@property
|
||||
def password(self):
|
||||
"""Password for file encryption or decryption."""
|
||||
@ -72,12 +79,31 @@ class Client:
|
||||
|
||||
async def run_sending_client(self, file_path='~/output.txt'):
|
||||
self._file_path = file_path
|
||||
reader, writer = await asyncio.open_connection(host=self.host,
|
||||
port=self.port,
|
||||
ssl=self.ssl_context)
|
||||
writer.write("sender\n".encode('utf-8'))
|
||||
file_name = os.path.basename(os.path.abspath(file_path))
|
||||
file_size = os.path.getsize(os.path.abspath(file_path))
|
||||
try:
|
||||
reader, writer = await asyncio.open_connection(
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
ssl=self.ssl_context
|
||||
)
|
||||
except ConnectionRefusedError as exception:
|
||||
logging.error(exception)
|
||||
return
|
||||
writer.write(
|
||||
f"s|{self.token}|{file_name}|{file_size}\n".encode('utf-8')
|
||||
)
|
||||
await writer.drain()
|
||||
await reader.readline() # Wait for server start signal
|
||||
# Wait for server start signal
|
||||
while 1:
|
||||
server_hello = await reader.readline()
|
||||
if not server_hello:
|
||||
logging.info("Server disconnected.")
|
||||
return
|
||||
server_hello = server_hello.decode('utf-8').strip('\n')
|
||||
if server_hello == 'start!':
|
||||
break
|
||||
logging.info(f"Server said: {server_hello}")
|
||||
await self.send(writer=writer)
|
||||
|
||||
async def encrypt_file(self, input_file, output_file):
|
||||
@ -142,12 +168,27 @@ class Client:
|
||||
|
||||
async def run_receiving_client(self, file_path='~/input.txt'):
|
||||
self._file_path = file_path
|
||||
reader, writer = await asyncio.open_connection(host=self.host,
|
||||
port=self.port,
|
||||
ssl=self.ssl_context)
|
||||
writer.write("receiver\n".encode('utf-8'))
|
||||
try:
|
||||
reader, writer = await asyncio.open_connection(
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
ssl=self.ssl_context
|
||||
)
|
||||
except ConnectionRefusedError as exception:
|
||||
logging.error(exception)
|
||||
return
|
||||
writer.write(f"r|{self.token}\n".encode('utf-8'))
|
||||
await writer.drain()
|
||||
await reader.readline() # Wait for server start signal
|
||||
# Wait for server start signal
|
||||
while 1:
|
||||
server_hello = await reader.readline()
|
||||
if not server_hello:
|
||||
logging.info("Server disconnected.")
|
||||
return
|
||||
server_hello = server_hello.decode('utf-8').strip('\n')
|
||||
if server_hello == 'start!':
|
||||
break
|
||||
logging.info(f"Server said: {server_hello}")
|
||||
await self.receive(reader=reader)
|
||||
|
||||
async def receive(self, reader: asyncio.StreamReader):
|
||||
@ -258,6 +299,11 @@ if __name__ == '__main__':
|
||||
default=None,
|
||||
required=False,
|
||||
help='Password for file encryption or decryption')
|
||||
cli_parser.add_argument('--token', '--t', '--session_token', type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help='Session token '
|
||||
'(must be the same for both clients)')
|
||||
cli_parser.add_argument('others',
|
||||
metavar='R or S',
|
||||
nargs='*',
|
||||
@ -268,6 +314,7 @@ if __name__ == '__main__':
|
||||
_action = get_action(args['action'])
|
||||
_file_path = args['path']
|
||||
_password = args['password']
|
||||
_token = args['token']
|
||||
|
||||
# If _host and _port are not provided from command-line, try to import them
|
||||
if _host is None:
|
||||
@ -303,6 +350,11 @@ if __name__ == '__main__':
|
||||
from config import password as _password
|
||||
except ImportError:
|
||||
_password = None
|
||||
if _token is None:
|
||||
try:
|
||||
from config import token as _token
|
||||
except ImportError:
|
||||
_token = None
|
||||
|
||||
# If import fails, prompt user for _host or _port
|
||||
while _host is None:
|
||||
@ -328,11 +380,29 @@ if __name__ == '__main__':
|
||||
"Your file will be unencoded unless you provide a password in "
|
||||
"config file."
|
||||
)
|
||||
if _token is None and _action == 'send':
|
||||
# Generate a random [6-10] chars-long alphanumerical token
|
||||
_token = ''.join(
|
||||
random.SystemRandom().choice(
|
||||
string.ascii_uppercase + string.digits
|
||||
)
|
||||
for _ in range(random.SystemRandom().randint(6, 10))
|
||||
)
|
||||
logging.info(
|
||||
"You have not provided a token for this connection.\n"
|
||||
f"A token has been generated for you:\t\t{_token}\n"
|
||||
"Your peer must be informed of this token.\n"
|
||||
"For future connections, you may provide a custom token writing "
|
||||
"it in config file."
|
||||
)
|
||||
while _token is None or not (6 <= len(_token) <= 10):
|
||||
_token = input("Please enter a 6-10 chars token.\t\t\t\t")
|
||||
loop = asyncio.get_event_loop()
|
||||
client = Client(
|
||||
host=_host,
|
||||
port=_port,
|
||||
password=_password
|
||||
password=_password,
|
||||
token=_token
|
||||
)
|
||||
try:
|
||||
from config import certificate
|
||||
|
137
src/server.py
137
src/server.py
@ -10,9 +10,9 @@ class Server:
|
||||
buffer_chunk_size=10**4, buffer_length_limit=10**4):
|
||||
self._host = host
|
||||
self._port = port
|
||||
self._stopping = False
|
||||
# Shared queue of bytes
|
||||
self.buffer = collections.deque()
|
||||
self.connections = collections.OrderedDict()
|
||||
# Dict of queues of bytes
|
||||
self.buffers = collections.OrderedDict()
|
||||
# How many bytes per chunk
|
||||
self._buffer_chunk_size = buffer_chunk_size
|
||||
# How many chunks in buffer
|
||||
@ -29,10 +29,6 @@ class Server:
|
||||
def port(self) -> int:
|
||||
return self._port
|
||||
|
||||
@property
|
||||
def stopping(self) -> bool:
|
||||
return self._stopping
|
||||
|
||||
@property
|
||||
def buffer_length_limit(self) -> int:
|
||||
return self._buffer_length_limit
|
||||
@ -53,28 +49,40 @@ class Server:
|
||||
def ssl_context(self) -> ssl.SSLContext:
|
||||
return self._ssl_context
|
||||
|
||||
@property
|
||||
def buffer_is_full(self):
|
||||
return (
|
||||
sum(len(buffer)
|
||||
for buffer in self.buffers.values())
|
||||
>= self.buffer_length_limit
|
||||
)
|
||||
|
||||
def set_ssl_context(self, ssl_context: ssl.SSLContext):
|
||||
self._ssl_context = ssl_context
|
||||
|
||||
async def run_reader(self, reader):
|
||||
while not self.stopping:
|
||||
async def run_reader(self, reader, connection_token):
|
||||
while 1:
|
||||
try:
|
||||
# Stop if buffer is full
|
||||
while len(self.buffer) >= self.buffer_length_limit:
|
||||
# Wait one second if buffer is full
|
||||
while self.buffer_is_full:
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
input_data = await reader.read(self.buffer_chunk_size)
|
||||
self.buffer.append(input_data)
|
||||
if connection_token not in self.buffers:
|
||||
break
|
||||
self.buffers[connection_token].append(input_data)
|
||||
except Exception as e:
|
||||
logging.error(e)
|
||||
logging.error(e, exc_info=True)
|
||||
|
||||
async def run_writer(self, writer):
|
||||
async def run_writer(self, writer, connection_token):
|
||||
consecutive_interruptions = 0
|
||||
errors = 0
|
||||
while not self.stopping:
|
||||
while 1:
|
||||
try:
|
||||
try:
|
||||
input_data = self.buffer.popleft()
|
||||
if connection_token not in self.buffers:
|
||||
break
|
||||
input_data = self.buffers[connection_token].popleft()
|
||||
except IndexError:
|
||||
# Slow down if buffer is short
|
||||
consecutive_interruptions += 1
|
||||
@ -89,7 +97,7 @@ class Server:
|
||||
writer.write(input_data)
|
||||
await writer.drain()
|
||||
except Exception as e:
|
||||
logging.error(e)
|
||||
logging.error(e, exc_info=True)
|
||||
errors += 1
|
||||
if errors > 3:
|
||||
break
|
||||
@ -104,25 +112,70 @@ class Server:
|
||||
Decide whether client is sender or receiver and start transmission.
|
||||
"""
|
||||
client_hello = await reader.readline()
|
||||
peer_is_sender = client_hello.decode('utf-8') == 'sender\n'
|
||||
client_hello = client_hello.decode('utf-8').strip('\n').split('|')
|
||||
peer_is_sender = client_hello[0] == 's'
|
||||
connection_token = client_hello[1]
|
||||
if connection_token not in self.connections:
|
||||
self.connections[connection_token] = dict(
|
||||
sender=False,
|
||||
receiver=False
|
||||
)
|
||||
if peer_is_sender:
|
||||
self._working = True
|
||||
if self.connections[connection_token]['sender']:
|
||||
writer.write(
|
||||
"Invalid token! "
|
||||
"A sender client is already connected!\n".encode('utf-8')
|
||||
)
|
||||
await writer.drain()
|
||||
writer.close()
|
||||
return
|
||||
self.connections[connection_token]['sender'] = True
|
||||
self.buffers[connection_token] = collections.deque()
|
||||
logging.info("Sender is connecting...")
|
||||
# Send start signal to client
|
||||
writer.write("Start!\n".encode('utf-8'))
|
||||
await writer.drain()
|
||||
await self.run_reader(reader=reader)
|
||||
logging.info("Incoming transmission ended")
|
||||
else:
|
||||
logging.info("Receiver is connecting...")
|
||||
while len(self.buffer) == 0:
|
||||
index, step = 0, 1
|
||||
while not self.connections[connection_token]['receiver']:
|
||||
index += 1
|
||||
if index >= step:
|
||||
writer.write("Waiting for receiver...\n".encode('utf-8'))
|
||||
await writer.drain()
|
||||
step += 1
|
||||
index = 0
|
||||
await asyncio.sleep(.5)
|
||||
# Send start signal to client
|
||||
writer.write("Start!\n".encode('utf-8'))
|
||||
writer.write("start!\n".encode('utf-8'))
|
||||
await writer.drain()
|
||||
await self.run_writer(writer=writer)
|
||||
logging.info("Incoming transmission starting...")
|
||||
await self.run_reader(reader=reader,
|
||||
connection_token=connection_token)
|
||||
logging.info("Incoming transmission ended")
|
||||
else:
|
||||
if self.connections[connection_token]['receiver']:
|
||||
writer.write(
|
||||
"Invalid token! "
|
||||
"A receiver client is already connected!\n".encode('utf-8')
|
||||
)
|
||||
await writer.drain()
|
||||
writer.close()
|
||||
return
|
||||
self.connections[connection_token]['receiver'] = True
|
||||
logging.info("Receiver is connecting...")
|
||||
index, step = 0, 1
|
||||
while not self.connections[connection_token]['sender']:
|
||||
index += 1
|
||||
if index >= step:
|
||||
writer.write("Waiting for sender...\n".encode('utf-8'))
|
||||
await writer.drain()
|
||||
step += 1
|
||||
index = 0
|
||||
await asyncio.sleep(.5)
|
||||
# Send start signal to client
|
||||
writer.write("start!\n".encode('utf-8'))
|
||||
await writer.drain()
|
||||
await self.run_writer(writer=writer,
|
||||
connection_token=connection_token)
|
||||
logging.info("Outgoing transmission ended")
|
||||
self._working = False
|
||||
del self.buffers[connection_token]
|
||||
del self.connections[connection_token]
|
||||
return
|
||||
|
||||
def run(self):
|
||||
@ -149,23 +202,11 @@ class Server:
|
||||
port=self.port,
|
||||
)
|
||||
async with self.server:
|
||||
try:
|
||||
await self.server.serve_forever()
|
||||
except KeyboardInterrupt:
|
||||
logging.info("Stopping...")
|
||||
self.server.close()
|
||||
await self.server.wait_closed()
|
||||
await self.server.serve_forever()
|
||||
return
|
||||
|
||||
def stop(self, *_):
|
||||
if self.working and not self.stopping:
|
||||
logging.info("Received interruption signal, stopping...")
|
||||
self._stopping = True
|
||||
else:
|
||||
raise KeyboardInterrupt("Not working yet...")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
def main():
|
||||
# noinspection SpellCheckingInspection
|
||||
log_formatter = logging.Formatter(
|
||||
"%(asctime)s [%(module)-15s %(levelname)-8s] %(message)s",
|
||||
@ -221,12 +262,16 @@ if __name__ == '__main__':
|
||||
port=_port,
|
||||
)
|
||||
try:
|
||||
# noinspection PyUnresolvedReferences
|
||||
from config import certificate, key
|
||||
_ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
||||
_ssl_context.check_hostname = False
|
||||
_ssl_context.load_cert_chain(certificate, key)
|
||||
server.set_ssl_context(_ssl_context)
|
||||
except ImportError:
|
||||
logging.info("Please consider using SSL.")
|
||||
certificate, key = None, None
|
||||
logging.warning("Please consider using SSL.")
|
||||
server.run()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user