Skip to content

Commit

Permalink
Added support for context managers
Browse files Browse the repository at this point in the history
  • Loading branch information
ForceFledgling committed Nov 26, 2023
1 parent 06e2e90 commit abb610e
Showing 1 changed file with 40 additions and 20 deletions.
60 changes: 40 additions & 20 deletions asyncio_telnet/asyncio_telnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def __init__(self, timeout=GLOBAL_DEFAULT_TIMEOUT):
"""
self.debuglevel = 0
self.timeout = timeout
self.fast_read_timeout = 0.1
self.read_timeout = timeout
self.fast_read_timeout = 2
self.reader = None
self.writer = None

Expand Down Expand Up @@ -107,21 +108,6 @@ async def read_until_eof(self, fast_mode=True):
break
return response

async def __aenter__(self):
"""
Is part of an asynchronous context manager.
"""
return self

async def __aexit__(self, exc_type, exc_value, traceback):
"""
Is part of an asynchronous context manager.
"""
if self.writer:
await self.writer.drain()
self.writer.close()
await self.writer.wait_closed()

async def filter_telnet_data(self, data):
"""
Removes TELNET control characters from the data.
Expand Down Expand Up @@ -172,8 +158,9 @@ class AsyncToSyncWrapper:
"""
Wraps an asynchronous class instance, allowing synchronous access to its methods.
"""
def __init__(self, async_class_instance):
def __init__(self, async_class_instance, loop=None):
self.async_class_instance = async_class_instance
self.loop = loop or asyncio.get_event_loop()

def __getattr__(self, name):
"""
Expand Down Expand Up @@ -206,8 +193,12 @@ def _sync_method(self, async_method, *args, **kwargs):
Returns:
The result of the asynchronous method.
"""
loop = asyncio.get_event_loop()
return loop.run_until_complete(async_method(*args, **kwargs))
try:
return self.loop.run_until_complete(async_method(*args, **kwargs))
finally:
if self.loop.is_running():
asyncio.run_coroutine_threadsafe(self.async_class_instance.close(), self.loop)
self.loop.stop()


class Telnet:
Expand All @@ -220,7 +211,7 @@ def __init__(self, timeout=GLOBAL_DEFAULT_TIMEOUT, sync_mode=False):
Args:
timeout: Timeout value for the Telnet connection.
sync_mode: If True, operates in synchronous mode using AsyncToSyncWrapper.
sync_mode: If True, operates in synchronous mode.
"""
if sync_mode:
self._instance = AsyncToSyncWrapper(AsyncTelnet(timeout))
Expand All @@ -238,3 +229,32 @@ def __getattr__(self, name):
The attribute from the underlying Telnet instance.
"""
return getattr(self._instance, name)

def __enter__(self):
"""
Implements the context manager protocol for entering the 'with' statement.
"""
return self

def __exit__(self, exc_type, exc_value, traceback):
"""
Implements the context manager protocol for exiting the 'with' statement.
"""
if self._instance.writer:
loop = asyncio.get_event_loop()
if not self._instance.writer.is_closing():
loop.run_until_complete(self._instance.writer.drain())
self._instance.writer.close()
loop.run_until_complete(self._instance.writer.wait_closed())

async def __aenter__(self):
"""
Implements the asynchronous context manager protocol for entering the 'async with' statement.
"""
return self

async def __aexit__(self, exc_type, exc_value, traceback):
"""
Implements the asynchronous context manager protocol for exiting the 'async with' statement.
"""
await self.close()

0 comments on commit abb610e

Please sign in to comment.