Refactoring

This commit is contained in:
Davte 2020-04-13 19:45:05 +02:00
parent 3b7aa265ab
commit 3f5384f9e9
3 changed files with 305 additions and 109 deletions

View File

@ -9,16 +9,25 @@ import random
import ssl import ssl
import string import string
import sys import sys
from typing import Union
from . import utilities from . import utilities
class Client: class Client:
def __init__(self, host='localhost', port=3001, def __init__(self, host='localhost', port=5000, ssl_context=None,
buffer_chunk_size=10**4, buffer_length_limit=10**4, action=None,
password=None, token=None): standalone=False,
buffer_chunk_size=10 ** 4,
buffer_length_limit=10 ** 4,
file_path=None,
password=None,
token=None):
self._host = host self._host = host
self._port = port self._port = port
self._ssl_context = ssl_context
self._action = action
self._standalone = standalone
self._stopping = False self._stopping = False
self._reader = None self._reader = None
self._writer = None self._writer = None
@ -28,7 +37,7 @@ class Client:
self._buffer_chunk_size = buffer_chunk_size self._buffer_chunk_size = buffer_chunk_size
# How many chunks in buffer # How many chunks in buffer
self._buffer_length_limit = buffer_length_limit self._buffer_length_limit = buffer_length_limit
self._file_path = None self._file_path = file_path
self._working = False self._working = False
self._token = token self._token = token
self._password = password self._password = password
@ -36,6 +45,7 @@ class Client:
self._encryption_complete = False self._encryption_complete = False
self._file_name = None self._file_name = None
self._file_size = None self._file_size = None
self._file_size_string = None
@property @property
def host(self) -> str: def host(self) -> str:
@ -45,6 +55,21 @@ class Client:
def port(self) -> int: def port(self) -> int:
return self._port return self._port
@property
def action(self) -> str:
"""Client role.
Possible values:
- `send`
- `receive`
"""
return self._action
@property
def standalone(self) -> bool:
"""Tell whether client should run as server as well."""
return self._standalone
@property @property
def stopping(self) -> bool: def stopping(self) -> bool:
return self._stopping return self._stopping
@ -101,52 +126,128 @@ class Client:
def file_size(self): def file_size(self):
return self._file_size return self._file_size
async def run_client(self, file_path, action): @property
self._file_path = file_path def file_size_string(self):
if action == 'send': return self._file_size_string
file_name = os.path.basename(os.path.abspath(file_path))
file_size = os.path.getsize(os.path.abspath(file_path)) async def run_client(self) -> None:
if self.action == 'send':
file_name = os.path.basename(os.path.abspath(self.file_path))
file_size = os.path.getsize(os.path.abspath(self.file_path))
# File size increases after encryption
# "Salted_" (8 bytes) + salt (8 bytes)
# Then, 1-16 bytes are added to make file_size a multiple of 16
# i.e., (32 - file_size mod 16) bytes are added to original size
if self.password:
file_size += 32 - (file_size % 16)
self.set_file_information( self.set_file_information(
file_name=file_name, file_name=file_name,
file_size=file_size file_size=file_size
) )
try: if self.standalone:
reader, writer = await asyncio.open_connection( server = await asyncio.start_server(
ssl=self.ssl_context,
client_connected_cb=self._connect,
host=self.host, host=self.host,
port=self.port, port=self.port,
ssl=self.ssl_context
) )
self._reader = reader async with server:
self._writer = writer logging.info("Running at `{s.host}:{s.port}`".format(s=self))
except ConnectionRefusedError as exception: await server.serve_forever()
logging.error(exception) else:
return try:
writer.write( reader, writer = await asyncio.open_connection(
( host=self.host,
f"s|{self.token}|" port=self.port,
f"{self.file_name}|{self.file_size}\n".encode('utf-8') ssl=self.ssl_context
) if action == 'send' )
else f"r|{self.token}\n".encode('utf-8') except (ConnectionRefusedError, ConnectionResetError) as exception:
) logging.error(f"Connection error: {exception}")
await writer.drain() return
await self.connect(reader=reader, writer=writer)
async def _connect(self, reader: asyncio.StreamReader,
writer: asyncio.StreamWriter):
try:
return await self.connect(reader, writer)
except KeyboardInterrupt:
print()
except Exception as e:
logging.error(e)
async def connect(self,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter):
self._reader = reader
self._writer = writer
async def _write(message: Union[list, str, bytes],
terminate_line=True) -> int:
"""Framework for `asyncio.StreamWriter.write` method.
Create string from list, encode it, send and drain writer.
Return 0 on success, 1 on error.
"""
# Adapt
if type(message) is list:
message = '|'.join(map(str, message))
if type(message) is str:
if terminate_line:
message += '\n'
message = message.encode('utf-8')
if type(message) is not bytes:
return 1
try:
writer.write(message)
await writer.drain()
except ConnectionResetError:
logging.error("Client disconnected.")
except Exception as e:
logging.error(f"Unexpected exception:\n{e}", exc_info=True)
else:
return 0 # On success, return 0
# On exception, return 1
return 1
if self.action == 'send' or not self.standalone:
if await _write(
[self.action[0], self.token,
self.file_name, self.file_size]
):
return
# Wait for server start signal # Wait for server start signal
while 1: while 1:
server_hello = await reader.readline() server_hello = await self.reader.readline()
if not server_hello: if not server_hello:
logging.info("Server disconnected.") logging.error("Server disconnected.")
return return
server_hello = server_hello.decode('utf-8').strip('\n').split('|') server_hello = server_hello.decode('utf-8').strip('\n').split('|')
if action == 'receive' and server_hello[0] == 's': if self.action == 'receive' and server_hello[0] == 's':
self.set_file_information(file_name=server_hello[2], self.set_file_information(file_name=server_hello[2],
file_size=server_hello[3]) file_size=server_hello[3])
elif (
self.standalone
and self.action == 'send'
and server_hello[0] == 'r'
):
# Check token
if server_hello[1] != self.token:
if await _write("Invalid session token!"):
return
return
elif server_hello[0] == 'start!': elif server_hello[0] == 'start!':
break break
else: else:
logging.info(f"Server said: {'|'.join(server_hello)}") logging.info(f"Server said: {'|'.join(server_hello)}")
if action == 'send': if self.standalone:
await self.send(writer=writer) if await _write("start!"):
return
break
if self.action == 'send':
await self.send(writer=self.writer)
else: else:
await self.receive(reader=reader) await self.receive(reader=self.reader)
async def encrypt_file(self, input_file, output_file): async def encrypt_file(self, input_file, output_file):
self._encryption_complete = False self._encryption_complete = False
@ -201,28 +302,27 @@ class Client:
break break
try: try:
writer.write(output_data) writer.write(output_data)
await writer.drain() await asyncio.wait_for(writer.drain(), timeout=3.0)
except ConnectionResetError: except ConnectionResetError:
print() # New line after progress_bar
logging.error('Server closed the connection.') logging.error('Server closed the connection.')
self.stop() self.stop()
break break
bytes_sent += self.buffer_chunk_size except asyncio.exceptions.TimeoutError:
print() # New line after progress_bar
logging.error('Server closed the connection.')
self.stop()
break
bytes_sent += len(output_data)
new_progress = min( new_progress = min(
int(bytes_sent / self.file_size * 100), int(bytes_sent / self.file_size * 100),
100 100
) )
progress_showed = (new_progress // 10) * 10 self.print_progress_bar(
sys.stdout.write( progress=new_progress,
f"\t\t\tSending `{self.file_name}`: " bytes_=bytes_sent,
f"{'#' * (progress_showed // 10)}"
f"{'.' * ((100 - progress_showed) // 10)}\t"
f"{new_progress}% completed "
f"({min(bytes_sent, self.file_size) // 1000} "
f"of {self.file_size // 1000} KB)\r"
) )
sys.stdout.flush() print() # New line after progress_bar
sys.stdout.write('\n')
sys.stdout.flush()
writer.close() writer.close()
return return
@ -242,26 +342,19 @@ class Client:
bytes_received = 0 bytes_received = 0
while not self.stopping: while not self.stopping:
input_data = await reader.read(self.buffer_chunk_size) input_data = await reader.read(self.buffer_chunk_size)
bytes_received += self.buffer_chunk_size bytes_received += len(input_data)
new_progress = min( new_progress = min(
int(bytes_received / self.file_size * 100), int(bytes_received / self.file_size * 100),
100 100
) )
progress_showed = (new_progress // 10) * 10 self.print_progress_bar(
sys.stdout.write( progress=new_progress,
f"\t\t\tReceiving `{self.file_name}`: " bytes_=bytes_received
f"{'#' * (progress_showed // 10)}"
f"{'.' * ((100 - progress_showed) // 10)}\t"
f"{new_progress}% completed "
f"({min(bytes_received, self.file_size) // 1000} "
f"of {self.file_size // 1000} KB)\r"
) )
sys.stdout.flush()
if not input_data: if not input_data:
break break
file_to_receive.write(input_data) file_to_receive.write(input_data)
sys.stdout.write('\n') print() # New line after sys.stdout.write
sys.stdout.flush()
logging.info("File received.") logging.info("File received.")
if self.password: if self.password:
logging.info("Decrypting file...") logging.info("Decrypting file...")
@ -289,7 +382,8 @@ class Client:
if self.working: if self.working:
logging.info("Received interruption signal, stopping...") logging.info("Received interruption signal, stopping...")
self._stopping = True self._stopping = True
self.writer.close() if self.writer:
self.writer.close()
else: else:
raise KeyboardInterrupt("Not working yet...") raise KeyboardInterrupt("Not working yet...")
@ -298,24 +392,62 @@ class Client:
self._file_name = file_name self._file_name = file_name
if file_size is not None: if file_size is not None:
self._file_size = int(file_size) self._file_size = int(file_size)
self._file_size_string = utilities.get_file_size_representation(
self.file_size
)
def run(self, file_path, action): def run(self):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
try: try:
loop.run_until_complete( loop.run_until_complete(
self.run_client(file_path=file_path, self.run_client()
action=action)
) )
except KeyboardInterrupt: except KeyboardInterrupt:
print()
logging.error("Interrupted") logging.error("Interrupted")
for task in asyncio.all_tasks(loop): for task in asyncio.all_tasks(loop):
task.cancel() task.cancel()
self.writer.close() if self.writer:
self.writer.close()
loop.run_until_complete( loop.run_until_complete(
self.writer.wait_closed() self.wait_closed()
) )
loop.close() loop.close()
def print_progress_bar(self, progress: int, bytes_: int):
"""Print client progress bar.
`progress` % = `bytes_string` transferred
out of `self.file_size_string`.
"""
action = {
'send': "Sending",
'receive': "Receiving"
}[self.action]
bytes_string = utilities.get_file_size_representation(
bytes_
)
utilities.print_progress_bar(
prefix=f"\t\t\t{action} `{self.file_name}`: ",
done_symbol='#',
pending_symbol='.',
progress=progress,
scale=5,
suffix=(
" completed "
f"({bytes_string} "
f"of {self.file_size_string})"
)
)
@staticmethod
async def wait_closed() -> None:
"""Give time to cancelled tasks to end properly.
Sleep .1 second and return.
"""
await asyncio.sleep(.1)
def get_action(action): def get_action(action):
"""Parse abbreviations for `action`.""" """Parse abbreviations for `action`."""
@ -380,6 +512,11 @@ def main():
default=None, default=None,
required=False, required=False,
help='server SSL certificate') help='server SSL certificate')
cli_parser.add_argument('--key', type=str,
default=None,
required=False,
help='server SSL key (required only for '
'SSL-secured standalone client)')
cli_parser.add_argument('--action', type=str, cli_parser.add_argument('--action', type=str,
default=None, default=None,
required=False, required=False,
@ -397,6 +534,9 @@ def main():
required=False, required=False,
help='Session token ' help='Session token '
'(must be the same for both clients)') '(must be the same for both clients)')
cli_parser.add_argument('--standalone',
action='store_true',
help='Run both as client and server')
cli_parser.add_argument('others', cli_parser.add_argument('others',
metavar='R or S', metavar='R or S',
nargs='*', nargs='*',
@ -405,10 +545,12 @@ def main():
host = args['host'] host = args['host']
port = args['port'] port = args['port']
certificate = args['certificate'] certificate = args['certificate']
key = args['key']
action = get_action(args['action']) action = get_action(args['action'])
file_path = args['path'] file_path = args['path']
password = args['password'] password = args['password']
token = args['token'] token = args['token']
standalone = args['standalone']
# If host and port are not provided from command-line, try to import them # If host and port are not provided from command-line, try to import them
sys.path.append(os.path.abspath('.')) sys.path.append(os.path.abspath('.'))
@ -455,6 +597,11 @@ def main():
from config import certificate from config import certificate
except ImportError: except ImportError:
certificate = None certificate = None
if key is None or not os.path.isfile(key):
try:
from config import key
except ImportError:
key = None
# If import fails, prompt user for host or port # If import fails, prompt user for host or port
new_settings = {} # After getting these settings, offer to store them new_settings = {} # After getting these settings, offer to store them
@ -540,17 +687,18 @@ def main():
logging.info("Configuration values stored.") logging.info("Configuration values stored.")
else: else:
logging.info("Proceeding without storing values...") logging.info("Proceeding without storing values...")
client = Client( ssl_context = None
host=host,
port=port,
password=password,
token=token
)
if certificate is not None: if certificate is not None:
_ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) if key is None: # Server-dependent client
_ssl_context.check_hostname = False ssl_context = ssl.create_default_context(
_ssl_context.load_verify_locations(certificate) purpose=ssl.Purpose.SERVER_AUTH
client.set_ssl_context(_ssl_context) )
ssl_context.load_verify_locations(certificate)
else: # Standalone client
ssl_context = ssl.create_default_context(
purpose=ssl.Purpose.CLIENT_AUTH
)
ssl_context.load_cert_chain(certificate, key)
else: else:
logging.warning( logging.warning(
"Please consider using SSL. To do so, add in `config.py` or " "Please consider using SSL. To do so, add in `config.py` or "
@ -559,10 +707,17 @@ def main():
"certificate = 'path/to/certificate.crt'" "certificate = 'path/to/certificate.crt'"
) )
logging.info("Starting client...") logging.info("Starting client...")
client.run( client = Client(
host=host,
port=port,
ssl_context=ssl_context,
action=action,
standalone=standalone,
file_path=file_path, file_path=file_path,
action=action password=password,
token=token
) )
client.run()
logging.info("Stopped client") logging.info("Stopped client")

View File

@ -13,10 +13,11 @@ from typing import Union
class Server: class Server:
def __init__(self, host='localhost', port=5000, def __init__(self, host='localhost', port=5000, ssl_context=None,
buffer_chunk_size=10 ** 4, buffer_length_limit=10 ** 4): buffer_chunk_size=10 ** 4, buffer_length_limit=10 ** 4):
self._host = host self._host = host
self._port = port self._port = port
self._ssl_context = ssl_context
self.connections = collections.OrderedDict() self.connections = collections.OrderedDict()
# Dict of queues of bytes # Dict of queues of bytes
self.buffers = collections.OrderedDict() self.buffers = collections.OrderedDict()
@ -87,27 +88,24 @@ class Server:
async def run_writer(self, writer, connection_token): async def run_writer(self, writer, connection_token):
consecutive_interruptions = 0 consecutive_interruptions = 0
errors = 0 errors = 0
while 1: while connection_token in self.buffers:
try: try:
try: input_data = self.buffers[connection_token].popleft()
if connection_token not in self.buffers: except IndexError:
break # Slow down if buffer is empty; after 1.5 s of silence, break
input_data = self.buffers[connection_token].popleft() consecutive_interruptions += 1
except IndexError: if consecutive_interruptions > 3:
# Slow down if buffer is short
consecutive_interruptions += 1
if consecutive_interruptions > 3:
break
await asyncio.sleep(.5)
continue
else:
consecutive_interruptions = 0
if not input_data:
break break
await asyncio.sleep(.5)
continue
else:
consecutive_interruptions = 0
if not input_data:
break
try:
writer.write(input_data) writer.write(input_data)
await writer.drain() await writer.drain()
except ConnectionResetError as e: except ConnectionResetError as e:
logging.error("Here")
logging.error(e) logging.error(e)
break break
except Exception as e: except Exception as e:
@ -127,7 +125,7 @@ class Server:
""" """
client_hello = await reader.readline() client_hello = await reader.readline()
client_hello = client_hello.decode('utf-8').strip('\n').split('|') client_hello = client_hello.decode('utf-8').strip('\n').split('|')
if len(client_hello) not in (2, 4,): if len(client_hello) != 4:
await self.refuse_connection(writer=writer, await self.refuse_connection(writer=writer,
message="Invalid client_hello!") message="Invalid client_hello!")
return return
@ -142,7 +140,7 @@ class Server:
terminate_line=True) -> int: terminate_line=True) -> int:
# Adapt # Adapt
if type(message) is list: if type(message) is list:
message = '|'.join(message) message = '|'.join(map(str, message))
if type(message) is str: if type(message) is str:
if terminate_line: if terminate_line:
message += '\n' message += '\n'
@ -211,12 +209,13 @@ class Server:
index = 0 index = 0
await asyncio.sleep(.5) await asyncio.sleep(.5)
# Send file information and start signal to client # Send file information and start signal to client
writer.write( if await _write(
"s|hidden_token|" ['s',
f"{self.connections[connection_token]['file_name']}|" 'hidden_token',
f"{self.connections[connection_token]['file_size']}" self.connections[connection_token]['file_name'],
"\n".encode('utf-8') self.connections[connection_token]['file_size']]
) ):
return
if await _write("start!"): if await _write("start!"):
return return
await self.run_writer(writer=writer, await self.run_writer(writer=writer,
@ -238,6 +237,7 @@ class Server:
try: try:
loop.run_until_complete(self.run_server()) loop.run_until_complete(self.run_server())
except KeyboardInterrupt: except KeyboardInterrupt:
print()
logging.info("Stopping...") logging.info("Stopping...")
# Cancel connection tasks (they should be done but are pending) # Cancel connection tasks (they should be done but are pending)
for task in asyncio.all_tasks(loop): for task in asyncio.all_tasks(loop):
@ -336,10 +336,6 @@ def main():
logging.info("Invalid port. Enter a valid port number!") logging.info("Invalid port. Enter a valid port number!")
port = None port = None
server = Server(
host=host,
port=port,
)
try: try:
if certificate is None or not os.path.isfile(certificate): if certificate is None or not os.path.isfile(certificate):
from config import certificate from config import certificate
@ -350,12 +346,12 @@ def main():
if not os.path.isfile(key): if not os.path.isfile(key):
key = None key = None
except ImportError: except ImportError:
pass certificate = None
key = None
ssl_context = None
if certificate and key: if certificate and key:
_ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
_ssl_context.check_hostname = False ssl_context.load_cert_chain(certificate, key)
_ssl_context.load_cert_chain(certificate, key)
server.set_ssl_context(_ssl_context)
else: else:
logging.warning( logging.warning(
"Please consider using SSL. To do so, add in `config.py` or " "Please consider using SSL. To do so, add in `config.py` or "
@ -364,6 +360,11 @@ def main():
"key = 'path/to/secret.key'\n" "key = 'path/to/secret.key'\n"
"certificate = 'path/to/certificate.crt'" "certificate = 'path/to/certificate.crt'"
) )
server = Server(
host=host,
port=port,
ssl_context=ssl_context
)
server.run() server.run()

View File

@ -2,11 +2,51 @@
import logging import logging
import signal import signal
import sys
units_of_measurements = {
1: 'bytes',
1000: 'KB',
1000 * 1000: 'MB',
1000 * 1000 * 1000: 'GB',
1000 * 1000 * 1000 * 1000: 'TB',
}
def get_file_size_representation(file_size):
scale, unit = get_scale_and_unit(file_size=file_size)
if scale < 10:
return f"{file_size} {unit}"
return f"{(file_size // (scale / 100)) / 100:.2f} {unit}"
def get_scale_and_unit(file_size):
scale, unit = min(units_of_measurements.items())
for scale, unit in sorted(units_of_measurements.items(), reverse=True):
if file_size > scale:
break
return scale, unit
def print_progress_bar(prefix='',
suffix='',
done_symbol="#",
pending_symbol=".",
progress=0,
scale=10):
progress_showed = (progress // scale) * scale
sys.stdout.write(
f"{prefix}"
f"{done_symbol * (progress_showed // scale)}"
f"{pending_symbol * ((100 - progress_showed) // scale)}\t"
f"{progress}%"
f"{suffix} \r"
)
sys.stdout.flush()
def timed_input(message: str = None, def timed_input(message: str = None,
timeout: int = 5): timeout: int = 5):
class TimeoutExpired(Exception): class TimeoutExpired(Exception):
pass pass