1
1
import asyncio
2
2
import json
3
+ import time
3
4
from collections .abc import AsyncIterator
4
5
from typing import AsyncGenerator
5
6
6
7
import grpc
7
8
9
+ from hatchet_sdk .clients .event_ts import Event_ts , read_with_interrupt
8
10
from hatchet_sdk .connection import new_conn
9
11
10
12
from ..dispatcher_pb2 import SubscribeToWorkflowRunsRequest , WorkflowRunEvent
15
17
16
18
DEFAULT_WORKFLOW_LISTENER_RETRY_INTERVAL = 1 # seconds
17
19
DEFAULT_WORKFLOW_LISTENER_RETRY_COUNT = 5
20
+ DEFAULT_WORKFLOW_LISTENER_INTERRUPT_INTERVAL = 1800 # 30 minutes
21
+
22
+
23
+ class _Subscription :
24
+ def __init__ (self , id : int , workflow_run_id : str ):
25
+ self .id = id
26
+ self .workflow_run_id = workflow_run_id
27
+ self .queue : asyncio .Queue [WorkflowRunEvent | None ] = asyncio .Queue ()
28
+
29
+ async def __aiter__ (self ):
30
+ return self
31
+
32
+ async def __anext__ (self ) -> WorkflowRunEvent :
33
+ return await self .queue .get ()
34
+
35
+ async def get (self ) -> WorkflowRunEvent :
36
+ event = await self .queue .get ()
37
+
38
+ if event is None :
39
+ raise StopAsyncIteration
40
+
41
+ return event
42
+
43
+ async def put (self , item : WorkflowRunEvent ):
44
+ await self .queue .put (item )
45
+
46
+ async def close (self ):
47
+ await self .queue .put (None )
18
48
19
49
20
50
class PooledWorkflowRunListener :
51
+ # list of all active subscriptions, mapping from a subscription id to a workflow run id
52
+ subscriptionsToWorkflows : dict [int , str ] = {}
53
+
54
+ # list of workflow run ids mapped to an array of subscription ids
55
+ workflowsToSubscriptions : dict [str , list [int ]] = {}
56
+
57
+ subscription_counter : int = 0
58
+ subscription_counter_lock : asyncio .Lock = asyncio .Lock ()
59
+
21
60
requests : asyncio .Queue [SubscribeToWorkflowRunsRequest ] = asyncio .Queue ()
22
61
23
62
listener : AsyncGenerator [WorkflowRunEvent , None ] = None
24
63
25
- events : dict [str , asyncio .Queue [WorkflowRunEvent ]] = {}
64
+ # events have keys of the format workflow_run_id + subscription_id
65
+ events : dict [int , _Subscription ] = {}
26
66
27
67
def __init__ (self , config : ClientConfig ):
28
68
conn = new_conn (config , True )
@@ -35,30 +75,77 @@ def abort(self):
35
75
self .stop_signal = True
36
76
self .requests .put_nowait (False )
37
77
78
+ async def _interrupter (self ):
79
+ """
80
+ _interrupter runs in a separate thread and interrupts the listener according to a configurable duration.
81
+ """
82
+ await asyncio .sleep (DEFAULT_WORKFLOW_LISTENER_INTERRUPT_INTERVAL )
83
+
84
+ if self .interrupt is not None :
85
+ self .interrupt .set ()
86
+
38
87
async def _init_producer (self ):
39
88
try :
40
89
if not self .listener :
41
- self .listener = await self ._retry_subscribe ()
42
- logger .debug (f"Workflow run listener connected." )
43
- async for workflow_event in self .listener :
44
- if workflow_event .workflowRunId in self .events :
45
- self .events [workflow_event .workflowRunId ].put_nowait (
46
- workflow_event
47
- )
48
- else :
49
- logger .warning (
50
- f"Received event for unknown workflow: { workflow_event .workflowRunId } "
51
- )
90
+ while True :
91
+ try :
92
+ self .listener = await self ._retry_subscribe ()
93
+
94
+ logger .debug (f"Workflow run listener connected." )
95
+
96
+ # spawn an interrupter task
97
+ asyncio .create_task (self ._interrupter ())
98
+
99
+ while True :
100
+ self .interrupt = Event_ts ()
101
+ t = asyncio .create_task (
102
+ read_with_interrupt (self .listener , self .interrupt )
103
+ )
104
+ await self .interrupt .wait ()
105
+
106
+ if not t .done ():
107
+ # print a warning
108
+ logger .warning (
109
+ "Interrupted read_with_interrupt task of workflow run listener"
110
+ )
111
+
112
+ t .cancel ()
113
+ self .listener .cancel ()
114
+ break
115
+
116
+ workflow_event : WorkflowRunEvent = t .result ()
117
+
118
+ # get a list of subscriptions for this workflow
119
+ subscriptions = self .workflowsToSubscriptions .get (
120
+ workflow_event .workflowRunId , []
121
+ )
122
+
123
+ for subscription_id in subscriptions :
124
+ await self .events [subscription_id ].put (workflow_event )
125
+
126
+ except grpc .RpcError as e :
127
+ logger .error (f"grpc error in workflow run listener: { e } " )
128
+ continue
52
129
except Exception as e :
53
130
logger .error (f"Error in workflow run listener: { e } " )
131
+
54
132
self .listener = None
55
133
56
- # signal all subscribers to stop
57
- # FIXME this is a bit of a hack, ideally we re-establish the listener and re-subscribe
58
- for key in self .events .keys ():
59
- self .events [key ].put_nowait (False )
134
+ # close all subscriptions
135
+ for subscription_id in self .events :
136
+ await self .events [subscription_id ].close ()
137
+
138
+ raise e
60
139
61
140
async def _request (self ) -> AsyncIterator [SubscribeToWorkflowRunsRequest ]:
141
+ # replay all existing subscriptions
142
+ workflow_run_set = set (self .subscriptionsToWorkflows .values ())
143
+
144
+ for workflow_run_id in workflow_run_set :
145
+ yield SubscribeToWorkflowRunsRequest (
146
+ workflowRunId = workflow_run_id ,
147
+ )
148
+
62
149
while True :
63
150
request = await self .requests .get ()
64
151
@@ -69,44 +156,69 @@ async def _request(self) -> AsyncIterator[SubscribeToWorkflowRunsRequest]:
69
156
yield request
70
157
self .requests .task_done ()
71
158
159
+ def cleanup_subscription (self , subscription_id : int ):
160
+ workflow_run_id = self .subscriptionsToWorkflows [subscription_id ]
161
+
162
+ if workflow_run_id in self .workflowsToSubscriptions :
163
+ self .workflowsToSubscriptions [workflow_run_id ].remove (subscription_id )
164
+
165
+ del self .subscriptionsToWorkflows [subscription_id ]
166
+ del self .events [subscription_id ]
167
+
72
168
async def subscribe (self , workflow_run_id : str ):
73
- self .events [workflow_run_id ] = asyncio .Queue ()
169
+ try :
170
+ # create a new subscription id, place a mutex on the counter
171
+ await self .subscription_counter_lock .acquire ()
172
+ self .subscription_counter += 1
173
+ subscription_id = self .subscription_counter
174
+ self .subscription_counter_lock .release ()
74
175
75
- asyncio . create_task ( self ._init_producer ())
176
+ self .subscriptionsToWorkflows [ subscription_id ] = workflow_run_id
76
177
77
- await self .requests .put (
78
- SubscribeToWorkflowRunsRequest (
79
- workflowRunId = workflow_run_id ,
178
+ if workflow_run_id not in self .workflowsToSubscriptions :
179
+ self .workflowsToSubscriptions [workflow_run_id ] = [subscription_id ]
180
+ else :
181
+ self .workflowsToSubscriptions [workflow_run_id ].append (subscription_id )
182
+
183
+ self .events [subscription_id ] = _Subscription (
184
+ subscription_id , workflow_run_id
80
185
)
81
- )
82
186
83
- while True :
84
- event = await self .events [workflow_run_id ].get ()
85
- if event is False :
86
- break
87
- if event .workflowRunId == workflow_run_id :
88
- yield event
89
- break # FIXME this should only break on terminal events... but we're not broadcasting event types
187
+ asyncio .create_task (self ._init_producer ())
188
+
189
+ await self .requests .put (
190
+ SubscribeToWorkflowRunsRequest (
191
+ workflowRunId = workflow_run_id ,
192
+ )
193
+ )
194
+
195
+ event = await self .events [subscription_id ].get ()
90
196
91
- del self .events [workflow_run_id ]
197
+ self .cleanup_subscription (subscription_id )
198
+
199
+ return event
200
+ except asyncio .CancelledError :
201
+ self .cleanup_subscription (subscription_id )
202
+ raise
92
203
93
204
async def result (self , workflow_run_id : str ):
94
- async for event in self .subscribe (workflow_run_id ):
95
- errors = []
205
+ event = await self .subscribe (workflow_run_id )
206
+
207
+ errors = []
96
208
97
- if event .results :
98
- errors = [result .error for result in event .results if result .error ]
209
+ if event .results :
210
+ errors = [result .error for result in event .results if result .error ]
99
211
100
- if errors :
101
- raise Exception (f"Workflow Errors: { errors } " )
212
+ if errors :
213
+ raise Exception (f"Workflow Errors: { errors } " )
102
214
103
- results = {
104
- result .stepReadableId : json .loads (result .output )
105
- for result in event .results
106
- if result .output
107
- }
215
+ results = {
216
+ result .stepReadableId : json .loads (result .output )
217
+ for result in event .results
218
+ if result .output
219
+ }
108
220
109
- return results
221
+ return results
110
222
111
223
async def _retry_subscribe (self ):
112
224
retries = 0
0 commit comments