From e6fe9aa42311bfba3283f6a2c7b9e7d8ed58aedb Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Thu, 28 May 2026 03:12:39 -0700 Subject: [PATCH] feat: refactor OpenTelemetry (OTel) instrumentation within the ADK core, moving from manual span management to structured helper classes ### Key Changes * **Structured Instrumentation:** Replaces manual `Tracing.trace` calls and explicit `Scope` management with `Flowable.using` and `Maybe.using` patterns. It introduces helper classes like `AgentInvocation` and `ToolExecution` to encapsulate telemetry logic. * **Metrics Integration:** Adds support for tracking new metrics during agent execution: * `gen_ai.agent.invocation.duration` * `gen_ai.agent.request.size` / `gen_ai.agent.response.size` * `gen_ai.agent.workflow.steps` * **Reactive API Improvements:** Leverages `Tracing.withContext()` and `doOnNext`/`doOnError` hooks within `AgentInvocation` and `ToolExecution` to automatically capture events and errors without polluting the core logic. * **Trace Hierarchy Refinement:** Updates how spans are nested. For example, in `ContextPropagationTest`, the child agent span is now correctly parented to the specific LLM call span that triggered it, rather than the parent agent span. * **Testing:** Significantly enhances `BaseAgentTest` to verify metric collection using `InMemoryMetricReader`. Updates various tests to match the new span naming convention (e.g., removing brackets from tool execution spans). ### Impact This refactor simplifies the core agent and flow logic by removing boilerplate telemetry code, making the instrumentation more robust and easier to maintain while expanding the observability of the ADK through new metrics. PiperOrigin-RevId: 922661681 --- .../java/com/google/adk/agents/BaseAgent.java | 27 +++--- .../adk/flows/llmflows/BaseLlmFlow.java | 20 ++--- .../google/adk/flows/llmflows/Functions.java | 36 ++++---- .../google/adk/telemetry/Instrumentation.java | 31 +++++-- .../com/google/adk/agents/BaseAgentTest.java | 84 ++++++++++++++++++- .../com/google/adk/agents/LlmAgentTest.java | 2 +- .../com/google/adk/runner/RunnerTest.java | 4 +- .../adk/telemetry/ContextPropagationTest.java | 19 +++-- .../com/google/adk/testing/TestCallback.java | 4 +- 9 files changed, 169 insertions(+), 58 deletions(-) diff --git a/core/src/main/java/com/google/adk/agents/BaseAgent.java b/core/src/main/java/com/google/adk/agents/BaseAgent.java index cbceceed2..5b154862e 100644 --- a/core/src/main/java/com/google/adk/agents/BaseAgent.java +++ b/core/src/main/java/com/google/adk/agents/BaseAgent.java @@ -24,7 +24,8 @@ import com.google.adk.agents.Callbacks.BeforeAgentCallback; import com.google.adk.events.Event; import com.google.adk.plugins.Plugin; -import com.google.adk.telemetry.Tracing; +import com.google.adk.telemetry.Instrumentation; +import com.google.adk.telemetry.Instrumentation.AgentInvocation; import com.google.adk.utils.AgentEnums.AgentOrigin; import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.CanIgnoreReturnValue; @@ -322,11 +323,13 @@ public Flowable runAsync(InvocationContext parentContext) { private Flowable run( InvocationContext parentContext, Function> runImplementation) { - Context parentSpanContext = Context.current(); - return Flowable.defer( - () -> { - InvocationContext invocationContext = createInvocationContext(parentContext); - + Context otelContext = Context.current(); + return Flowable.using( + () -> + Instrumentation.recordAgentInvocation( + createInvocationContext(parentContext), this, otelContext), + agentInvocation -> { + InvocationContext invocationContext = agentInvocation.getCtx(); Flowable mainAndAfterEvents = Flowable.defer(() -> runImplementation.apply(invocationContext)) .concatWith( @@ -350,14 +353,10 @@ private Flowable run( return Flowable.just(beforeEvent).concatWith(mainAndAfterEvents); }) .switchIfEmpty(mainAndAfterEvents) - .compose( - Tracing.trace("invoke_agent " + name()) - .setParent(parentSpanContext) - .configure( - span -> - Tracing.traceAgentInvocation( - span, name(), description(), invocationContext))); - }); + .doOnNext(agentInvocation::addEvent) + .doOnError(agentInvocation::setError); + }, + AgentInvocation::close); } /** diff --git a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java index ef7dce75a..dffba0e80 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java @@ -479,12 +479,10 @@ private Flowable runOneStep(Context spanContext, InvocationContext contex "Agent not found: " + agentToTransfer))); } return postProcessedEvents.concatWith( - Flowable.defer( - () -> { - try (Scope s = spanContext.makeCurrent()) { - return nextAgent.get().runAsync(context); - } - })); + nextAgent + .get() + .runAsync(context) + .compose(Tracing.withContext(spanContext))); } return postProcessedEvents; }); @@ -666,12 +664,10 @@ public void onError(Throwable e) { "Agent not found: " + event.actions().transferToAgent().get()); } Flowable nextAgentEvents = - Flowable.defer( - () -> { - try (Scope s = spanContext.makeCurrent()) { - return nextAgent.get().runLive(invocationContext); - } - }); + nextAgent + .get() + .runLive(invocationContext) + .compose(Tracing.withContext(spanContext)); events = Flowable.concat(events, nextAgentEvents); } return events; diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index 4aa20798d..8c60ebf76 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -29,6 +29,8 @@ import com.google.adk.events.Event; import com.google.adk.events.EventActions; import com.google.adk.events.ToolConfirmation; +import com.google.adk.telemetry.Instrumentation; +import com.google.adk.telemetry.Instrumentation.ToolExecution; import com.google.adk.telemetry.Tracing; import com.google.adk.tools.BaseTool; import com.google.adk.tools.FunctionTool; @@ -430,6 +432,25 @@ private static Maybe postProcessFunctionResult( ToolContext toolContext, boolean isLive, Context parentContext) { + return Maybe.using( + () -> + Instrumentation.recordToolExecution( + tool, invocationContext.agent(), functionArgs, parentContext), + toolExecution -> + processFunctionResult( + maybeFunctionResult, invocationContext, tool, functionArgs, toolContext, isLive) + .doOnSuccess(event -> toolExecution.context().setFunctionResponseEvent(event)) + .doOnError(toolExecution::setError), + ToolExecution::close); + } + + private static Maybe processFunctionResult( + Maybe> maybeFunctionResult, + InvocationContext invocationContext, + BaseTool tool, + Map functionArgs, + ToolContext toolContext, + boolean isLive) { return maybeFunctionResult .map(Optional::of) .defaultIfEmpty(Optional.empty()) @@ -467,20 +488,7 @@ private static Maybe postProcessFunctionResult( tool, finalFunctionResult, toolContext, invocationContext); return Maybe.just(event); }); - }) - .compose( - Tracing.trace("execute_tool [" + tool.name() + "]") - .setParent(parentContext) - .onSuccess( - (span, event) -> - Tracing.traceToolExecution( - span, - tool.name(), - tool.description(), - tool.getClass().getSimpleName(), - functionArgs, - event, - null))); + }); } private static Optional mergeParallelFunctionResponseEvents( diff --git a/core/src/main/java/com/google/adk/telemetry/Instrumentation.java b/core/src/main/java/com/google/adk/telemetry/Instrumentation.java index a2c62ba12..fd27878c9 100644 --- a/core/src/main/java/com/google/adk/telemetry/Instrumentation.java +++ b/core/src/main/java/com/google/adk/telemetry/Instrumentation.java @@ -125,8 +125,12 @@ public static final class AgentInvocation extends ClosableTelemetryScope { private final InvocationContext ctx; private final List events = Collections.synchronizedList(new ArrayList<>()); - public AgentInvocation(InvocationContext ctx, BaseAgent agent) { - super(Tracing.getTracer().spanBuilder("invoke_agent " + agent.name()).startSpan()); + public AgentInvocation(InvocationContext ctx, BaseAgent agent, Context parentContext) { + super( + Tracing.getTracer() + .spanBuilder("invoke_agent " + agent.name()) + .setParent(parentContext) + .startSpan()); this.agent = agent; this.ctx = ctx; Tracing.traceAgentInvocation(span, agent.name(), agent.description(), ctx); @@ -160,8 +164,13 @@ public static final class ToolExecution extends ClosableTelemetryScope { private final BaseAgent agent; private final Map functionArgs; - public ToolExecution(BaseTool tool, BaseAgent agent, Map functionArgs) { - super(Tracing.getTracer().spanBuilder("execute_tool " + tool.name()).startSpan()); + public ToolExecution( + BaseTool tool, BaseAgent agent, Map functionArgs, Context parentContext) { + super( + Tracing.getTracer() + .spanBuilder("execute_tool " + tool.name()) + .setParent(parentContext) + .startSpan()); this.tool = tool; this.agent = agent; this.functionArgs = functionArgs; @@ -196,12 +205,22 @@ protected void handleMetricsError(RuntimeException e) { /** Creates an AgentInvocation context to record agent invocation telemetry. */ public static AgentInvocation recordAgentInvocation(InvocationContext ctx, BaseAgent agent) { - return new AgentInvocation(ctx, agent); + return recordAgentInvocation(ctx, agent, Context.current()); + } + + public static AgentInvocation recordAgentInvocation( + InvocationContext ctx, BaseAgent agent, Context parentContext) { + return new AgentInvocation(ctx, agent, parentContext); } /** Creates a ToolExecution context to record tool execution telemetry. */ public static ToolExecution recordToolExecution( BaseTool tool, BaseAgent agent, Map functionArgs) { - return new ToolExecution(tool, agent, functionArgs); + return recordToolExecution(tool, agent, functionArgs, Context.current()); + } + + public static ToolExecution recordToolExecution( + BaseTool tool, BaseAgent agent, Map functionArgs, Context parentContext) { + return new ToolExecution(tool, agent, functionArgs, parentContext); } } diff --git a/core/src/test/java/com/google/adk/agents/BaseAgentTest.java b/core/src/test/java/com/google/adk/agents/BaseAgentTest.java index 5e2fa5792..a3436e6cb 100644 --- a/core/src/test/java/com/google/adk/agents/BaseAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/BaseAgentTest.java @@ -22,26 +22,42 @@ import com.google.adk.agents.Callbacks.AfterAgentCallback; import com.google.adk.agents.Callbacks.BeforeAgentCallback; import com.google.adk.events.Event; +import com.google.adk.telemetry.Metrics; import com.google.adk.testing.TestBaseAgent; import com.google.adk.testing.TestCallback; import com.google.adk.testing.TestUtils; import com.google.common.collect.ImmutableList; import com.google.genai.types.Content; import com.google.genai.types.Part; +import io.opentelemetry.api.GlobalOpenTelemetry; +import io.opentelemetry.api.common.AttributeKey; +import io.opentelemetry.api.metrics.Meter; +import io.opentelemetry.sdk.OpenTelemetrySdk; +import io.opentelemetry.sdk.metrics.SdkMeterProvider; +import io.opentelemetry.sdk.metrics.data.HistogramPointData; +import io.opentelemetry.sdk.metrics.data.MetricData; +import io.opentelemetry.sdk.testing.exporter.InMemoryMetricReader; +import io.opentelemetry.sdk.testing.time.TestClock; +import io.opentelemetry.sdk.trace.SdkTracerProvider; import io.reactivex.rxjava3.core.Completable; import io.reactivex.rxjava3.core.Maybe; import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; +import org.junit.After; +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @RunWith(JUnit4.class) public final class BaseAgentTest { - private static final String TEST_AGENT_NAME = "testAgent"; private static final String TEST_AGENT_DESCRIPTION = "A test agent"; + private InMemoryMetricReader inMemoryMetricReader; + private TestClock testClock; + private Meter originalMeter; + private static class ClosableTestAgent extends TestBaseAgent { final AtomicBoolean closed = new AtomicBoolean(false); @@ -56,6 +72,35 @@ public Completable close() { } } + @Before + public void setUp() { + GlobalOpenTelemetry.resetForTest(); + testClock = TestClock.create(); + inMemoryMetricReader = InMemoryMetricReader.create(); + SdkMeterProvider sdkMeterProvider = + SdkMeterProvider.builder() + .registerMetricReader(inMemoryMetricReader) + .setClock(testClock) + .build(); + + OpenTelemetrySdk openTelemetrySdk = + OpenTelemetrySdk.builder() + .setTracerProvider(SdkTracerProvider.builder().build()) + .setMeterProvider(sdkMeterProvider) + .build(); + + GlobalOpenTelemetry.set(openTelemetrySdk); + originalMeter = GlobalOpenTelemetry.getMeter("gcp.vertex.agent"); + Metrics.setMeterForTesting(openTelemetrySdk.getMeter("gcp.vertex.agent")); + } + + @After + public void tearDown() { + if (originalMeter != null) { + Metrics.setMeterForTesting(originalMeter); + } + } + @Test public void constructor_setsNameAndDescription() { String name = "testName"; @@ -173,6 +218,36 @@ public void runAsync_noCallbacks_invokesRunAsyncImpl() { assertThat(results).hasSize(1); assertThat(results.get(0).content()).hasValue(runAsyncImplContent); assertThat(runAsyncImpl.wasCalled()).isTrue(); + MetricData durationMetric = findMetricByName("gen_ai.agent.invocation.duration"); + assertThat(durationMetric.getUnit()).isEqualTo("ms"); + HistogramPointData durationPoint = + durationMetric.getHistogramData().getPoints().iterator().next(); + assertThat(durationPoint.getAttributes().get(AttributeKey.stringKey("gen_ai.agent.name"))) + .isEqualTo("testAgent"); + + MetricData reqSizeMetric = findMetricByName("gen_ai.agent.request.size"); + assertThat(reqSizeMetric.getUnit()).isEqualTo("By"); + HistogramPointData reqSizePoint = + reqSizeMetric.getHistogramData().getPoints().iterator().next(); + assertThat(reqSizePoint.getSum()).isEqualTo(12.0); + assertThat(reqSizePoint.getAttributes().get(AttributeKey.stringKey("gen_ai.agent.name"))) + .isEqualTo("testAgent"); + + MetricData respSizeMetric = findMetricByName("gen_ai.agent.response.size"); + assertThat(respSizeMetric.getUnit()).isEqualTo("By"); + HistogramPointData respSizePoint = + respSizeMetric.getHistogramData().getPoints().iterator().next(); + assertThat(respSizePoint.getSum()).isEqualTo(11.0); + assertThat(respSizePoint.getAttributes().get(AttributeKey.stringKey("gen_ai.agent.name"))) + .isEqualTo("testAgent"); + + MetricData workflowStepsMetric = findMetricByName("gen_ai.agent.workflow.steps"); + assertThat(workflowStepsMetric.getUnit()).isEqualTo("1"); + HistogramPointData workflowStepsPoint = + workflowStepsMetric.getHistogramData().getPoints().iterator().next(); + assertThat(workflowStepsPoint.getSum()).isEqualTo(1.0); + assertThat(workflowStepsPoint.getAttributes().get(AttributeKey.stringKey("gen_ai.agent.name"))) + .isEqualTo("testAgent"); } @Test @@ -627,4 +702,11 @@ public void close_twoLevelsSubAgents_closesAllSubAgents() { assertThat(subAgent.closed.get()).isTrue(); assertThat(subSubAgent.closed.get()).isTrue(); } + + private MetricData findMetricByName(String name) { + return inMemoryMetricReader.collectAllMetrics().stream() + .filter(m -> m.getName().equals(name)) + .findFirst() + .orElseThrow(() -> new AssertionError("Metric not found: " + name)); + } } diff --git a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java index 35cf12f6f..26843bb56 100644 --- a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java @@ -494,7 +494,7 @@ public void runAsync_withTools_createsToolSpans() throws InterruptedException { List spans = openTelemetryRule.getSpans(); SpanData agentSpan = findSpanByName(spans, "invoke_agent test agent"); List llmSpans = findSpansByName(spans, "call_llm"); - List toolSpans = findSpansByName(spans, "execute_tool [echo_tool]"); + List toolSpans = findSpansByName(spans, "execute_tool echo_tool"); assertThat(llmSpans).hasSize(2); assertThat(toolSpans).hasSize(1); diff --git a/core/src/test/java/com/google/adk/runner/RunnerTest.java b/core/src/test/java/com/google/adk/runner/RunnerTest.java index 3abfbdc20..00d5d63bf 100644 --- a/core/src/test/java/com/google/adk/runner/RunnerTest.java +++ b/core/src/test/java/com/google/adk/runner/RunnerTest.java @@ -1366,7 +1366,7 @@ public void runAsync_createsToolSpansWithCorrectParent() { List spans = openTelemetryRule.getSpans(); List llmSpans = spans.stream().filter(s -> s.getName().equals("call_llm")).toList(); List toolSpans = - spans.stream().filter(s -> s.getName().equals("execute_tool [echo_tool]")).toList(); + spans.stream().filter(s -> s.getName().equals("execute_tool echo_tool")).toList(); assertThat(llmSpans).hasSize(2); assertThat(toolSpans).hasSize(1); @@ -1401,7 +1401,7 @@ public void runLive_createsToolSpansWithCorrectParent() throws Exception { List spans = openTelemetryRule.getSpans(); List llmSpans = spans.stream().filter(s -> s.getName().equals("call_llm")).toList(); List toolSpans = - spans.stream().filter(s -> s.getName().equals("execute_tool [echo_tool]")).toList(); + spans.stream().filter(s -> s.getName().equals("execute_tool echo_tool")).toList(); // In runLive, there is one call_llm span for the execution assertThat(llmSpans).hasSize(1); diff --git a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java index 44877e972..331ae77b2 100644 --- a/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java +++ b/core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java @@ -471,7 +471,7 @@ public void testAgentWithToolCallTraceHierarchy() throws InterruptedException { // invocation // └── invoke_agent test_agent // ├── call_llm - // │ └── execute_tool [search_flights] + // │ └── execute_tool search_flights // └── call_llm SearchFlightsTool searchFlightsTool = new SearchFlightsTool(); @@ -499,7 +499,7 @@ public void testAgentWithToolCallTraceHierarchy() throws InterruptedException { SpanData invocation = findSpanByName("invocation"); SpanData invokeAgent = findSpanByName("invoke_agent test_agent"); - SpanData toolResponse = findSpanByName("execute_tool [search_flights]"); + SpanData toolResponse = findSpanByName("execute_tool search_flights"); List callLlmSpans = openTelemetryRule.getSpans().stream() .filter(s -> s.getName().equals("call_llm")) @@ -515,7 +515,7 @@ public void testAgentWithToolCallTraceHierarchy() throws InterruptedException { assertParent(invocation, invokeAgent); // ├── call_llm 1 assertParent(invokeAgent, callLlm1); - // │ └── execute_tool [search_flights] + // │ └── execute_tool search_flights assertParent(callLlm1, toolResponse); // └── call_llm 2 assertParent(invokeAgent, callLlm2); @@ -546,7 +546,7 @@ public void testNestedAgentTraceHierarchy() throws InterruptedException { // invocation // └── invoke_agent AgentA // ├── call_llm - // │ └── execute_tool [transfer_to_agent] + // │ └── execute_tool transfer_to_agent // └── invoke_agent AgentB // └── call_llm TestLlm llm = @@ -573,7 +573,7 @@ public void testNestedAgentTraceHierarchy() throws InterruptedException { SpanData invocation = findSpanByName("invocation"); SpanData agentASpan = findSpanByName("invoke_agent AgentA"); - SpanData executeTool = findSpanByName("execute_tool [transfer_to_agent]"); + SpanData executeTool = findSpanByName("execute_tool transfer_to_agent"); SpanData agentBSpan = findSpanByName("invoke_agent AgentB"); List callLlmSpans = @@ -586,10 +586,17 @@ public void testNestedAgentTraceHierarchy() throws InterruptedException { SpanData agentACallLlm1 = callLlmSpans.get(0); SpanData agentBCallLlm = callLlmSpans.get(1); + // Assert hierarchy: + // invocation + // └── invoke_agent AgentA assertParent(invocation, agentASpan); + // └── call_llm 1 assertParent(agentASpan, agentACallLlm1); + // ├── execute_tool transfer_to_agent assertParent(agentACallLlm1, executeTool); - assertParent(agentASpan, agentBSpan); + // └── invoke_agent AgentB + assertParent(agentACallLlm1, agentBSpan); + // └── call_llm 2 assertParent(agentBSpan, agentBCallLlm); } diff --git a/core/src/test/java/com/google/adk/testing/TestCallback.java b/core/src/test/java/com/google/adk/testing/TestCallback.java index 6f35f5a3c..403e3874a 100644 --- a/core/src/test/java/com/google/adk/testing/TestCallback.java +++ b/core/src/test/java/com/google/adk/testing/TestCallback.java @@ -91,7 +91,7 @@ public Supplier> asRunAsyncImplSupplier(Content content) { Flowable.defer( () -> { markAsCalled(); - return Flowable.just(Event.builder().content(content).build()); + return Flowable.just(Event.builder().author("testAgent").content(content).build()); }); } @@ -111,7 +111,7 @@ public Supplier> asRunLiveImplSupplier(Content content) { Flowable.defer( () -> { markAsCalled(); - return Flowable.just(Event.builder().content(content).build()); + return Flowable.just(Event.builder().author("testAgent").content(content).build()); }); }