This commit is contained in:
Davte 2020-04-17 21:09:46 +02:00
parent 2ff847b44b
commit f21a7fdfb9
2 changed files with 28 additions and 14 deletions

View File

@ -40,25 +40,27 @@ class Client:
buffer_length_limit=10 ** 4, buffer_length_limit=10 ** 4,
file_path=None, file_path=None,
password=None, password=None,
token=None): token=None,
ssl_handshake_timeout=None):
self._host = host self._host = host
self._port = port self._port = port
self._ssl_context = ssl_context self._ssl_context = ssl_context
self._action = action self._action = action
self._standalone = standalone self._standalone = standalone
self._stopping = False self._file_path = file_path
self._reader = None self._password = password
self._writer = None self._token = token
# Shared queue of bytes self._ssl_handshake_timeout = ssl_handshake_timeout
self.buffer = collections.deque()
# How many bytes per chunk # How many bytes per chunk
self._buffer_chunk_size = buffer_chunk_size self._buffer_chunk_size = buffer_chunk_size
# How many chunks in buffer # How many chunks in buffer
self._buffer_length_limit = buffer_length_limit self._buffer_length_limit = buffer_length_limit
self._file_path = file_path # Shared queue of bytes
self.buffer = collections.deque()
self._working = False self._working = False
self._token = token self._stopping = False
self._password = password self._reader = None
self._writer = None
self._encryption_complete = False self._encryption_complete = False
self._file_name = None self._file_name = None
self._file_size = None self._file_size = None
@ -138,6 +140,16 @@ class Client:
def set_ssl_context(self, ssl_context: ssl.SSLContext): def set_ssl_context(self, ssl_context: ssl.SSLContext):
self._ssl_context = ssl_context self._ssl_context = ssl_context
@property
def ssl_handshake_timeout(self) -> Union[int, None]:
"""Return SSL handshake timeout.
If SSL context is not set, return None.
Otherwise, return seconds to wait before considering handshake failed.
"""
if self.ssl_context:
return self._ssl_handshake_timeout
@property @property
def token(self): def token(self):
"""Session token. """Session token.
@ -199,13 +211,13 @@ class Client:
host=self.host, host=self.host,
port=self.port, port=self.port,
ssl=self.ssl_context, ssl=self.ssl_context,
ssl_handshake_timeout=5 ssl_handshake_timeout=self.ssl_handshake_timeout
) )
except (ConnectionRefusedError, ConnectionResetError, except (ConnectionRefusedError, ConnectionResetError,
ConnectionAbortedError) as exception: ConnectionAbortedError) as exception:
logging.error(f"Connection error: {exception}") logging.error(f"Connection error: {exception}")
return return
except ssl.SSLCertVerificationError as exception: except ssl.SSLError as exception:
logging.error(f"SSL error: {exception}") logging.error(f"SSL error: {exception}")
return return
await self.connect(reader=reader, writer=writer) await self.connect(reader=reader, writer=writer)
@ -276,7 +288,7 @@ class Client:
while 1: while 1:
server_hello = await self.reader.readline() server_hello = await self.reader.readline()
if not server_hello: if not server_hello:
logging.error("Server disconnected.") logging.error("Server refused connection.")
return return
server_hello = server_hello.decode('utf-8').strip('\n').split('|') server_hello = server_hello.decode('utf-8').strip('\n').split('|')
if self.action == 'receive' and server_hello[0] == 's': if self.action == 'receive' and server_hello[0] == 's':

View File

@ -237,8 +237,10 @@ class Server:
return return
def disconnect(self, connection_token: str) -> None: def disconnect(self, connection_token: str) -> None:
del self.buffers[connection_token] if connection_token in self.buffers:
del self.connections[connection_token] del self.buffers[connection_token]
if connection_token in self.connections:
del self.connections[connection_token]
def run(self): def run(self):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()