Skip to content

Commit da4bc0e

Browse files
BhadrakshBcopybara-github
authored andcommitted
feat: New Agent Visualization
Merge #482 Added new agent visualzation that accounts for the internal architecture of the Workflow Agents and shows them inside of a cluster with a label as the name of the Workflow Agent. The sub agents are conencted as per the Workflow Agents' working and architecture COPYBARA_INTEGRATE_REVIEW=#482 from BhadrakshB:new_agent_visualization 994a9e2 PiperOrigin-RevId: 766345311
1 parent 2735942 commit da4bc0e

File tree

1 file changed

+132
-18
lines changed

1 file changed

+132
-18
lines changed

src/google/adk/cli/agent_graph.py

Lines changed: 132 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
import graphviz
2121

22-
from ..agents import BaseAgent
22+
from ..agents import BaseAgent, SequentialAgent, LoopAgent, ParallelAgent
2323
from ..agents.llm_agent import LlmAgent
2424
from ..tools.agent_tool import AgentTool
2525
from ..tools.base_tool import BaseTool
@@ -35,14 +35,34 @@
3535
retrieval_tool_module_loaded = True
3636

3737

38-
async def build_graph(graph, agent: BaseAgent, highlight_pairs):
38+
async def build_graph(graph: graphviz.Digraph, agent: BaseAgent, highlight_pairs, parent_agent=None):
39+
"""
40+
Build a graph of the agent and its sub-agents.
41+
Args:
42+
graph: The graph to build on.
43+
agent: The agent to build the graph for.
44+
highlight_pairs: A list of pairs of nodes to highlight.
45+
parent_agent: The parent agent of the current agent. This is specifically used when building Workflow Agents to directly connect a node to nodes inside a Workflow Agent.
46+
47+
Returns:
48+
None
49+
"""
3950
dark_green = '#0F5223'
4051
light_green = '#69CB87'
4152
light_gray = '#cccccc'
53+
white = '#ffffff'
4254

4355
def get_node_name(tool_or_agent: Union[BaseAgent, BaseTool]):
4456
if isinstance(tool_or_agent, BaseAgent):
45-
return tool_or_agent.name
57+
# Added Workflow Agent checks for different agent types
58+
if isinstance(tool_or_agent, SequentialAgent):
59+
return tool_or_agent.name + f" (Sequential Agent)"
60+
elif isinstance(tool_or_agent, LoopAgent):
61+
return tool_or_agent.name + f" (Loop Agent)"
62+
elif isinstance(tool_or_agent, ParallelAgent):
63+
return tool_or_agent.name + f" (Parallel Agent)"
64+
else:
65+
return tool_or_agent.name
4666
elif isinstance(tool_or_agent, BaseTool):
4767
return tool_or_agent.name
4868
else:
@@ -73,6 +93,7 @@ def get_node_caption(tool_or_agent: Union[BaseAgent, BaseTool]):
7393
def get_node_shape(tool_or_agent: Union[BaseAgent, BaseTool]):
7494
if isinstance(tool_or_agent, BaseAgent):
7595
return 'ellipse'
96+
7697
elif retrieval_tool_module_loaded and isinstance(
7798
tool_or_agent, BaseRetrievalTool
7899
):
@@ -88,33 +109,120 @@ def get_node_shape(tool_or_agent: Union[BaseAgent, BaseTool]):
88109
tool_or_agent,
89110
)
90111
return 'cylinder'
112+
113+
def should_build_agent_cluster(tool_or_agent: Union[BaseAgent, BaseTool]):
114+
if isinstance(tool_or_agent, BaseAgent):
115+
if isinstance(tool_or_agent, SequentialAgent):
116+
return True
117+
elif isinstance(tool_or_agent, LoopAgent):
118+
return True
119+
elif isinstance(tool_or_agent, ParallelAgent):
120+
return True
121+
else:
122+
return False
123+
elif retrieval_tool_module_loaded and isinstance(
124+
tool_or_agent, BaseRetrievalTool
125+
):
126+
return False
127+
elif isinstance(tool_or_agent, FunctionTool):
128+
return False
129+
elif isinstance(tool_or_agent, BaseTool):
130+
return False
131+
else:
132+
logger.warning(
133+
'Unsupported tool, type: %s, obj: %s',
134+
type(tool_or_agent),
135+
tool_or_agent,
136+
)
137+
return False
138+
139+
def build_cluster(child: graphviz.Digraph, agent: BaseAgent, name: str):
140+
if isinstance(agent, LoopAgent):
141+
# Draw the edge from the parent agent to the first sub-agent
142+
draw_edge(parent_agent.name, agent.sub_agents[0].name)
143+
length = len(agent.sub_agents)
144+
currLength = 0
145+
# Draw the edges between the sub-agents
146+
for sub_agent_int_sequential in agent.sub_agents:
147+
build_graph(child, sub_agent_int_sequential, highlight_pairs)
148+
# Draw the edge between the current sub-agent and the next one
149+
# If it's the last sub-agent, draw an edge to the first one to indicating a loop
150+
draw_edge(agent.sub_agents[currLength].name, agent.sub_agents[0 if currLength == length - 1 else currLength+1 ].name)
151+
currLength += 1
152+
elif isinstance(agent, SequentialAgent):
153+
# Draw the edge from the parent agent to the first sub-agent
154+
draw_edge(parent_agent.name, agent.sub_agents[0].name)
155+
length = len(agent.sub_agents)
156+
currLength = 0
157+
158+
# Draw the edges between the sub-agents
159+
for sub_agent_int_sequential in agent.sub_agents:
160+
build_graph(child, sub_agent_int_sequential, highlight_pairs)
161+
# Draw the edge between the current sub-agent and the next one
162+
# If it's the last sub-agent, don't draw an edge to avoid a loop
163+
draw_edge(agent.sub_agents[currLength].name, agent.sub_agents[currLength+1].name) if currLength != length - 1 else None
164+
currLength += 1
165+
166+
elif isinstance(agent, ParallelAgent):
167+
# Draw the edge from the parent agent to every sub-agent
168+
for sub_agent in agent.sub_agents:
169+
build_graph(child, sub_agent, highlight_pairs)
170+
draw_edge(parent_agent.name, sub_agent.name)
171+
else:
172+
for sub_agent in agent.sub_agents:
173+
build_graph(child, sub_agent, highlight_pairs)
174+
draw_edge(agent.name, sub_agent.name)
175+
176+
child.attr(
177+
label=name,
178+
style='rounded',
179+
color=white,
180+
fontcolor=light_gray,
181+
)
91182

