diff --git a/example_publisher/provider.py b/example_publisher/provider.py index 1238532..650b9a9 100644 --- a/example_publisher/provider.py +++ b/example_publisher/provider.py @@ -13,12 +13,14 @@ class Price: class Provider(ABC): + _update_loop_task = None + @abstractmethod def upd_products(self, product_symbols: List[Symbol]): ... def start(self) -> None: - asyncio.create_task(self._update_loop()) + self._update_loop_task = asyncio.create_task(self._update_loop()) @abstractmethod async def _update_loop(self): diff --git a/example_publisher/providers/pyth_replicator.py b/example_publisher/providers/pyth_replicator.py index 816319e..c4d1a61 100644 --- a/example_publisher/providers/pyth_replicator.py +++ b/example_publisher/providers/pyth_replicator.py @@ -31,6 +31,7 @@ def __init__(self, config: PythReplicatorConfig) -> None: self._prices: Dict[ str, Tuple[float | None, float | None, UnixTimestamp | None] ] = {} + self._update_accounts_task: asyncio.Task | None = None async def _update_loop(self) -> None: self._ws = self._client.create_watch_session() @@ -41,7 +42,7 @@ async def _update_loop(self) -> None: self._config.program_key, await self._client.get_all_accounts() ) - asyncio.create_task(self._update_accounts_loop()) + self._update_accounts_task = asyncio.create_task(self._update_accounts_loop()) while True: update = await self._ws.next_update() diff --git a/example_publisher/publisher.py b/example_publisher/publisher.py index d6b4858..8ef0fbe 100644 --- a/example_publisher/publisher.py +++ b/example_publisher/publisher.py @@ -27,6 +27,7 @@ class Product: class Publisher: def __init__(self, config: Config) -> None: self.config: Config = config + self._product_update_task: asyncio.Task | None = None if not getattr(self.config, self.config.provider_engine): raise ValueError(f"Missing {self.config.provider_engine} config") @@ -48,7 +49,9 @@ def __init__(self, config: Config) -> None: async def start(self): await self.pythd.connect() - asyncio.create_task(self._start_product_update_loop()) + self._product_update_task = asyncio.create_task( + self._start_product_update_loop() + ) async def _start_product_update_loop(self): await self._upd_products() diff --git a/example_publisher/pythd.py b/example_publisher/pythd.py index b560646..51e7b57 100644 --- a/example_publisher/pythd.py +++ b/example_publisher/pythd.py @@ -43,6 +43,7 @@ def __init__( self.address = address self.server: Server = None self.on_notify_price_sched = on_notify_price_sched + self._notify_price_sched_tasks = set() async def connect(self) -> Server: self.server = Server(self.address) @@ -69,7 +70,11 @@ async def subscribe_price_sched(self, account: str) -> int: def _notify_price_sched(self, subscription: int) -> None: log.debug("notify_price_sched RPC call received", subscription=subscription) - asyncio.get_event_loop().create_task(self.on_notify_price_sched(subscription)) + task = asyncio.get_event_loop().create_task( + self.on_notify_price_sched(subscription) + ) + self._notify_price_sched_tasks.add(task) + task.add_done_callback(lambda: self._notify_price_sched_tasks.remove(task)) async def all_products(self) -> List[Product]: result = await self.server.get_product_list() diff --git a/pyproject.toml b/pyproject.toml index f8ea5cc..1c491c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "example-publisher" -version = "1.0.1" +version = "1.0.2" description = "" authors = [] license = "Apache-2"