1919
2020import graphviz
2121
22- from ..agents import BaseAgent
22+ from ..agents import BaseAgent , SequentialAgent , LoopAgent , ParallelAgent
2323from ..agents .llm_agent import LlmAgent
2424from ..tools .agent_tool import AgentTool
2525from ..tools .base_tool import BaseTool
3535 retrieval_tool_module_loaded = True
3636
3737
38- async def build_graph (graph , agent : BaseAgent , highlight_pairs ):
38+ async def build_graph (graph : graphviz .Digraph , agent : BaseAgent , highlight_pairs , parent_agent = None ):
39+ """
40+ Build a graph of the agent and its sub-agents.
41+ Args:
42+ graph: The graph to build on.
43+ agent: The agent to build the graph for.
44+ highlight_pairs: A list of pairs of nodes to highlight.
45+ parent_agent: The parent agent of the current agent. This is specifically used when building Workflow Agents to directly connect a node to nodes inside a Workflow Agent.
46+
47+ Returns:
48+ None
49+ """
3950 dark_green = '#0F5223'
4051 light_green = '#69CB87'
4152 light_gray = '#cccccc'
53+ white = '#ffffff'
4254
4355 def get_node_name (tool_or_agent : Union [BaseAgent , BaseTool ]):
4456 if isinstance (tool_or_agent , BaseAgent ):
45- return tool_or_agent .name
57+ # Added Workflow Agent checks for different agent types
58+ if isinstance (tool_or_agent , SequentialAgent ):
59+ return tool_or_agent .name + f" (Sequential Agent)"
60+ elif isinstance (tool_or_agent , LoopAgent ):
61+ return tool_or_agent .name + f" (Loop Agent)"
62+ elif isinstance (tool_or_agent , ParallelAgent ):
63+ return tool_or_agent .name + f" (Parallel Agent)"
64+ else :
65+ return tool_or_agent .name
4666 elif isinstance (tool_or_agent , BaseTool ):
4767 return tool_or_agent .name
4868 else :
@@ -73,6 +93,7 @@ def get_node_caption(tool_or_agent: Union[BaseAgent, BaseTool]):
7393 def get_node_shape (tool_or_agent : Union [BaseAgent , BaseTool ]):
7494 if isinstance (tool_or_agent , BaseAgent ):
7595 return 'ellipse'
96+
7697 elif retrieval_tool_module_loaded and isinstance (
7798 tool_or_agent , BaseRetrievalTool
7899 ):
@@ -88,33 +109,120 @@ def get_node_shape(tool_or_agent: Union[BaseAgent, BaseTool]):
88109 tool_or_agent ,
89110 )
90111 return 'cylinder'
112+
113+ def should_build_agent_cluster (tool_or_agent : Union [BaseAgent , BaseTool ]):
114+ if isinstance (tool_or_agent , BaseAgent ):
115+ if isinstance (tool_or_agent , SequentialAgent ):
116+ return True
117+ elif isinstance (tool_or_agent , LoopAgent ):
118+ return True
119+ elif isinstance (tool_or_agent , ParallelAgent ):
120+ return True
121+ else :
122+ return False
123+ elif retrieval_tool_module_loaded and isinstance (
124+ tool_or_agent , BaseRetrievalTool
125+ ):
126+ return False
127+ elif isinstance (tool_or_agent , FunctionTool ):
128+ return False
129+ elif isinstance (tool_or_agent , BaseTool ):
130+ return False
131+ else :
132+ logger .warning (
133+ 'Unsupported tool, type: %s, obj: %s' ,
134+ type (tool_or_agent ),
135+ tool_or_agent ,
136+ )
137+ return False
138+
139+ def build_cluster (child : graphviz .Digraph , agent : BaseAgent , name : str ):
140+ if isinstance (agent , LoopAgent ):
141+ # Draw the edge from the parent agent to the first sub-agent
142+ draw_edge (parent_agent .name , agent .sub_agents [0 ].name )
143+ length = len (agent .sub_agents )
144+ currLength = 0
145+ # Draw the edges between the sub-agents
146+ for sub_agent_int_sequential in agent .sub_agents :
147+ build_graph (child , sub_agent_int_sequential , highlight_pairs )
148+ # Draw the edge between the current sub-agent and the next one
149+ # If it's the last sub-agent, draw an edge to the first one to indicating a loop
150+ draw_edge (agent .sub_agents [currLength ].name , agent .sub_agents [0 if currLength == length - 1 else currLength + 1 ].name )
151+ currLength += 1
152+ elif isinstance (agent , SequentialAgent ):
153+ # Draw the edge from the parent agent to the first sub-agent
154+ draw_edge (parent_agent .name , agent .sub_agents [0 ].name )
155+ length = len (agent .sub_agents )
156+ currLength = 0
157+
158+ # Draw the edges between the sub-agents
159+ for sub_agent_int_sequential in agent .sub_agents :
160+ build_graph (child , sub_agent_int_sequential , highlight_pairs )
161+ # Draw the edge between the current sub-agent and the next one
162+ # If it's the last sub-agent, don't draw an edge to avoid a loop
163+ draw_edge (agent .sub_agents [currLength ].name , agent .sub_agents [currLength + 1 ].name ) if currLength != length - 1 else None
164+ currLength += 1
165+
166+ elif isinstance (agent , ParallelAgent ):
167+ # Draw the edge from the parent agent to every sub-agent
168+ for sub_agent in agent .sub_agents :
169+ build_graph (child , sub_agent , highlight_pairs )
170+ draw_edge (parent_agent .name , sub_agent .name )
171+ else :
172+ for sub_agent in agent .sub_agents :
173+ build_graph (child , sub_agent , highlight_pairs )
174+ draw_edge (agent .name , sub_agent .name )
175+
176+ child .attr (
177+ label = name ,
178+ style = 'rounded' ,
179+ color = white ,
180+ fontcolor = light_gray ,
181+ )
91182
92183 def draw_node (tool_or_agent : Union [BaseAgent , BaseTool ]):
93184 name = get_node_name (tool_or_agent )
94185 shape = get_node_shape (tool_or_agent )
95186 caption = get_node_caption (tool_or_agent )
187+ asCluster = should_build_agent_cluster (tool_or_agent )
188+ child = None
96189 if highlight_pairs :
97190 for highlight_tuple in highlight_pairs :
98191 if name in highlight_tuple :
99- graph .node (
100- name ,
101- caption ,
102- style = 'filled,rounded' ,
103- fillcolor = dark_green ,
104- color = dark_green ,
105- shape = shape ,
106- fontcolor = light_gray ,
107- )
192+ # if in highlight, draw highlight node
193+ if asCluster :
194+ cluster = graphviz .Digraph (name = 'cluster_' + name ) # adding "cluster_" to the name makes the graph render as a cluster subgraph
195+ build_cluster (cluster , agent , name )
196+ graph .subgraph (cluster )
197+ else :
198+ graph .node (
199+ name ,
200+ caption ,
201+ style = 'filled,rounded' ,
202+ fillcolor = dark_green ,
203+ color = dark_green ,
204+ shape = shape ,
205+ fontcolor = light_gray ,
206+ )
108207 return
109- # if not in highlight, draw non-highliht node
110- graph .node (
208+ # if not in highlight, draw non-highlight node
209+ if asCluster :
210+
211+ cluster = graphviz .Digraph (name = 'cluster_' + name ) # adding "cluster_" to the name makes the graph render as a cluster subgraph
212+ build_cluster (cluster , agent , name )
213+ graph .subgraph (cluster )
214+
215+ else :
216+ graph .node (
111217 name ,
112218 caption ,
113219 shape = shape ,
114220 style = 'rounded' ,
115221 color = light_gray ,
116222 fontcolor = light_gray ,
117- )
223+ )
224+
225+ return
118226
119227 def draw_edge (from_name , to_name ):
120228 if highlight_pairs :
@@ -126,12 +234,18 @@ def draw_edge(from_name, to_name):
126234 graph .edge (from_name , to_name , color = light_green , dir = 'back' )
127235 return
128236 # if no need to highlight, color gray
129- graph .edge (from_name , to_name , arrowhead = 'none' , color = light_gray )
237+ if (should_build_agent_cluster (agent )):
238+
239+ graph .edge (from_name , to_name , color = light_gray , )
240+ else :
241+ graph .edge (from_name , to_name , arrowhead = 'none' , color = light_gray )
130242
131243 draw_node (agent )
132244 for sub_agent in agent .sub_agents :
133- await build_graph (graph , sub_agent , highlight_pairs )
134- draw_edge (agent .name , sub_agent .name )
245+
246+ build_graph (graph , sub_agent , highlight_pairs , agent )
247+ if (not should_build_agent_cluster (sub_agent ) and not should_build_agent_cluster (agent )): # This is to avoid making a node for a Workflow Agent
248+ draw_edge (agent .name , sub_agent .name )
135249 if isinstance (agent , LlmAgent ):
136250 for tool in await agent .canonical_tools ():
137251 draw_node (tool )
0 commit comments