Skip to content

Commit b69ec97

Browse files
authored
fix: make workflow run listener more resilient (hatchet-dev#56)
* fix: update requirements to indicate python version >=3.10 * fix: improve the workflow run listener * chore: linting
1 parent 464bc0f commit b69ec97

File tree

4 files changed

+188
-65
lines changed

4 files changed

+188
-65
lines changed

hatchet_sdk/clients/dispatcher.py

+8-22
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
import random
55
import threading
66
import time
7-
from typing import Any, AsyncGenerator, Callable, List, Union
7+
from typing import Any, AsyncGenerator, List
88

99
import grpc
1010
from grpc._cython import cygrpc
1111

12+
from hatchet_sdk.clients.event_ts import Event_ts, read_with_interrupt
1213
from hatchet_sdk.connection import new_conn
1314

1415
from ..dispatcher_pb2 import (
@@ -110,25 +111,6 @@ def unregister(self):
110111
START_GET_GROUP_KEY = 2
111112

112113

113-
class Event_ts(asyncio.Event):
114-
def __init__(self, *args, **kwargs):
115-
super().__init__(*args, **kwargs)
116-
if self._loop is None:
117-
self._loop = asyncio.get_event_loop()
118-
119-
def set(self):
120-
self._loop.call_soon_threadsafe(super().set)
121-
122-
def clear(self):
123-
self._loop.call_soon_threadsafe(super().clear)
124-
125-
126-
async def read_action(listener: Any, interrupt: Event_ts):
127-
assigned_action = await listener.read()
128-
interrupt.set()
129-
return assigned_action
130-
131-
132114
async def exp_backoff_sleep(attempt: int, max_sleep_time: float = 5):
133115
base_time = 0.1 # starting sleep time in seconds (100 milliseconds)
134116
jitter = random.uniform(0, base_time) # add random jitter
@@ -220,12 +202,16 @@ async def _generator(self) -> AsyncGenerator[Action, None]:
220202
try:
221203
while True:
222204
self.interrupt = Event_ts()
223-
t = asyncio.create_task(read_action(listener, self.interrupt))
205+
t = asyncio.create_task(
206+
read_with_interrupt(listener, self.interrupt)
207+
)
224208
await self.interrupt.wait()
225209

226210
if not t.done():
227211
# print a warning
228-
logger.warning("Interrupted read_action task")
212+
logger.warning(
213+
"Interrupted read_with_interrupt task of action listener"
214+
)
229215

230216
t.cancel()
231217
listener.cancel()

hatchet_sdk/clients/event_ts.py

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import asyncio
2+
from typing import Any
3+
4+
5+
class Event_ts(asyncio.Event):
6+
"""
7+
Event_ts is a subclass of asyncio.Event that allows for thread-safe setting and clearing of the event.
8+
"""
9+
10+
def __init__(self, *args, **kwargs):
11+
super().__init__(*args, **kwargs)
12+
if self._loop is None:
13+
self._loop = asyncio.get_event_loop()
14+
15+
def set(self):
16+
self._loop.call_soon_threadsafe(super().set)
17+
18+
def clear(self):
19+
self._loop.call_soon_threadsafe(super().clear)
20+
21+
22+
async def read_with_interrupt(listener: Any, interrupt: Event_ts):
23+
result = await listener.read()
24+
interrupt.set()
25+
return result

hatchet_sdk/clients/workflow_listener.py

+154-42
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import asyncio
22
import json
3+
import time
34
from collections.abc import AsyncIterator
45
from typing import AsyncGenerator
56

67
import grpc
78

9+
from hatchet_sdk.clients.event_ts import Event_ts, read_with_interrupt
810
from hatchet_sdk.connection import new_conn
911

1012
from ..dispatcher_pb2 import SubscribeToWorkflowRunsRequest, WorkflowRunEvent
@@ -15,14 +17,52 @@
1517

1618
DEFAULT_WORKFLOW_LISTENER_RETRY_INTERVAL = 1 # seconds
1719
DEFAULT_WORKFLOW_LISTENER_RETRY_COUNT = 5
20+
DEFAULT_WORKFLOW_LISTENER_INTERRUPT_INTERVAL = 1800 # 30 minutes
21+
22+
23+
class _Subscription:
24+
def __init__(self, id: int, workflow_run_id: str):
25+
self.id = id
26+
self.workflow_run_id = workflow_run_id
27+
self.queue: asyncio.Queue[WorkflowRunEvent | None] = asyncio.Queue()
28+
29+
async def __aiter__(self):
30+
return self
31+
32+
async def __anext__(self) -> WorkflowRunEvent:
33+
return await self.queue.get()
34+
35+
async def get(self) -> WorkflowRunEvent:
36+
event = await self.queue.get()
37+
38+
if event is None:
39+
raise StopAsyncIteration
40+
41+
return event
42+
43+
async def put(self, item: WorkflowRunEvent):
44+
await self.queue.put(item)
45+
46+
async def close(self):
47+
await self.queue.put(None)
1848

1949

2050
class PooledWorkflowRunListener:
51+
# list of all active subscriptions, mapping from a subscription id to a workflow run id
52+
subscriptionsToWorkflows: dict[int, str] = {}
53+
54+
# list of workflow run ids mapped to an array of subscription ids
55+
workflowsToSubscriptions: dict[str, list[int]] = {}
56+
57+
subscription_counter: int = 0
58+
subscription_counter_lock: asyncio.Lock = asyncio.Lock()
59+
2160
requests: asyncio.Queue[SubscribeToWorkflowRunsRequest] = asyncio.Queue()
2261

2362
listener: AsyncGenerator[WorkflowRunEvent, None] = None
2463

25-
events: dict[str, asyncio.Queue[WorkflowRunEvent]] = {}
64+
# events have keys of the format workflow_run_id + subscription_id
65+
events: dict[int, _Subscription] = {}
2666

2767
def __init__(self, config: ClientConfig):
2868
conn = new_conn(config, True)
@@ -35,30 +75,77 @@ def abort(self):
3575
self.stop_signal = True
3676
self.requests.put_nowait(False)
3777

78+
async def _interrupter(self):
79+
"""
80+
_interrupter runs in a separate thread and interrupts the listener according to a configurable duration.
81+
"""
82+
await asyncio.sleep(DEFAULT_WORKFLOW_LISTENER_INTERRUPT_INTERVAL)
83+
84+
if self.interrupt is not None:
85+
self.interrupt.set()
86+
3887
async def _init_producer(self):
3988
try:
4089
if not self.listener:
41-
self.listener = await self._retry_subscribe()
42-
logger.debug(f"Workflow run listener connected.")
43-
async for workflow_event in self.listener:
44-
if workflow_event.workflowRunId in self.events:
45-
self.events[workflow_event.workflowRunId].put_nowait(
46-
workflow_event
47-
)
48-
else:
49-
logger.warning(
50-
f"Received event for unknown workflow: {workflow_event.workflowRunId}"
51-
)
90+
while True:
91+
try:
92+
self.listener = await self._retry_subscribe()
93+
94+
logger.debug(f"Workflow run listener connected.")
95+
96+
# spawn an interrupter task
97+
asyncio.create_task(self._interrupter())
98+
99+
while True:
100+
self.interrupt = Event_ts()
101+
t = asyncio.create_task(
102+
read_with_interrupt(self.listener, self.interrupt)
103+
)
104+
await self.interrupt.wait()
105+
106+
if not t.done():
107+
# print a warning
108+
logger.warning(
109+
"Interrupted read_with_interrupt task of workflow run listener"
110+
)
111+
112+
t.cancel()
113+
self.listener.cancel()
114+
break
115+
116+
workflow_event: WorkflowRunEvent = t.result()
117+
118+
# get a list of subscriptions for this workflow
119+
subscriptions = self.workflowsToSubscriptions.get(
120+
workflow_event.workflowRunId, []
121+
)
122+
123+
for subscription_id in subscriptions:
124+
await self.events[subscription_id].put(workflow_event)
125+
126+
except grpc.RpcError as e:
127+
logger.error(f"grpc error in workflow run listener: {e}")
128+
continue
52129
except Exception as e:
53130
logger.error(f"Error in workflow run listener: {e}")
131+
54132
self.listener = None
55133

56-
# signal all subscribers to stop
57-
# FIXME this is a bit of a hack, ideally we re-establish the listener and re-subscribe
58-
for key in self.events.keys():
59-
self.events[key].put_nowait(False)
134+
# close all subscriptions
135+
for subscription_id in self.events:
136+
await self.events[subscription_id].close()
137+
138+
raise e
60139

61140
async def _request(self) -> AsyncIterator[SubscribeToWorkflowRunsRequest]:
141+
# replay all existing subscriptions
142+
workflow_run_set = set(self.subscriptionsToWorkflows.values())
143+
144+
for workflow_run_id in workflow_run_set:
145+
yield SubscribeToWorkflowRunsRequest(
146+
workflowRunId=workflow_run_id,
147+
)
148+
62149
while True:
63150
request = await self.requests.get()
64151

@@ -69,44 +156,69 @@ async def _request(self) -> AsyncIterator[SubscribeToWorkflowRunsRequest]:
69156
yield request
70157
self.requests.task_done()
71158

159+
def cleanup_subscription(self, subscription_id: int):
160+
workflow_run_id = self.subscriptionsToWorkflows[subscription_id]
161+
162+
if workflow_run_id in self.workflowsToSubscriptions:
163+
self.workflowsToSubscriptions[workflow_run_id].remove(subscription_id)
164+
165+
del self.subscriptionsToWorkflows[subscription_id]
166+
del self.events[subscription_id]
167+
72168
async def subscribe(self, workflow_run_id: str):
73-
self.events[workflow_run_id] = asyncio.Queue()
169+
try:
170+
# create a new subscription id, place a mutex on the counter
171+
await self.subscription_counter_lock.acquire()
172+
self.subscription_counter += 1
173+
subscription_id = self.subscription_counter
174+
self.subscription_counter_lock.release()
74175

75-
asyncio.create_task(self._init_producer())
176+
self.subscriptionsToWorkflows[subscription_id] = workflow_run_id
76177

77-
await self.requests.put(
78-
SubscribeToWorkflowRunsRequest(
79-
workflowRunId=workflow_run_id,
178+
if workflow_run_id not in self.workflowsToSubscriptions:
179+
self.workflowsToSubscriptions[workflow_run_id] = [subscription_id]
180+
else:
181+
self.workflowsToSubscriptions[workflow_run_id].append(subscription_id)
182+
183+
self.events[subscription_id] = _Subscription(
184+
subscription_id, workflow_run_id
80185
)
81-
)
82186

83-
while True:
84-
event = await self.events[workflow_run_id].get()
85-
if event is False:
86-
break
87-
if event.workflowRunId == workflow_run_id:
88-
yield event
89-
break # FIXME this should only break on terminal events... but we're not broadcasting event types
187+
asyncio.create_task(self._init_producer())
188+
189+
await self.requests.put(
190+
SubscribeToWorkflowRunsRequest(
191+
workflowRunId=workflow_run_id,
192+
)
193+
)
194+
195+
event = await self.events[subscription_id].get()
90196

91-
del self.events[workflow_run_id]
197+
self.cleanup_subscription(subscription_id)
198+
199+
return event
200+
except asyncio.CancelledError:
201+
self.cleanup_subscription(subscription_id)
202+
raise
92203

93204
async def result(self, workflow_run_id: str):
94-
async for event in self.subscribe(workflow_run_id):
95-
errors = []
205+
event = await self.subscribe(workflow_run_id)
206+
207+
errors = []
96208

97-
if event.results:
98-
errors = [result.error for result in event.results if result.error]
209+
if event.results:
210+
errors = [result.error for result in event.results if result.error]
99211

100-
if errors:
101-
raise Exception(f"Workflow Errors: {errors}")
212+
if errors:
213+
raise Exception(f"Workflow Errors: {errors}")
102214

103-
results = {
104-
result.stepReadableId: json.loads(result.output)
105-
for result in event.results
106-
if result.output
107-
}
215+
results = {
216+
result.stepReadableId: json.loads(result.output)
217+
for result in event.results
218+
if result.output
219+
}
108220

109-
return results
221+
return results
110222

111223
async def _retry_subscribe(self):
112224
retries = 0

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "hatchet-sdk"
3-
version = "0.27.0"
3+
version = "0.27.1"
44
description = ""
55
authors = ["Alexander Belanger <[email protected]>"]
66
readme = "README.md"

0 commit comments

Comments
 (0)