92183
def draw_node(tool_or_agent: Union[BaseAgent, BaseTool]):
93184
name = get_node_name(tool_or_agent)
94185
shape = get_node_shape(tool_or_agent)
95186
caption = get_node_caption(tool_or_agent)
187+
asCluster = should_build_agent_cluster(tool_or_agent)
188+
child = None
96189
if highlight_pairs:
97190
for highlight_tuple in highlight_pairs:
98191
if name in highlight_tuple:
99-
graph.node(
100-
name,
101-
caption,
102-
style='filled,rounded',
103-
fillcolor=dark_green,
104-
color=dark_green,
105-
shape=shape,
106-
fontcolor=light_gray,
107-
)
192+
# if in highlight, draw highlight node
193+
if asCluster:
194+
cluster = graphviz.Digraph(name='cluster_' + name) # adding "cluster_" to the name makes the graph render as a cluster subgraph
195+
build_cluster(cluster, agent, name)
196+
graph.subgraph(cluster)
197+
else:
198+
graph.node(
199+
name,
200+
caption,
201+
style='filled,rounded',
202+
fillcolor=dark_green,
203+
color=dark_green,
204+
shape=shape,
205+
fontcolor=light_gray,
206+
)
108207
return
109-
# if not in highlight, draw non-highliht node
110-
graph.node(
208+
# if not in highlight, draw non-highlight node
209+
if asCluster:
210+
211+
cluster = graphviz.Digraph(name='cluster_' + name) # adding "cluster_" to the name makes the graph render as a cluster subgraph
212+
build_cluster(cluster, agent, name)
213+
graph.subgraph(cluster)
214+
215+
else:
216+
graph.node(
111217
name,
112218
caption,
113219
shape=shape,
114220
style='rounded',
115221
color=light_gray,
116222
fontcolor=light_gray,
117-
)
223+
)
224+
225+
return
118226

119227
def draw_edge(from_name, to_name):
120228
if highlight_pairs:
@@ -126,12 +234,18 @@ def draw_edge(from_name, to_name):
126234
graph.edge(from_name, to_name, color=light_green, dir='back')
127235
return
128236
# if no need to highlight, color gray
129-
graph.edge(from_name, to_name, arrowhead='none', color=light_gray)
237+
if (should_build_agent_cluster(agent)):
238+
239+
graph.edge(from_name, to_name, color=light_gray, )
240+
else:
241+
graph.edge(from_name, to_name, arrowhead='none', color=light_gray)
130242

131243
draw_node(agent)
132244
for sub_agent in agent.sub_agents:
133-
await build_graph(graph, sub_agent, highlight_pairs)
134-
draw_edge(agent.name, sub_agent.name)
245+
246+
build_graph(graph, sub_agent, highlight_pairs, agent)
247+
if (not should_build_agent_cluster(sub_agent) and not should_build_agent_cluster(agent)): # This is to avoid making a node for a Workflow Agent
248+
draw_edge(agent.name, sub_agent.name)
135249
if isinstance(agent, LlmAgent):
136250
for tool in await agent.canonical_tools():
137251
draw_node(tool)

0 commit comments

Comments
 (0)