Refactoring
This commit is contained in:
parent
3b7aa265ab
commit
3f5384f9e9
@ -9,16 +9,25 @@ import random
|
|||||||
import ssl
|
import ssl
|
||||||
import string
|
import string
|
||||||
import sys
|
import sys
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
from . import utilities
|
from . import utilities
|
||||||
|
|
||||||
|
|
||||||
class Client:
|
class Client:
|
||||||
def __init__(self, host='localhost', port=3001,
|
def __init__(self, host='localhost', port=5000, ssl_context=None,
|
||||||
buffer_chunk_size=10**4, buffer_length_limit=10**4,
|
action=None,
|
||||||
password=None, token=None):
|
standalone=False,
|
||||||
|
buffer_chunk_size=10 ** 4,
|
||||||
|
buffer_length_limit=10 ** 4,
|
||||||
|
file_path=None,
|
||||||
|
password=None,
|
||||||
|
token=None):
|
||||||
self._host = host
|
self._host = host
|
||||||
self._port = port
|
self._port = port
|
||||||
|
self._ssl_context = ssl_context
|
||||||
|
self._action = action
|
||||||
|
self._standalone = standalone
|
||||||
self._stopping = False
|
self._stopping = False
|
||||||
self._reader = None
|
self._reader = None
|
||||||
self._writer = None
|
self._writer = None
|
||||||
@ -28,7 +37,7 @@ class Client:
|
|||||||
self._buffer_chunk_size = buffer_chunk_size
|
self._buffer_chunk_size = buffer_chunk_size
|
||||||
# How many chunks in buffer
|
# How many chunks in buffer
|
||||||
self._buffer_length_limit = buffer_length_limit
|
self._buffer_length_limit = buffer_length_limit
|
||||||
self._file_path = None
|
self._file_path = file_path
|
||||||
self._working = False
|
self._working = False
|
||||||
self._token = token
|
self._token = token
|
||||||
self._password = password
|
self._password = password
|
||||||
@ -36,6 +45,7 @@ class Client:
|
|||||||
self._encryption_complete = False
|
self._encryption_complete = False
|
||||||
self._file_name = None
|
self._file_name = None
|
||||||
self._file_size = None
|
self._file_size = None
|
||||||
|
self._file_size_string = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def host(self) -> str:
|
def host(self) -> str:
|
||||||
@ -45,6 +55,21 @@ class Client:
|
|||||||
def port(self) -> int:
|
def port(self) -> int:
|
||||||
return self._port
|
return self._port
|
||||||
|
|
||||||
|
@property
|
||||||
|
def action(self) -> str:
|
||||||
|
"""Client role.
|
||||||
|
|
||||||
|
Possible values:
|
||||||
|
- `send`
|
||||||
|
- `receive`
|
||||||
|
"""
|
||||||
|
return self._action
|
||||||
|
|
||||||
|
@property
|
||||||
|
def standalone(self) -> bool:
|
||||||
|
"""Tell whether client should run as server as well."""
|
||||||
|
return self._standalone
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def stopping(self) -> bool:
|
def stopping(self) -> bool:
|
||||||
return self._stopping
|
return self._stopping
|
||||||
@ -101,52 +126,128 @@ class Client:
|
|||||||
def file_size(self):
|
def file_size(self):
|
||||||
return self._file_size
|
return self._file_size
|
||||||
|
|
||||||
async def run_client(self, file_path, action):
|
@property
|
||||||
self._file_path = file_path
|
def file_size_string(self):
|
||||||
if action == 'send':
|
return self._file_size_string
|
||||||
file_name = os.path.basename(os.path.abspath(file_path))
|
|
||||||
file_size = os.path.getsize(os.path.abspath(file_path))
|
async def run_client(self) -> None:
|
||||||
|
if self.action == 'send':
|
||||||
|
file_name = os.path.basename(os.path.abspath(self.file_path))
|
||||||
|
file_size = os.path.getsize(os.path.abspath(self.file_path))
|
||||||
|
# File size increases after encryption
|
||||||
|
# "Salted_" (8 bytes) + salt (8 bytes)
|
||||||
|
# Then, 1-16 bytes are added to make file_size a multiple of 16
|
||||||
|
# i.e., (32 - file_size mod 16) bytes are added to original size
|
||||||
|
if self.password:
|
||||||
|
file_size += 32 - (file_size % 16)
|
||||||
self.set_file_information(
|
self.set_file_information(
|
||||||
file_name=file_name,
|
file_name=file_name,
|
||||||
file_size=file_size
|
file_size=file_size
|
||||||
)
|
)
|
||||||
try:
|
if self.standalone:
|
||||||
reader, writer = await asyncio.open_connection(
|
server = await asyncio.start_server(
|
||||||
|
ssl=self.ssl_context,
|
||||||
|
client_connected_cb=self._connect,
|
||||||
host=self.host,
|
host=self.host,
|
||||||
port=self.port,
|
port=self.port,
|
||||||
ssl=self.ssl_context
|
|
||||||
)
|
)
|
||||||
self._reader = reader
|
async with server:
|
||||||
self._writer = writer
|
logging.info("Running at `{s.host}:{s.port}`".format(s=self))
|
||||||
except ConnectionRefusedError as exception:
|
await server.serve_forever()
|
||||||
logging.error(exception)
|
else:
|
||||||
return
|
try:
|
||||||
writer.write(
|
reader, writer = await asyncio.open_connection(
|
||||||
(
|
host=self.host,
|
||||||
f"s|{self.token}|"
|
port=self.port,
|
||||||
f"{self.file_name}|{self.file_size}\n".encode('utf-8')
|
ssl=self.ssl_context
|
||||||
) if action == 'send'
|
)
|
||||||
else f"r|{self.token}\n".encode('utf-8')
|
except (ConnectionRefusedError, ConnectionResetError) as exception:
|
||||||
)
|
logging.error(f"Connection error: {exception}")
|
||||||
await writer.drain()
|
return
|
||||||
|
await self.connect(reader=reader, writer=writer)
|
||||||
|
|
||||||
|
async def _connect(self, reader: asyncio.StreamReader,
|
||||||
|
writer: asyncio.StreamWriter):
|
||||||
|
try:
|
||||||
|
return await self.connect(reader, writer)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print()
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(e)
|
||||||
|
|
||||||
|
async def connect(self,
|
||||||
|
reader: asyncio.StreamReader,
|
||||||
|
writer: asyncio.StreamWriter):
|
||||||
|
self._reader = reader
|
||||||
|
self._writer = writer
|
||||||
|
|
||||||
|
async def _write(message: Union[list, str, bytes],
|
||||||
|
terminate_line=True) -> int:
|
||||||
|
"""Framework for `asyncio.StreamWriter.write` method.
|
||||||
|
|
||||||
|
Create string from list, encode it, send and drain writer.
|
||||||
|
Return 0 on success, 1 on error.
|
||||||
|
"""
|
||||||
|
# Adapt
|
||||||
|
if type(message) is list:
|
||||||
|
message = '|'.join(map(str, message))
|
||||||
|
if type(message) is str:
|
||||||
|
if terminate_line:
|
||||||
|
message += '\n'
|
||||||
|
message = message.encode('utf-8')
|
||||||
|
if type(message) is not bytes:
|
||||||
|
return 1
|
||||||
|
try:
|
||||||
|
writer.write(message)
|
||||||
|
await writer.drain()
|
||||||
|
except ConnectionResetError:
|
||||||
|
logging.error("Client disconnected.")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Unexpected exception:\n{e}", exc_info=True)
|
||||||
|
else:
|
||||||
|
return 0 # On success, return 0
|
||||||
|
# On exception, return 1
|
||||||
|
return 1
|
||||||
|
|
||||||
|
if self.action == 'send' or not self.standalone:
|
||||||
|
if await _write(
|
||||||
|
[self.action[0], self.token,
|
||||||
|
self.file_name, self.file_size]
|
||||||
|
):
|
||||||
|
return
|
||||||
# Wait for server start signal
|
# Wait for server start signal
|
||||||
while 1:
|
while 1:
|
||||||
server_hello = await reader.readline()
|
server_hello = await self.reader.readline()
|
||||||
if not server_hello:
|
if not server_hello:
|
||||||
logging.info("Server disconnected.")
|
logging.error("Server disconnected.")
|
||||||
return
|
return
|
||||||
server_hello = server_hello.decode('utf-8').strip('\n').split('|')
|
server_hello = server_hello.decode('utf-8').strip('\n').split('|')
|
||||||
if action == 'receive' and server_hello[0] == 's':
|
if self.action == 'receive' and server_hello[0] == 's':
|
||||||
|
|
||||||
self.set_file_information(file_name=server_hello[2],
|
self.set_file_information(file_name=server_hello[2],
|
||||||
file_size=server_hello[3])
|
file_size=server_hello[3])
|
||||||
|
elif (
|
||||||
|
self.standalone
|
||||||
|
and self.action == 'send'
|
||||||
|
and server_hello[0] == 'r'
|
||||||
|
):
|
||||||
|
# Check token
|
||||||
|
if server_hello[1] != self.token:
|
||||||
|
if await _write("Invalid session token!"):
|
||||||
|
return
|
||||||
|
return
|
||||||
elif server_hello[0] == 'start!':
|
elif server_hello[0] == 'start!':
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
logging.info(f"Server said: {'|'.join(server_hello)}")
|
logging.info(f"Server said: {'|'.join(server_hello)}")
|
||||||
if action == 'send':
|
if self.standalone:
|
||||||
await self.send(writer=writer)
|
if await _write("start!"):
|
||||||
|
return
|
||||||
|
break
|
||||||
|
if self.action == 'send':
|
||||||
|
await self.send(writer=self.writer)
|
||||||
else:
|
else:
|
||||||
await self.receive(reader=reader)
|
await self.receive(reader=self.reader)
|
||||||
|
|
||||||
async def encrypt_file(self, input_file, output_file):
|
async def encrypt_file(self, input_file, output_file):
|
||||||
self._encryption_complete = False
|
self._encryption_complete = False
|
||||||
@ -201,28 +302,27 @@ class Client:
|
|||||||
break
|
break
|
||||||
try:
|
try:
|
||||||
writer.write(output_data)
|
writer.write(output_data)
|
||||||
await writer.drain()
|
await asyncio.wait_for(writer.drain(), timeout=3.0)
|
||||||
except ConnectionResetError:
|
except ConnectionResetError:
|
||||||
|
print() # New line after progress_bar
|
||||||
logging.error('Server closed the connection.')
|
logging.error('Server closed the connection.')
|
||||||
self.stop()
|
self.stop()
|
||||||
break
|
break
|
||||||
bytes_sent += self.buffer_chunk_size
|
except asyncio.exceptions.TimeoutError:
|
||||||
|
print() # New line after progress_bar
|
||||||
|
logging.error('Server closed the connection.')
|
||||||
|
self.stop()
|
||||||
|
break
|
||||||
|
bytes_sent += len(output_data)
|
||||||
new_progress = min(
|
new_progress = min(
|
||||||
int(bytes_sent / self.file_size * 100),
|
int(bytes_sent / self.file_size * 100),
|
||||||
100
|
100
|
||||||
)
|
)
|
||||||
progress_showed = (new_progress // 10) * 10
|
self.print_progress_bar(
|
||||||
sys.stdout.write(
|
progress=new_progress,
|
||||||
f"\t\t\tSending `{self.file_name}`: "
|
bytes_=bytes_sent,
|
||||||
f"{'#' * (progress_showed // 10)}"
|
|
||||||
f"{'.' * ((100 - progress_showed) // 10)}\t"
|
|
||||||
f"{new_progress}% completed "
|
|
||||||
f"({min(bytes_sent, self.file_size) // 1000} "
|
|
||||||
f"of {self.file_size // 1000} KB)\r"
|
|
||||||
)
|
)
|
||||||
sys.stdout.flush()
|
print() # New line after progress_bar
|
||||||
sys.stdout.write('\n')
|
|
||||||
sys.stdout.flush()
|
|
||||||
writer.close()
|
writer.close()
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -242,26 +342,19 @@ class Client:
|
|||||||
bytes_received = 0
|
bytes_received = 0
|
||||||
while not self.stopping:
|
while not self.stopping:
|
||||||
input_data = await reader.read(self.buffer_chunk_size)
|
input_data = await reader.read(self.buffer_chunk_size)
|
||||||
bytes_received += self.buffer_chunk_size
|
bytes_received += len(input_data)
|
||||||
new_progress = min(
|
new_progress = min(
|
||||||
int(bytes_received / self.file_size * 100),
|
int(bytes_received / self.file_size * 100),
|
||||||
100
|
100
|
||||||
)
|
)
|
||||||
progress_showed = (new_progress // 10) * 10
|
self.print_progress_bar(
|
||||||
sys.stdout.write(
|
progress=new_progress,
|
||||||
f"\t\t\tReceiving `{self.file_name}`: "
|
bytes_=bytes_received
|
||||||
f"{'#' * (progress_showed // 10)}"
|
|
||||||
f"{'.' * ((100 - progress_showed) // 10)}\t"
|
|
||||||
f"{new_progress}% completed "
|
|
||||||
f"({min(bytes_received, self.file_size) // 1000} "
|
|
||||||
f"of {self.file_size // 1000} KB)\r"
|
|
||||||
)
|
)
|
||||||
sys.stdout.flush()
|
|
||||||
if not input_data:
|
if not input_data:
|
||||||
break
|
break
|
||||||
file_to_receive.write(input_data)
|
file_to_receive.write(input_data)
|
||||||
sys.stdout.write('\n')
|
print() # New line after sys.stdout.write
|
||||||
sys.stdout.flush()
|
|
||||||
logging.info("File received.")
|
logging.info("File received.")
|
||||||
if self.password:
|
if self.password:
|
||||||
logging.info("Decrypting file...")
|
logging.info("Decrypting file...")
|
||||||
@ -289,7 +382,8 @@ class Client:
|
|||||||
if self.working:
|
if self.working:
|
||||||
logging.info("Received interruption signal, stopping...")
|
logging.info("Received interruption signal, stopping...")
|
||||||
self._stopping = True
|
self._stopping = True
|
||||||
self.writer.close()
|
if self.writer:
|
||||||
|
self.writer.close()
|
||||||
else:
|
else:
|
||||||
raise KeyboardInterrupt("Not working yet...")
|
raise KeyboardInterrupt("Not working yet...")
|
||||||
|
|
||||||
@ -298,24 +392,62 @@ class Client:
|
|||||||
self._file_name = file_name
|
self._file_name = file_name
|
||||||
if file_size is not None:
|
if file_size is not None:
|
||||||
self._file_size = int(file_size)
|
self._file_size = int(file_size)
|
||||||
|
self._file_size_string = utilities.get_file_size_representation(
|
||||||
|
self.file_size
|
||||||
|
)
|
||||||
|
|
||||||
def run(self, file_path, action):
|
def run(self):
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
try:
|
try:
|
||||||
loop.run_until_complete(
|
loop.run_until_complete(
|
||||||
self.run_client(file_path=file_path,
|
self.run_client()
|
||||||
action=action)
|
|
||||||
)
|
)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
|
print()
|
||||||
logging.error("Interrupted")
|
logging.error("Interrupted")
|
||||||
for task in asyncio.all_tasks(loop):
|
for task in asyncio.all_tasks(loop):
|
||||||
task.cancel()
|
task.cancel()
|
||||||
self.writer.close()
|
if self.writer:
|
||||||
|
self.writer.close()
|
||||||
loop.run_until_complete(
|
loop.run_until_complete(
|
||||||
self.writer.wait_closed()
|
self.wait_closed()
|
||||||
)
|
)
|
||||||
loop.close()
|
loop.close()
|
||||||
|
|
||||||
|
def print_progress_bar(self, progress: int, bytes_: int):
|
||||||
|
"""Print client progress bar.
|
||||||
|
|
||||||
|
`progress` % = `bytes_string` transferred
|
||||||
|
out of `self.file_size_string`.
|
||||||
|
"""
|
||||||
|
action = {
|
||||||
|
'send': "Sending",
|
||||||
|
'receive': "Receiving"
|
||||||
|
}[self.action]
|
||||||
|
bytes_string = utilities.get_file_size_representation(
|
||||||
|
bytes_
|
||||||
|
)
|
||||||
|
utilities.print_progress_bar(
|
||||||
|
prefix=f"\t\t\t{action} `{self.file_name}`: ",
|
||||||
|
done_symbol='#',
|
||||||
|
pending_symbol='.',
|
||||||
|
progress=progress,
|
||||||
|
scale=5,
|
||||||
|
suffix=(
|
||||||
|
" completed "
|
||||||
|
f"({bytes_string} "
|
||||||
|
f"of {self.file_size_string})"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def wait_closed() -> None:
|
||||||
|
"""Give time to cancelled tasks to end properly.
|
||||||
|
|
||||||
|
Sleep .1 second and return.
|
||||||
|
"""
|
||||||
|
await asyncio.sleep(.1)
|
||||||
|
|
||||||
|
|
||||||
def get_action(action):
|
def get_action(action):
|
||||||
"""Parse abbreviations for `action`."""
|
"""Parse abbreviations for `action`."""
|
||||||
@ -380,6 +512,11 @@ def main():
|
|||||||
default=None,
|
default=None,
|
||||||
required=False,
|
required=False,
|
||||||
help='server SSL certificate')
|
help='server SSL certificate')
|
||||||
|
cli_parser.add_argument('--key', type=str,
|
||||||
|
default=None,
|
||||||
|
required=False,
|
||||||
|
help='server SSL key (required only for '
|
||||||
|
'SSL-secured standalone client)')
|
||||||
cli_parser.add_argument('--action', type=str,
|
cli_parser.add_argument('--action', type=str,
|
||||||
default=None,
|
default=None,
|
||||||
required=False,
|
required=False,
|
||||||
@ -397,6 +534,9 @@ def main():
|
|||||||
required=False,
|
required=False,
|
||||||
help='Session token '
|
help='Session token '
|
||||||
'(must be the same for both clients)')
|
'(must be the same for both clients)')
|
||||||
|
cli_parser.add_argument('--standalone',
|
||||||
|
action='store_true',
|
||||||
|
help='Run both as client and server')
|
||||||
cli_parser.add_argument('others',
|
cli_parser.add_argument('others',
|
||||||
metavar='R or S',
|
metavar='R or S',
|
||||||
nargs='*',
|
nargs='*',
|
||||||
@ -405,10 +545,12 @@ def main():
|
|||||||
host = args['host']
|
host = args['host']
|
||||||
port = args['port']
|
port = args['port']
|
||||||
certificate = args['certificate']
|
certificate = args['certificate']
|
||||||
|
key = args['key']
|
||||||
action = get_action(args['action'])
|
action = get_action(args['action'])
|
||||||
file_path = args['path']
|
file_path = args['path']
|
||||||
password = args['password']
|
password = args['password']
|
||||||
token = args['token']
|
token = args['token']
|
||||||
|
standalone = args['standalone']
|
||||||
|
|
||||||
# If host and port are not provided from command-line, try to import them
|
# If host and port are not provided from command-line, try to import them
|
||||||
sys.path.append(os.path.abspath('.'))
|
sys.path.append(os.path.abspath('.'))
|
||||||
@ -455,6 +597,11 @@ def main():
|
|||||||
from config import certificate
|
from config import certificate
|
||||||
except ImportError:
|
except ImportError:
|
||||||
certificate = None
|
certificate = None
|
||||||
|
if key is None or not os.path.isfile(key):
|
||||||
|
try:
|
||||||
|
from config import key
|
||||||
|
except ImportError:
|
||||||
|
key = None
|
||||||
|
|
||||||
# If import fails, prompt user for host or port
|
# If import fails, prompt user for host or port
|
||||||
new_settings = {} # After getting these settings, offer to store them
|
new_settings = {} # After getting these settings, offer to store them
|
||||||
@ -540,17 +687,18 @@ def main():
|
|||||||
logging.info("Configuration values stored.")
|
logging.info("Configuration values stored.")
|
||||||
else:
|
else:
|
||||||
logging.info("Proceeding without storing values...")
|
logging.info("Proceeding without storing values...")
|
||||||
client = Client(
|
ssl_context = None
|
||||||
host=host,
|
|
||||||
port=port,
|
|
||||||
password=password,
|
|
||||||
token=token
|
|
||||||
)
|
|
||||||
if certificate is not None:
|
if certificate is not None:
|
||||||
_ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
|
if key is None: # Server-dependent client
|
||||||
_ssl_context.check_hostname = False
|
ssl_context = ssl.create_default_context(
|
||||||
_ssl_context.load_verify_locations(certificate)
|
purpose=ssl.Purpose.SERVER_AUTH
|
||||||
client.set_ssl_context(_ssl_context)
|
)
|
||||||
|
ssl_context.load_verify_locations(certificate)
|
||||||
|
else: # Standalone client
|
||||||
|
ssl_context = ssl.create_default_context(
|
||||||
|
purpose=ssl.Purpose.CLIENT_AUTH
|
||||||
|
)
|
||||||
|
ssl_context.load_cert_chain(certificate, key)
|
||||||
else:
|
else:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
"Please consider using SSL. To do so, add in `config.py` or "
|
"Please consider using SSL. To do so, add in `config.py` or "
|
||||||
@ -559,10 +707,17 @@ def main():
|
|||||||
"certificate = 'path/to/certificate.crt'"
|
"certificate = 'path/to/certificate.crt'"
|
||||||
)
|
)
|
||||||
logging.info("Starting client...")
|
logging.info("Starting client...")
|
||||||
client.run(
|
client = Client(
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
ssl_context=ssl_context,
|
||||||
|
action=action,
|
||||||
|
standalone=standalone,
|
||||||
file_path=file_path,
|
file_path=file_path,
|
||||||
action=action
|
password=password,
|
||||||
|
token=token
|
||||||
)
|
)
|
||||||
|
client.run()
|
||||||
logging.info("Stopped client")
|
logging.info("Stopped client")
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,10 +13,11 @@ from typing import Union
|
|||||||
|
|
||||||
|
|
||||||
class Server:
|
class Server:
|
||||||
def __init__(self, host='localhost', port=5000,
|
def __init__(self, host='localhost', port=5000, ssl_context=None,
|
||||||
buffer_chunk_size=10 ** 4, buffer_length_limit=10 ** 4):
|
buffer_chunk_size=10 ** 4, buffer_length_limit=10 ** 4):
|
||||||
self._host = host
|
self._host = host
|
||||||
self._port = port
|
self._port = port
|
||||||
|
self._ssl_context = ssl_context
|
||||||
self.connections = collections.OrderedDict()
|
self.connections = collections.OrderedDict()
|
||||||
# Dict of queues of bytes
|
# Dict of queues of bytes
|
||||||
self.buffers = collections.OrderedDict()
|
self.buffers = collections.OrderedDict()
|
||||||
@ -87,27 +88,24 @@ class Server:
|
|||||||
async def run_writer(self, writer, connection_token):
|
async def run_writer(self, writer, connection_token):
|
||||||
consecutive_interruptions = 0
|
consecutive_interruptions = 0
|
||||||
errors = 0
|
errors = 0
|
||||||
while 1:
|
while connection_token in self.buffers:
|
||||||
try:
|
try:
|
||||||
try:
|
input_data = self.buffers[connection_token].popleft()
|
||||||
if connection_token not in self.buffers:
|
except IndexError:
|
||||||
break
|
# Slow down if buffer is empty; after 1.5 s of silence, break
|
||||||
input_data = self.buffers[connection_token].popleft()
|
consecutive_interruptions += 1
|
||||||
except IndexError:
|
if consecutive_interruptions > 3:
|
||||||
# Slow down if buffer is short
|
|
||||||
consecutive_interruptions += 1
|
|
||||||
if consecutive_interruptions > 3:
|
|
||||||
break
|
|
||||||
await asyncio.sleep(.5)
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
consecutive_interruptions = 0
|
|
||||||
if not input_data:
|
|
||||||
break
|
break
|
||||||
|
await asyncio.sleep(.5)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
consecutive_interruptions = 0
|
||||||
|
if not input_data:
|
||||||
|
break
|
||||||
|
try:
|
||||||
writer.write(input_data)
|
writer.write(input_data)
|
||||||
await writer.drain()
|
await writer.drain()
|
||||||
except ConnectionResetError as e:
|
except ConnectionResetError as e:
|
||||||
logging.error("Here")
|
|
||||||
logging.error(e)
|
logging.error(e)
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -127,7 +125,7 @@ class Server:
|
|||||||
"""
|
"""
|
||||||
client_hello = await reader.readline()
|
client_hello = await reader.readline()
|
||||||
client_hello = client_hello.decode('utf-8').strip('\n').split('|')
|
client_hello = client_hello.decode('utf-8').strip('\n').split('|')
|
||||||
if len(client_hello) not in (2, 4,):
|
if len(client_hello) != 4:
|
||||||
await self.refuse_connection(writer=writer,
|
await self.refuse_connection(writer=writer,
|
||||||
message="Invalid client_hello!")
|
message="Invalid client_hello!")
|
||||||
return
|
return
|
||||||
@ -142,7 +140,7 @@ class Server:
|
|||||||
terminate_line=True) -> int:
|
terminate_line=True) -> int:
|
||||||
# Adapt
|
# Adapt
|
||||||
if type(message) is list:
|
if type(message) is list:
|
||||||
message = '|'.join(message)
|
message = '|'.join(map(str, message))
|
||||||
if type(message) is str:
|
if type(message) is str:
|
||||||
if terminate_line:
|
if terminate_line:
|
||||||
message += '\n'
|
message += '\n'
|
||||||
@ -211,12 +209,13 @@ class Server:
|
|||||||
index = 0
|
index = 0
|
||||||
await asyncio.sleep(.5)
|
await asyncio.sleep(.5)
|
||||||
# Send file information and start signal to client
|
# Send file information and start signal to client
|
||||||
writer.write(
|
if await _write(
|
||||||
"s|hidden_token|"
|
['s',
|
||||||
f"{self.connections[connection_token]['file_name']}|"
|
'hidden_token',
|
||||||
f"{self.connections[connection_token]['file_size']}"
|
self.connections[connection_token]['file_name'],
|
||||||
"\n".encode('utf-8')
|
self.connections[connection_token]['file_size']]
|
||||||
)
|
):
|
||||||
|
return
|
||||||
if await _write("start!"):
|
if await _write("start!"):
|
||||||
return
|
return
|
||||||
await self.run_writer(writer=writer,
|
await self.run_writer(writer=writer,
|
||||||
@ -238,6 +237,7 @@ class Server:
|
|||||||
try:
|
try:
|
||||||
loop.run_until_complete(self.run_server())
|
loop.run_until_complete(self.run_server())
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
|
print()
|
||||||
logging.info("Stopping...")
|
logging.info("Stopping...")
|
||||||
# Cancel connection tasks (they should be done but are pending)
|
# Cancel connection tasks (they should be done but are pending)
|
||||||
for task in asyncio.all_tasks(loop):
|
for task in asyncio.all_tasks(loop):
|
||||||
@ -336,10 +336,6 @@ def main():
|
|||||||
logging.info("Invalid port. Enter a valid port number!")
|
logging.info("Invalid port. Enter a valid port number!")
|
||||||
port = None
|
port = None
|
||||||
|
|
||||||
server = Server(
|
|
||||||
host=host,
|
|
||||||
port=port,
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
if certificate is None or not os.path.isfile(certificate):
|
if certificate is None or not os.path.isfile(certificate):
|
||||||
from config import certificate
|
from config import certificate
|
||||||
@ -350,12 +346,12 @@ def main():
|
|||||||
if not os.path.isfile(key):
|
if not os.path.isfile(key):
|
||||||
key = None
|
key = None
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
certificate = None
|
||||||
|
key = None
|
||||||
|
ssl_context = None
|
||||||
if certificate and key:
|
if certificate and key:
|
||||||
_ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
|
||||||
_ssl_context.check_hostname = False
|
ssl_context.load_cert_chain(certificate, key)
|
||||||
_ssl_context.load_cert_chain(certificate, key)
|
|
||||||
server.set_ssl_context(_ssl_context)
|
|
||||||
else:
|
else:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
"Please consider using SSL. To do so, add in `config.py` or "
|
"Please consider using SSL. To do so, add in `config.py` or "
|
||||||
@ -364,6 +360,11 @@ def main():
|
|||||||
"key = 'path/to/secret.key'\n"
|
"key = 'path/to/secret.key'\n"
|
||||||
"certificate = 'path/to/certificate.crt'"
|
"certificate = 'path/to/certificate.crt'"
|
||||||
)
|
)
|
||||||
|
server = Server(
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
ssl_context=ssl_context
|
||||||
|
)
|
||||||
server.run()
|
server.run()
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,11 +2,51 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import signal
|
import signal
|
||||||
|
import sys
|
||||||
|
|
||||||
|
units_of_measurements = {
|
||||||
|
1: 'bytes',
|
||||||
|
1000: 'KB',
|
||||||
|
1000 * 1000: 'MB',
|
||||||
|
1000 * 1000 * 1000: 'GB',
|
||||||
|
1000 * 1000 * 1000 * 1000: 'TB',
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_file_size_representation(file_size):
|
||||||
|
scale, unit = get_scale_and_unit(file_size=file_size)
|
||||||
|
if scale < 10:
|
||||||
|
return f"{file_size} {unit}"
|
||||||
|
return f"{(file_size // (scale / 100)) / 100:.2f} {unit}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_scale_and_unit(file_size):
|
||||||
|
scale, unit = min(units_of_measurements.items())
|
||||||
|
for scale, unit in sorted(units_of_measurements.items(), reverse=True):
|
||||||
|
if file_size > scale:
|
||||||
|
break
|
||||||
|
return scale, unit
|
||||||
|
|
||||||
|
|
||||||
|
def print_progress_bar(prefix='',
|
||||||
|
suffix='',
|
||||||
|
done_symbol="#",
|
||||||
|
pending_symbol=".",
|
||||||
|
progress=0,
|
||||||
|
scale=10):
|
||||||
|
progress_showed = (progress // scale) * scale
|
||||||
|
sys.stdout.write(
|
||||||
|
f"{prefix}"
|
||||||
|
f"{done_symbol * (progress_showed // scale)}"
|
||||||
|
f"{pending_symbol * ((100 - progress_showed) // scale)}\t"
|
||||||
|
f"{progress}%"
|
||||||
|
f"{suffix} \r"
|
||||||
|
)
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
|
||||||
def timed_input(message: str = None,
|
def timed_input(message: str = None,
|
||||||
timeout: int = 5):
|
timeout: int = 5):
|
||||||
|
|
||||||
class TimeoutExpired(Exception):
|
class TimeoutExpired(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user