diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClient.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClient.java index 16e38e06..346f2079 100644 --- a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClient.java +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClient.java @@ -81,4 +81,19 @@ AIJudgeConfig judgeConfig( LDContext context, AIJudgeConfigDefault defaultValue, Map variables); + + /** + * Reconstructs a tracker from a resumption token produced by + * {@link LDAIConfigTracker#getResumptionToken()}. + *

+ * The reconstructed tracker shares the original run's {@code runId}, so events it emits (for + * example deferred user feedback recorded in another process) correlate with the original AI run. + * Model and provider names are not carried in the token and are reported as empty strings. + * + * @param resumptionToken the token to reconstruct from + * @param context the context the tracker's events will be attributed to + * @return a tracker sharing the original run's identity + * @throws IllegalArgumentException if the token is malformed + */ + LDAIConfigTracker createTracker(String resumptionToken, LDContext context); } diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClientImpl.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClientImpl.java index 3269621b..2e28ad65 100644 --- a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClientImpl.java +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClientImpl.java @@ -8,13 +8,15 @@ import com.launchdarkly.sdk.LDContext; import com.launchdarkly.sdk.LDValue; import com.launchdarkly.sdk.LDValueType; -import com.launchdarkly.sdk.server.ai.datamodel.LDAIConfigTypes.Mode; import com.launchdarkly.sdk.server.ai.datamodel.LDAIConfigTypes.Message; +import com.launchdarkly.sdk.server.ai.datamodel.LDAIConfigTypes.Mode; +import com.launchdarkly.sdk.server.ai.datamodel.LDAIConfigTypes.Model; +import com.launchdarkly.sdk.server.ai.datamodel.LDAIConfigTypes.Provider; import com.launchdarkly.sdk.server.ai.internal.AIConfigFlagValue; import com.launchdarkly.sdk.server.ai.internal.AIConfigParser; import com.launchdarkly.sdk.server.ai.internal.AISdkInfo; import com.launchdarkly.sdk.server.ai.internal.Interpolator; -import com.launchdarkly.sdk.server.ai.internal.NoOpAIConfigTracker; +import com.launchdarkly.sdk.server.ai.internal.LDAIConfigTrackerImpl; import com.launchdarkly.sdk.server.interfaces.LDClientInterface; import java.util.ArrayList; @@ -22,6 +24,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.UUID; import java.util.function.Supplier; /** @@ -51,9 +54,6 @@ public final class LDAIClientImpl implements LDAIClient { .anonymous(true) .build(); - // Tracking is implemented in a later step; until then every config hands out the no-op tracker. - private static final Supplier TRACKER_FACTORY = () -> NoOpAIConfigTracker.INSTANCE; - private final LDClientInterface client; private final LDLogger logger; private final Interpolator interpolator; @@ -146,6 +146,11 @@ public AIJudgeConfig judgeConfig( return (AIJudgeConfig) evaluate(key, context, effectiveDefault, Mode.JUDGE, variables); } + @Override + public LDAIConfigTracker createTracker(String resumptionToken, LDContext context) { + return LDAIConfigTrackerImpl.fromResumptionToken(resumptionToken, client, context, logger); + } + private AIAgentConfig evaluateAgent( String key, LDContext context, AIAgentConfigDefault defaultValue, Map variables) { AIAgentConfigDefault effectiveDefault = @@ -180,7 +185,7 @@ private AIConfig evaluate( logger.warn( "AI Config mode mismatch for {}: expected {}, got {}. Returning disabled config.", key, mode.getWireValue(), flagMode.getWireValue()); - return disabledConfig(key, mode); + return disabledConfig(key, mode, context); } return buildConfig(key, mode, parsed, context, variables); @@ -192,6 +197,8 @@ private AIConfig buildConfig( AIConfigFlagValue parsed, LDContext context, Map variables) { + Supplier trackerFactory = trackerFactory( + key, parsed.getVariationKey(), parsed.getVersion(), parsed.getModel(), parsed.getProvider(), context); switch (mode) { case AGENT: return new AIAgentConfig( @@ -202,7 +209,7 @@ private AIConfig buildConfig( interpolate(parsed.getInstructions(), variables, context), parsed.getJudgeConfiguration(), parsed.getTools(), - TRACKER_FACTORY); + trackerFactory); case JUDGE: return new AIJudgeConfig( key, @@ -211,7 +218,7 @@ private AIConfig buildConfig( parsed.getProvider(), interpolateMessages(parsed.getMessages(), variables, context), parsed.getEvaluationMetricKey(), - TRACKER_FACTORY); + trackerFactory); case COMPLETION: default: return new AICompletionConfig( @@ -222,7 +229,7 @@ private AIConfig buildConfig( interpolateMessages(parsed.getMessages(), variables, context), parsed.getJudgeConfiguration(), parsed.getTools(), - TRACKER_FACTORY); + trackerFactory); } } @@ -247,7 +254,7 @@ private AIConfig buildConfigFromDefault( interpolate(agent.getInstructions(), variables, context), agent.getJudgeConfiguration(), agent.getTools(), - TRACKER_FACTORY); + trackerFactory(key, null, null, agent.getModel(), agent.getProvider(), context)); } case JUDGE: { AIJudgeConfigDefault judge = (AIJudgeConfigDefault) defaultValue; @@ -258,7 +265,7 @@ private AIConfig buildConfigFromDefault( judge.getProvider(), interpolateMessages(judge.getMessages(), variables, context), judge.getEvaluationMetricKey(), - TRACKER_FACTORY); + trackerFactory(key, null, null, judge.getModel(), judge.getProvider(), context)); } case COMPLETION: default: { @@ -271,23 +278,38 @@ private AIConfig buildConfigFromDefault( interpolateMessages(completion.getMessages(), variables, context), completion.getJudgeConfiguration(), completion.getTools(), - TRACKER_FACTORY); + trackerFactory(key, null, null, completion.getModel(), completion.getProvider(), context)); } } } - private AIConfig disabledConfig(String key, Mode mode) { + private AIConfig disabledConfig(String key, Mode mode, LDContext context) { + Supplier trackerFactory = trackerFactory(key, null, null, null, null, context); switch (mode) { case AGENT: - return new AIAgentConfig(key, false, null, null, null, null, null, TRACKER_FACTORY); + return new AIAgentConfig(key, false, null, null, null, null, null, trackerFactory); case JUDGE: - return new AIJudgeConfig(key, false, null, null, null, null, TRACKER_FACTORY); + return new AIJudgeConfig(key, false, null, null, null, null, trackerFactory); case COMPLETION: default: - return new AICompletionConfig(key, false, null, null, null, null, null, TRACKER_FACTORY); + return new AICompletionConfig(key, false, null, null, null, null, null, trackerFactory); } } + /** + * Builds a factory that produces a fresh tracker, with a new {@code runId}, on each call. The + * factory captures the config's correlation data and the evaluation context. + */ + private Supplier trackerFactory( + String key, String variationKey, Integer version, Model model, Provider provider, LDContext context) { + String varKey = variationKey == null ? "" : variationKey; + int ver = version == null ? 0 : version; + String modelName = model != null && model.getName() != null ? model.getName() : ""; + String providerName = provider != null && provider.getName() != null ? provider.getName() : ""; + return () -> new LDAIConfigTrackerImpl( + client, UUID.randomUUID().toString(), key, varKey, ver, modelName, providerName, context, null, logger); + } + private List interpolateMessages( List messages, Map variables, LDContext context) { if (messages == null) { diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIConfigTracker.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIConfigTracker.java index a298e33b..b97b6f25 100644 --- a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIConfigTracker.java +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIConfigTracker.java @@ -1,16 +1,167 @@ package com.launchdarkly.sdk.server.ai; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.FeedbackKind; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.JudgeResult; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.Metrics; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.MetricSummary; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.TokenUsage; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.TrackData; + +import java.time.Duration; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.function.Function; + /** - * Reports events related to a single AI run of an {@link AIConfig}. + * Reports metrics related to a single AI run of an {@link AIConfig}. *

- * A tracker is obtained from a retrieved config via {@link AIConfig#createTracker()}. Each tracker - * corresponds to one AI run and is used to record metrics such as model usage, duration, and - * feedback against the AI Config it was created from. + * A tracker is obtained from a retrieved config via {@link AIConfig#createTracker()}, or + * reconstructed across process boundaries via + * {@link LDAIClient#createTracker(String, com.launchdarkly.sdk.LDContext)}. Each tracker corresponds + * to one AI run; every event it emits shares a {@code runId} (a UUIDv4) so LaunchDarkly can + * correlate them in metrics views. Start a new run by calling {@link AIConfig#createTracker()} again. *

- * This interface is an intentional placeholder. The metric- and feedback-reporting - * methods (and resumption-token support) are introduced in a later step of the AI SDK build-out; it - * is defined here so that the public config types expose a stable {@code createTracker()} surface. - * The only implementation in this release is an internal no-op. + * Thread-safety. Implementations are safe to share across threads. The + * "record-once" metrics ({@link #trackDuration}, {@link #trackTimeToFirstToken}, + * {@link #trackSuccess}/{@link #trackError}, {@link #trackFeedback}, {@link #trackTokens}) each emit + * at most once per tracker even under concurrent calls; later calls are ignored and logged. + * {@link #trackToolCall}/{@link #trackToolCalls} and {@link #trackJudgeResult} may be called any + * number of times and emit on every call. */ public interface LDAIConfigTracker { + /** + * Returns the correlation data attached to every event this tracker emits. + * + * @return the track data, never {@code null} + */ + TrackData getTrackData(); + + /** + * Returns a URL-safe Base64 token that encodes this tracker's {@code runId}, {@code configKey}, + * {@code variationKey}, and {@code version}. + *

+ * Pass it to {@link LDAIClient#createTracker(String, com.launchdarkly.sdk.LDContext)} to + * reconstruct a tracker in another process so deferred events (for example user feedback) still + * correlate with the original run. + * + * @return the resumption token, never {@code null} + */ + String getResumptionToken(); + + /** + * Records the duration of the generation. + *

+ * Records at most once per tracker; later calls are ignored. Negative durations (for example from + * clock skew) are clamped to zero. + * + * @param duration the generation duration; must not be {@code null} + */ + void trackDuration(Duration duration); + + /** + * Runs the given operation, recording its duration even if it throws. + *

+ * This does not record success or error; use {@link #trackMetricsOf} for that. Because + * {@link #trackDuration} records at most once, calling this twice on the same tracker re-runs the + * operation but emits no second duration event. + * + * @param operation the operation to time + * @param the operation's result type + * @return the operation's result + * @throws Exception if the operation throws + */ + T trackDurationOf(Callable operation) throws Exception; + + /** + * Records the time to first token for a streaming generation. + *

+ * Records at most once per tracker; later calls are ignored. Negative values are clamped to zero. + * + * @param duration the time to first token; must not be {@code null} + */ + void trackTimeToFirstToken(Duration duration); + + /** + * Records that the generation succeeded. + *

+ * Success and error share state: only the first of {@link #trackSuccess}/{@link #trackError} + * recorded on a tracker takes effect; later calls are ignored. + */ + void trackSuccess(); + + /** + * Records that the generation failed. + *

+ * Success and error share state: only the first of {@link #trackSuccess}/{@link #trackError} + * recorded on a tracker takes effect; later calls are ignored. + */ + void trackError(); + + /** + * Records end-user feedback about the generation. + *

+ * Records at most once per tracker; later calls are ignored. + * + * @param kind the feedback sentiment; must not be {@code null} + */ + void trackFeedback(FeedbackKind kind); + + /** + * Records token usage for the generation. + *

+ * Records at most once per tracker; later calls are ignored. Negative counts are clamped to zero, + * and an individual count is only emitted when it is greater than zero. + * + * @param tokens the token usage; must not be {@code null} + */ + void trackTokens(TokenUsage tokens); + + /** + * Records a single tool invocation. May be called any number of times. + * + * @param toolKey the identifier of the invoked tool; must not be {@code null} + */ + void trackToolCall(String toolKey); + + /** + * Records several tool invocations. May be called any number of times. + * + * @param toolKeys the identifiers of the invoked tools; must not be {@code null} + */ + void trackToolCalls(List toolKeys); + + /** + * Records a judge evaluation result. May be called any number of times. + *

+ * No event is emitted when the result was not sampled, did not succeed, or carries no metric key + * or score. A {@code null} score is treated as "no score" and is distinct from {@code 0.0}. + * + * @param result the judge result; must not be {@code null} + */ + void trackJudgeResult(JudgeResult result); + + /** + * Runs the given operation, recording its duration and then its outcome and metrics. + *

+ * The operation is timed via {@link #trackDurationOf}. If it throws, an error is recorded and the + * exception is rethrown. Otherwise the extractor is applied to the result; if the extractor + * throws, an error is recorded and the exception is rethrown. On success the extracted metrics + * drive {@link #trackSuccess}/{@link #trackError}, {@link #trackTokens}, and + * {@link #trackToolCalls}. + * + * @param metricsExtractor extracts {@link Metrics} from the operation's result + * @param operation the AI operation to run + * @param the operation's result type + * @return the operation's result + * @throws Exception if the operation or the extractor throws + */ + T trackMetricsOf(Function metricsExtractor, Callable operation) + throws Exception; + + /** + * Returns an immutable snapshot of the metrics recorded on this tracker so far. + * + * @return the metric summary, never {@code null} + */ + MetricSummary getSummary(); } diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/datamodel/LDAITrackingTypes.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/datamodel/LDAITrackingTypes.java new file mode 100644 index 00000000..6f22c05a --- /dev/null +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/datamodel/LDAITrackingTypes.java @@ -0,0 +1,766 @@ +package com.launchdarkly.sdk.server.ai.datamodel; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +/** + * Container for the value types used when reporting metrics about an AI run through an + * {@link com.launchdarkly.sdk.server.ai.LDAIConfigTracker}. + *

+ * These are simple, immutable data carriers (token usage, feedback sentiment, the metrics a caller + * extracts from a model response, a judge result, the per-run summary, and the data attached to + * every tracking event). They are grouped under one type to keep the public surface compact and to + * avoid claiming generic top-level names such as {@code TokenUsage} or {@code Metrics}. + *

+ * This class cannot be instantiated. + */ +public final class LDAITrackingTypes { + private LDAITrackingTypes() { + } + + /** + * Sentiment of end-user feedback about a generation. + */ + public enum FeedbackKind { + /** + * Positive sentiment. + */ + POSITIVE("positive"), + /** + * Negative sentiment. + */ + NEGATIVE("negative"); + + private final String value; + + FeedbackKind(String value) { + this.value = value; + } + + /** + * Returns the wire value used in event names for this sentiment. + * + * @return the wire value, either {@code "positive"} or {@code "negative"} + */ + public String getValue() { + return value; + } + } + + /** + * Token usage reported for a single generation. + *

+ * Counts are non-negative; negative inputs are clamped to zero when recorded. + */ + public static final class TokenUsage { + private final long total; + private final long input; + private final long output; + + /** + * Creates a token-usage record. + * + * @param total the combined token count + * @param input the number of input (prompt) tokens + * @param output the number of output (completion) tokens + */ + public TokenUsage(long total, long input, long output) { + this.total = total; + this.input = input; + this.output = output; + } + + /** + * Returns the combined token count. + * + * @return the total token count + */ + public long getTotal() { + return total; + } + + /** + * Returns the number of input (prompt) tokens. + * + * @return the input token count + */ + public long getInput() { + return input; + } + + /** + * Returns the number of output (completion) tokens. + * + * @return the output token count + */ + public long getOutput() { + return output; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof TokenUsage)) { + return false; + } + TokenUsage that = (TokenUsage) o; + return total == that.total && input == that.input && output == that.output; + } + + @Override + public int hashCode() { + return Objects.hash(total, input, output); + } + + @Override + public String toString() { + return "TokenUsage{total=" + total + ", input=" + input + ", output=" + output + '}'; + } + } + + /** + * Metrics a caller extracts from an AI run, supplied to + * {@link com.launchdarkly.sdk.server.ai.LDAIConfigTracker#trackMetricsOf}. + */ + public static final class Metrics { + private final boolean success; + private final TokenUsage tokens; + private final List toolCalls; + private final Long durationMs; + + private Metrics(Builder b) { + this.success = b.success; + this.tokens = b.tokens; + this.toolCalls = b.toolCalls == null + ? null : Collections.unmodifiableList(new ArrayList<>(b.toolCalls)); + this.durationMs = b.durationMs; + } + + /** + * Returns whether the AI run was successful. + * + * @return {@code true} if the run succeeded + */ + public boolean isSuccess() { + return success; + } + + /** + * Returns the token usage for the run. + * + * @return the token usage, or {@code null} if not available + */ + public TokenUsage getTokens() { + return tokens; + } + + /** + * Returns the identifiers of tools invoked during the run. + * + * @return an unmodifiable list of tool keys, or {@code null} if none were recorded + */ + public List getToolCalls() { + return toolCalls; + } + + /** + * Returns the measured duration of the run in milliseconds. + * + * @return the duration in milliseconds, or {@code null} if not measured + */ + public Long getDurationMs() { + return durationMs; + } + + /** + * Creates a builder for a metrics result. + * + * @param success whether the run was successful + * @return a new {@link Builder} + */ + public static Builder builder(boolean success) { + return new Builder(success); + } + + /** + * Builder for {@link Metrics}. + */ + public static final class Builder { + private final boolean success; + private TokenUsage tokens; + private List toolCalls; + private Long durationMs; + + private Builder(boolean success) { + this.success = success; + } + + /** + * Sets the token usage. + * + * @param v the token usage + * @return this builder + */ + public Builder tokens(TokenUsage v) { + this.tokens = v; + return this; + } + + /** + * Sets the tool-call identifiers. + * + * @param v the tool keys + * @return this builder + */ + public Builder toolCalls(List v) { + this.toolCalls = v; + return this; + } + + /** + * Sets the measured duration in milliseconds. + * + * @param v the duration in milliseconds + * @return this builder + */ + public Builder durationMs(Long v) { + this.durationMs = v; + return this; + } + + /** + * Builds the immutable {@link Metrics}. + * + * @return a new {@link Metrics} + */ + public Metrics build() { + return new Metrics(this); + } + } + } + + /** + * The outcome of a judge evaluation, supplied to + * {@link com.launchdarkly.sdk.server.ai.LDAIConfigTracker#trackJudgeResult}. + *

+ * A {@code null} {@link #getScore() score} means "no score was produced" and is distinct from a + * legitimate score of {@code 0.0}. + */ + public static final class JudgeResult { + private final String judgeConfigKey; + private final boolean success; + private final String errorMessage; + private final boolean sampled; + private final String metricKey; + private final Double score; + private final String reasoning; + + private JudgeResult(Builder b) { + this.judgeConfigKey = b.judgeConfigKey; + this.success = b.success; + this.errorMessage = b.errorMessage; + this.sampled = b.sampled; + this.metricKey = b.metricKey; + this.score = b.score; + this.reasoning = b.reasoning; + } + + /** + * Returns the key of the judge configuration that produced this result. + * + * @return the judge config key, or {@code null} if not set + */ + public String getJudgeConfigKey() { + return judgeConfigKey; + } + + /** + * Returns whether the evaluation completed successfully. + * + * @return {@code true} if the evaluation succeeded + */ + public boolean isSuccess() { + return success; + } + + /** + * Returns the error message when the evaluation failed. + * + * @return the error message, or {@code null} if there was none + */ + public String getErrorMessage() { + return errorMessage; + } + + /** + * Returns whether the evaluation was sampled (actually run) rather than skipped. + * + * @return {@code true} if the evaluation was sampled + */ + public boolean isSampled() { + return sampled; + } + + /** + * Returns the metric key the score is reported against. + * + * @return the metric key, or {@code null} if not set + */ + public String getMetricKey() { + return metricKey; + } + + /** + * Returns the score, between {@code 0.0} and {@code 1.0}. + * + * @return the score, or {@code null} if no score was produced + */ + public Double getScore() { + return score; + } + + /** + * Returns the reasoning behind the score. + * + * @return the reasoning, or {@code null} if not provided + */ + public String getReasoning() { + return reasoning; + } + + /** + * Creates a builder for a judge result. + * + * @param sampled whether the evaluation was sampled + * @param success whether the evaluation succeeded + * @return a new {@link Builder} + */ + public static Builder builder(boolean sampled, boolean success) { + return new Builder(sampled, success); + } + + /** + * Builder for {@link JudgeResult}. + */ + public static final class Builder { + private final boolean sampled; + private final boolean success; + private String judgeConfigKey; + private String errorMessage; + private String metricKey; + private Double score; + private String reasoning; + + private Builder(boolean sampled, boolean success) { + this.sampled = sampled; + this.success = success; + } + + /** + * Sets the judge configuration key. + * + * @param v the judge config key + * @return this builder + */ + public Builder judgeConfigKey(String v) { + this.judgeConfigKey = v; + return this; + } + + /** + * Sets the error message. + * + * @param v the error message + * @return this builder + */ + public Builder errorMessage(String v) { + this.errorMessage = v; + return this; + } + + /** + * Sets the metric key. + * + * @param v the metric key + * @return this builder + */ + public Builder metricKey(String v) { + this.metricKey = v; + return this; + } + + /** + * Sets the score. + * + * @param v the score + * @return this builder + */ + public Builder score(Double v) { + this.score = v; + return this; + } + + /** + * Sets the reasoning. + * + * @param v the reasoning + * @return this builder + */ + public Builder reasoning(String v) { + this.reasoning = v; + return this; + } + + /** + * Builds the immutable {@link JudgeResult}. + * + * @return a new {@link JudgeResult} + */ + public JudgeResult build() { + return new JudgeResult(this); + } + } + } + + /** + * A snapshot summary of the metrics recorded on a tracker, returned by + * {@link com.launchdarkly.sdk.server.ai.LDAIConfigTracker#getSummary}. + *

+ * Fields are {@code null} when the corresponding metric was never recorded. + */ + public static final class MetricSummary { + private final Boolean success; + private final TokenUsage tokens; + private final List toolCalls; + private final Long durationMs; + private final Long timeToFirstTokenMs; + private final FeedbackKind feedback; + private final String resumptionToken; + + private MetricSummary(Builder b) { + this.success = b.success; + this.tokens = b.tokens; + this.toolCalls = b.toolCalls == null + ? null : Collections.unmodifiableList(new ArrayList<>(b.toolCalls)); + this.durationMs = b.durationMs; + this.timeToFirstTokenMs = b.timeToFirstTokenMs; + this.feedback = b.feedback; + this.resumptionToken = b.resumptionToken; + } + + /** + * Returns whether the run was recorded as successful. + * + * @return {@code true}/{@code false} once recorded, or {@code null} if neither success nor + * error was recorded + */ + public Boolean getSuccess() { + return success; + } + + /** + * Returns the recorded token usage. + * + * @return the token usage, or {@code null} if not recorded + */ + public TokenUsage getTokens() { + return tokens; + } + + /** + * Returns an immutable snapshot of the tool calls recorded so far. + * + * @return an unmodifiable list of tool keys, or {@code null} if none were recorded + */ + public List getToolCalls() { + return toolCalls; + } + + /** + * Returns the recorded generation duration in milliseconds. + * + * @return the duration in milliseconds, or {@code null} if not recorded + */ + public Long getDurationMs() { + return durationMs; + } + + /** + * Returns the recorded time-to-first-token in milliseconds. + * + * @return the time to first token in milliseconds, or {@code null} if not recorded + */ + public Long getTimeToFirstTokenMs() { + return timeToFirstTokenMs; + } + + /** + * Returns the recorded feedback sentiment. + * + * @return the feedback, or {@code null} if not recorded + */ + public FeedbackKind getFeedback() { + return feedback; + } + + /** + * Returns the resumption token for the run this summary belongs to. + * + * @return the resumption token + */ + public String getResumptionToken() { + return resumptionToken; + } + + /** + * Creates a new builder. + * + * @return a new {@link Builder} + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for {@link MetricSummary}. + */ + public static final class Builder { + private Boolean success; + private TokenUsage tokens; + private List toolCalls; + private Long durationMs; + private Long timeToFirstTokenMs; + private FeedbackKind feedback; + private String resumptionToken; + + private Builder() { + } + + /** + * Sets the success flag. + * + * @param v the success flag + * @return this builder + */ + public Builder success(Boolean v) { + this.success = v; + return this; + } + + /** + * Sets the token usage. + * + * @param v the token usage + * @return this builder + */ + public Builder tokens(TokenUsage v) { + this.tokens = v; + return this; + } + + /** + * Sets the tool-call identifiers. + * + * @param v the tool keys + * @return this builder + */ + public Builder toolCalls(List v) { + this.toolCalls = v; + return this; + } + + /** + * Sets the generation duration in milliseconds. + * + * @param v the duration in milliseconds + * @return this builder + */ + public Builder durationMs(Long v) { + this.durationMs = v; + return this; + } + + /** + * Sets the time-to-first-token in milliseconds. + * + * @param v the time to first token in milliseconds + * @return this builder + */ + public Builder timeToFirstTokenMs(Long v) { + this.timeToFirstTokenMs = v; + return this; + } + + /** + * Sets the feedback sentiment. + * + * @param v the feedback + * @return this builder + */ + public Builder feedback(FeedbackKind v) { + this.feedback = v; + return this; + } + + /** + * Sets the resumption token. + * + * @param v the resumption token + * @return this builder + */ + public Builder resumptionToken(String v) { + this.resumptionToken = v; + return this; + } + + /** + * Builds the immutable {@link MetricSummary}. + * + * @return a new {@link MetricSummary} + */ + public MetricSummary build() { + return new MetricSummary(this); + } + } + } + + /** + * The correlation data attached to every event a tracker emits. + *

+ * All events for one AI run share a {@code runId} so LaunchDarkly can correlate them. + */ + public static final class TrackData { + private final String runId; + private final String configKey; + private final String variationKey; + private final int version; + private final String modelName; + private final String providerName; + private final String graphKey; + + /** + * Creates a track-data record. + * + * @param runId the per-run UUID shared by all of the run's events + * @param configKey the AI Config key + * @param variationKey the variation key, or empty string when unknown + * @param version the AI Config version + * @param modelName the model name, or empty string when unknown + * @param providerName the provider name, or empty string when unknown + * @param graphKey the graph key, or {@code null} when not part of a graph + */ + public TrackData( + String runId, + String configKey, + String variationKey, + int version, + String modelName, + String providerName, + String graphKey) { + this.runId = runId; + this.configKey = configKey; + this.variationKey = variationKey; + this.version = version; + this.modelName = modelName; + this.providerName = providerName; + this.graphKey = graphKey; + } + + /** + * Returns the per-run UUID shared by all of the run's events. + * + * @return the run id + */ + public String getRunId() { + return runId; + } + + /** + * Returns the AI Config key. + * + * @return the config key + */ + public String getConfigKey() { + return configKey; + } + + /** + * Returns the variation key. + * + * @return the variation key, or empty string when unknown + */ + public String getVariationKey() { + return variationKey; + } + + /** + * Returns the AI Config version. + * + * @return the version + */ + public int getVersion() { + return version; + } + + /** + * Returns the model name. + * + * @return the model name, or empty string when unknown + */ + public String getModelName() { + return modelName; + } + + /** + * Returns the provider name. + * + * @return the provider name, or empty string when unknown + */ + public String getProviderName() { + return providerName; + } + + /** + * Returns the graph key. + * + * @return the graph key, or {@code null} when not part of a graph + */ + public String getGraphKey() { + return graphKey; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof TrackData)) { + return false; + } + TrackData that = (TrackData) o; + return version == that.version + && Objects.equals(runId, that.runId) + && Objects.equals(configKey, that.configKey) + && Objects.equals(variationKey, that.variationKey) + && Objects.equals(modelName, that.modelName) + && Objects.equals(providerName, that.providerName) + && Objects.equals(graphKey, that.graphKey); + } + + @Override + public int hashCode() { + return Objects.hash(runId, configKey, variationKey, version, modelName, providerName, graphKey); + } + + @Override + public String toString() { + return "TrackData{runId=" + runId + ", configKey=" + configKey + ", variationKey=" + variationKey + + ", version=" + version + ", modelName=" + modelName + ", providerName=" + providerName + + ", graphKey=" + graphKey + '}'; + } + } +} diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImpl.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImpl.java new file mode 100644 index 00000000..9cdfd34d --- /dev/null +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImpl.java @@ -0,0 +1,359 @@ +package com.launchdarkly.sdk.server.ai.internal; + +import com.launchdarkly.logging.LDLogger; +import com.launchdarkly.sdk.LDContext; +import com.launchdarkly.sdk.LDValue; +import com.launchdarkly.sdk.ObjectBuilder; +import com.launchdarkly.sdk.server.ai.LDAIConfigTracker; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.FeedbackKind; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.JudgeResult; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.Metrics; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.MetricSummary; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.TokenUsage; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.TrackData; +import com.launchdarkly.sdk.server.interfaces.LDClientInterface; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Function; + +/** + * The production {@link LDAIConfigTracker}. Emits AI metric events through the base + * {@code LDClient} using a shared {@code runId} and the config's correlation data. + *

+ * Thread-safety. The class is safe for concurrent use. Each record-once metric is + * guarded by an {@link AtomicBoolean}: a writer claims the guard with {@code compareAndSet} before + * it stores the value and emits, so exactly one event is produced no matter how many threads call + * concurrently. {@code trackSuccess} and {@code trackError} share a single guard. Tool calls are + * accumulated in a {@link CopyOnWriteArrayList}; tool-call and judge-result events are not + * record-once and emit on every call. Summary fields are written only by the thread that wins the + * guard and are declared {@code volatile} so {@link #getSummary()} observes a consistent snapshot. + *

+ * This class is an internal implementation detail and is not part of the supported API. + */ +public final class LDAIConfigTrackerImpl implements LDAIConfigTracker { + private static final String DURATION_TOTAL = "$ld:ai:duration:total"; + private static final String TIME_TO_FIRST_TOKEN = "$ld:ai:tokens:ttf"; + private static final String TOKENS_TOTAL = "$ld:ai:tokens:total"; + private static final String TOKENS_INPUT = "$ld:ai:tokens:input"; + private static final String TOKENS_OUTPUT = "$ld:ai:tokens:output"; + private static final String GENERATION_SUCCESS = "$ld:ai:generation:success"; + private static final String GENERATION_ERROR = "$ld:ai:generation:error"; + private static final String FEEDBACK_POSITIVE = "$ld:ai:feedback:user:positive"; + private static final String FEEDBACK_NEGATIVE = "$ld:ai:feedback:user:negative"; + private static final String TOOL_CALL = "$ld:ai:tool_call"; + + private final LDClientInterface client; + private final LDLogger logger; + private final LDContext context; + private final String runId; + private final String configKey; + private final String variationKey; + private final int version; + private final String modelName; + private final String providerName; + private final String graphKey; + private final String resumptionToken; + + private final AtomicBoolean durationRecorded = new AtomicBoolean(false); + private final AtomicBoolean timeToFirstTokenRecorded = new AtomicBoolean(false); + private final AtomicBoolean outcomeRecorded = new AtomicBoolean(false); + private final AtomicBoolean feedbackRecorded = new AtomicBoolean(false); + private final AtomicBoolean tokensRecorded = new AtomicBoolean(false); + private final CopyOnWriteArrayList toolCalls = new CopyOnWriteArrayList<>(); + + private volatile Long durationMs; + private volatile Long timeToFirstTokenMs; + private volatile Boolean success; + private volatile FeedbackKind feedback; + private volatile TokenUsage tokens; + + /** + * Creates a tracker for a single AI run. + * + * @param client the base client used to emit events; must not be {@code null} + * @param runId the per-run UUID shared by all of the run's events + * @param configKey the AI Config key + * @param variationKey the variation key, or empty string when unknown + * @param version the AI Config version + * @param modelName the model name, or empty string when unknown + * @param providerName the provider name, or empty string when unknown + * @param context the evaluation context the events are attributed to + * @param graphKey the graph key, or {@code null} when not part of a graph + * @param logger the logger used for skip warnings; must not be {@code null} + */ + public LDAIConfigTrackerImpl( + LDClientInterface client, + String runId, + String configKey, + String variationKey, + int version, + String modelName, + String providerName, + LDContext context, + String graphKey, + LDLogger logger) { + this.client = client; + this.runId = runId; + this.configKey = configKey; + this.variationKey = variationKey == null ? "" : variationKey; + this.version = version; + this.modelName = modelName == null ? "" : modelName; + this.providerName = providerName == null ? "" : providerName; + this.context = context; + this.graphKey = graphKey; + this.logger = logger; + this.resumptionToken = + ResumptionTokens.encode(this.runId, this.configKey, this.variationKey, this.version, this.graphKey); + } + + /** + * Reconstructs a tracker from a resumption token so deferred events correlate with the original + * run. The restored tracker shares the original {@code runId} but reports empty model and provider + * names, which are not carried in the token. + * + * @param token the resumption token + * @param client the base client used to emit events + * @param context the evaluation context the events are attributed to + * @param logger the logger used for skip warnings + * @return a tracker sharing the original run's identity + * @throws IllegalArgumentException if the token is malformed + */ + public static LDAIConfigTrackerImpl fromResumptionToken( + String token, LDClientInterface client, LDContext context, LDLogger logger) { + ResumptionTokens.Decoded d = ResumptionTokens.decode(token); + return new LDAIConfigTrackerImpl( + client, d.getRunId(), d.getConfigKey(), d.getVariationKey(), d.getVersion(), + "", "", context, d.getGraphKey(), logger); + } + + @Override + public TrackData getTrackData() { + return new TrackData(runId, configKey, variationKey, version, modelName, providerName, graphKey); + } + + @Override + public String getResumptionToken() { + return resumptionToken; + } + + @Override + public void trackDuration(Duration duration) { + if (duration == null) { + logger.warn("Skipping trackDuration: duration was null."); + return; + } + if (!durationRecorded.compareAndSet(false, true)) { + logger.warn("Skipping trackDuration: duration already recorded on this tracker."); + return; + } + long ms = Math.max(0L, duration.toMillis()); + this.durationMs = ms; + client.trackMetric(DURATION_TOTAL, context, baseData().build(), ms); + } + + @Override + public T trackDurationOf(Callable operation) throws Exception { + if (operation == null) { + throw new NullPointerException("operation must not be null"); + } + long start = System.nanoTime(); + try { + return operation.call(); + } finally { + long elapsedMs = (System.nanoTime() - start) / 1_000_000L; + trackDuration(Duration.ofMillis(elapsedMs)); + } + } + + @Override + public void trackTimeToFirstToken(Duration duration) { + if (duration == null) { + logger.warn("Skipping trackTimeToFirstToken: duration was null."); + return; + } + if (!timeToFirstTokenRecorded.compareAndSet(false, true)) { + logger.warn("Skipping trackTimeToFirstToken: time to first token already recorded on this tracker."); + return; + } + long ms = Math.max(0L, duration.toMillis()); + this.timeToFirstTokenMs = ms; + client.trackMetric(TIME_TO_FIRST_TOKEN, context, baseData().build(), ms); + } + + @Override + public void trackSuccess() { + if (!outcomeRecorded.compareAndSet(false, true)) { + logger.warn("Skipping trackSuccess: success or error already recorded on this tracker."); + return; + } + this.success = Boolean.TRUE; + client.trackMetric(GENERATION_SUCCESS, context, baseData().build(), 1); + } + + @Override + public void trackError() { + if (!outcomeRecorded.compareAndSet(false, true)) { + logger.warn("Skipping trackError: success or error already recorded on this tracker."); + return; + } + this.success = Boolean.FALSE; + client.trackMetric(GENERATION_ERROR, context, baseData().build(), 1); + } + + @Override + public void trackFeedback(FeedbackKind kind) { + if (kind == null) { + logger.warn("Skipping trackFeedback: feedback kind was null."); + return; + } + if (!feedbackRecorded.compareAndSet(false, true)) { + logger.warn("Skipping trackFeedback: feedback already recorded on this tracker."); + return; + } + this.feedback = kind; + String event = kind == FeedbackKind.POSITIVE ? FEEDBACK_POSITIVE : FEEDBACK_NEGATIVE; + client.trackMetric(event, context, baseData().build(), 1); + } + + @Override + public void trackTokens(TokenUsage rawTokens) { + if (rawTokens == null) { + logger.warn("Skipping trackTokens: token usage was null."); + return; + } + if (!tokensRecorded.compareAndSet(false, true)) { + logger.warn("Skipping trackTokens: token usage already recorded on this tracker."); + return; + } + long total = Math.max(0L, rawTokens.getTotal()); + long input = Math.max(0L, rawTokens.getInput()); + long output = Math.max(0L, rawTokens.getOutput()); + this.tokens = new TokenUsage(total, input, output); + if (total > 0L) { + client.trackMetric(TOKENS_TOTAL, context, baseData().build(), total); + } + if (input > 0L) { + client.trackMetric(TOKENS_INPUT, context, baseData().build(), input); + } + if (output > 0L) { + client.trackMetric(TOKENS_OUTPUT, context, baseData().build(), output); + } + } + + @Override + public void trackToolCall(String toolKey) { + if (toolKey == null) { + logger.warn("Skipping trackToolCall: tool key was null."); + return; + } + toolCalls.add(toolKey); + client.trackMetric(TOOL_CALL, context, baseData().put("toolKey", toolKey).build(), 1); + } + + @Override + public void trackToolCalls(List toolKeys) { + if (toolKeys == null) { + logger.warn("Skipping trackToolCalls: tool keys were null."); + return; + } + for (String toolKey : toolKeys) { + trackToolCall(toolKey); + } + } + + @Override + public void trackJudgeResult(JudgeResult result) { + if (result == null) { + logger.warn("Skipping trackJudgeResult: result was null."); + return; + } + if (!result.isSampled() || !result.isSuccess()) { + return; + } + if (result.getMetricKey() == null || result.getScore() == null) { + return; + } + ObjectBuilder data = baseData(); + if (result.getJudgeConfigKey() != null) { + data.put("judgeConfigKey", result.getJudgeConfigKey()); + } + client.trackMetric(result.getMetricKey(), context, data.build(), result.getScore()); + } + + @Override + public T trackMetricsOf(Function metricsExtractor, Callable operation) + throws Exception { + if (metricsExtractor == null) { + throw new NullPointerException("metricsExtractor must not be null"); + } + if (operation == null) { + throw new NullPointerException("operation must not be null"); + } + + T result; + try { + result = trackDurationOf(operation); + } catch (Exception e) { + trackError(); + throw e; + } + + Metrics metrics; + try { + metrics = metricsExtractor.apply(result); + } catch (RuntimeException e) { + trackError(); + throw e; + } + + if (metrics != null) { + if (metrics.isSuccess()) { + trackSuccess(); + } else { + trackError(); + } + if (metrics.getTokens() != null) { + trackTokens(metrics.getTokens()); + } + List calls = metrics.getToolCalls(); + if (calls != null && !calls.isEmpty()) { + trackToolCalls(calls); + } + } + return result; + } + + @Override + public MetricSummary getSummary() { + MetricSummary.Builder b = MetricSummary.builder() + .success(success) + .tokens(tokens) + .durationMs(durationMs) + .timeToFirstTokenMs(timeToFirstTokenMs) + .feedback(feedback) + .resumptionToken(resumptionToken); + if (!toolCalls.isEmpty()) { + b.toolCalls(new ArrayList<>(toolCalls)); + } + return b.build(); + } + + private ObjectBuilder baseData() { + ObjectBuilder b = LDValue.buildObject() + .put("runId", runId) + .put("configKey", configKey) + .put("variationKey", variationKey) + .put("version", version) + .put("modelName", modelName) + .put("providerName", providerName); + if (graphKey != null) { + b.put("graphKey", graphKey); + } + return b; + } +} diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/NoOpAIConfigTracker.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/NoOpAIConfigTracker.java deleted file mode 100644 index 1cbc3c51..00000000 --- a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/NoOpAIConfigTracker.java +++ /dev/null @@ -1,19 +0,0 @@ -package com.launchdarkly.sdk.server.ai.internal; - -import com.launchdarkly.sdk.server.ai.LDAIConfigTracker; - -/** - * The no-op {@link LDAIConfigTracker} used until metric reporting is implemented in a later step of - * the AI SDK. It is immutable and stateless, so a single shared instance is safe to reuse. - *

- * This class is an internal implementation detail and is not part of the supported API. - */ -public final class NoOpAIConfigTracker implements LDAIConfigTracker { - /** - * The shared instance. - */ - public static final NoOpAIConfigTracker INSTANCE = new NoOpAIConfigTracker(); - - private NoOpAIConfigTracker() { - } -} diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokens.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokens.java new file mode 100644 index 00000000..bfb5ac47 --- /dev/null +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokens.java @@ -0,0 +1,236 @@ +package com.launchdarkly.sdk.server.ai.internal; + +import com.launchdarkly.sdk.LDValue; +import com.launchdarkly.sdk.LDValueType; + +import java.nio.charset.StandardCharsets; +import java.util.Base64; + +/** + * Encodes and decodes AI run resumption tokens. + *

+ * A token is the URL-safe Base64 (no padding) of a canonical JSON object whose keys appear in a + * fixed order: {@code runId, configKey, variationKey, version}, followed by {@code graphKey} only + * when it is set. {@code variationKey} is always present (empty string when unknown) so the encoding + * is byte-compatible with the other LaunchDarkly SDKs. {@code modelName} and {@code providerName} + * are intentionally not carried; a tracker reconstructed from a token reports them as empty strings. + *

+ * Decoding is strict: each field is type-validated, and malformed or oversized tokens are rejected + * with an {@link IllegalArgumentException}. + *

+ * This class is an internal implementation detail and is not part of the supported API. + */ +public final class ResumptionTokens { + /** + * Maximum size, in bytes, of a token's decoded JSON payload. Anything larger is rejected to bound + * the work done parsing untrusted input. + */ + static final int MAX_PAYLOAD_BYTES = 4096; + + private ResumptionTokens() { + } + + /** + * Encodes a resumption token. + * + * @param runId the per-run UUID + * @param configKey the AI Config key + * @param variationKey the variation key; encoded as empty string when {@code null} + * @param version the AI Config version + * @param graphKey the graph key, or {@code null} to omit it + * @return the URL-safe Base64 token + */ + public static String encode( + String runId, String configKey, String variationKey, int version, String graphKey) { + StringBuilder json = new StringBuilder(96); + json.append("{\"runId\":").append(jsonString(runId)) + .append(",\"configKey\":").append(jsonString(configKey)) + .append(",\"variationKey\":").append(jsonString(variationKey == null ? "" : variationKey)) + .append(",\"version\":").append(version); + if (graphKey != null) { + json.append(",\"graphKey\":").append(jsonString(graphKey)); + } + json.append('}'); + return Base64.getUrlEncoder().withoutPadding() + .encodeToString(json.toString().getBytes(StandardCharsets.UTF_8)); + } + + /** + * Decodes and validates a resumption token. + * + * @param token the token to decode + * @return the decoded fields + * @throws IllegalArgumentException if the token is {@code null}, oversized, not valid Base64, not + * a JSON object, or missing/mistyped required fields + */ + public static Decoded decode(String token) { + if (token == null) { + throw new IllegalArgumentException("resumption token must not be null"); + } + byte[] bytes; + try { + bytes = Base64.getUrlDecoder().decode(token); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("malformed resumption token: invalid Base64", e); + } + if (bytes.length > MAX_PAYLOAD_BYTES) { + throw new IllegalArgumentException( + "malformed resumption token: payload exceeds " + MAX_PAYLOAD_BYTES + " bytes"); + } + + LDValue payload; + try { + payload = LDValue.parse(new String(bytes, StandardCharsets.UTF_8)); + } catch (RuntimeException e) { + throw new IllegalArgumentException("malformed resumption token: invalid JSON", e); + } + if (payload.getType() != LDValueType.OBJECT) { + throw new IllegalArgumentException("malformed resumption token: expected a JSON object"); + } + + String runId = requireString(payload, "runId"); + String configKey = requireString(payload, "configKey"); + String variationKey = optionalString(payload, "variationKey", ""); + int version = requireInt(payload, "version"); + String graphKey = optionalString(payload, "graphKey", null); + return new Decoded(runId, configKey, variationKey, version, graphKey); + } + + private static String requireString(LDValue object, String field) { + LDValue v = object.get(field); + if (!v.isString()) { + throw new IllegalArgumentException( + "malformed resumption token: field '" + field + "' must be a string"); + } + return v.stringValue(); + } + + private static String optionalString(LDValue object, String field, String defaultValue) { + LDValue v = object.get(field); + if (v.getType() == LDValueType.NULL) { + return defaultValue; + } + if (!v.isString()) { + throw new IllegalArgumentException( + "malformed resumption token: field '" + field + "' must be a string"); + } + return v.stringValue(); + } + + private static int requireInt(LDValue object, String field) { + LDValue v = object.get(field); + if (v.getType() != LDValueType.NUMBER || !v.isInt()) { + throw new IllegalArgumentException( + "malformed resumption token: field '" + field + "' must be an integer"); + } + return v.intValue(); + } + + /** + * Escapes a Java string as a JSON string literal using the standard JSON escapes (matching + * {@code JSON.stringify}), so the encoding stays byte-compatible across SDKs. + */ + private static String jsonString(String s) { + StringBuilder sb = new StringBuilder(s.length() + 2); + sb.append('"'); + for (int i = 0; i < s.length(); i++) { + char c = s.charAt(i); + switch (c) { + case '"': + sb.append("\\\""); + break; + case '\\': + sb.append("\\\\"); + break; + case '\b': + sb.append("\\b"); + break; + case '\f': + sb.append("\\f"); + break; + case '\n': + sb.append("\\n"); + break; + case '\r': + sb.append("\\r"); + break; + case '\t': + sb.append("\\t"); + break; + default: + if (c < 0x20) { + sb.append(String.format("\\u%04x", (int) c)); + } else { + sb.append(c); + } + break; + } + } + sb.append('"'); + return sb.toString(); + } + + /** + * The validated fields decoded from a resumption token. + */ + public static final class Decoded { + private final String runId; + private final String configKey; + private final String variationKey; + private final int version; + private final String graphKey; + + Decoded(String runId, String configKey, String variationKey, int version, String graphKey) { + this.runId = runId; + this.configKey = configKey; + this.variationKey = variationKey; + this.version = version; + this.graphKey = graphKey; + } + + /** + * Returns the per-run UUID. + * + * @return the run id + */ + public String getRunId() { + return runId; + } + + /** + * Returns the AI Config key. + * + * @return the config key + */ + public String getConfigKey() { + return configKey; + } + + /** + * Returns the variation key. + * + * @return the variation key, never {@code null} + */ + public String getVariationKey() { + return variationKey; + } + + /** + * Returns the AI Config version. + * + * @return the version + */ + public int getVersion() { + return version; + } + + /** + * Returns the graph key. + * + * @return the graph key, or {@code null} when absent + */ + public String getGraphKey() { + return graphKey; + } + } +} diff --git a/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/LDAIClientImplTest.java b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/LDAIClientImplTest.java index 11253f45..c6d78206 100644 --- a/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/LDAIClientImplTest.java +++ b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/LDAIClientImplTest.java @@ -278,4 +278,46 @@ public void agentConfigsHandlesEmptyList() { private static Map variables() { return new HashMap<>(); } + + // ---- tracker wiring ------------------------------------------------------- + + @Test + public void configTrackerCarriesVariationAndModelMetadata() { + String json = "{\"_ldMeta\":{\"enabled\":true,\"mode\":\"completion\",\"variationKey\":\"var-7\",\"version\":11}," + + "\"model\":{\"name\":\"gpt-4\"},\"provider\":{\"name\":\"openai\"}}"; + when(client.jsonValueVariation(anyString(), any(), any())).thenReturn(LDValue.parse(json)); + + AICompletionConfig config = ai.completionConfig("key", context, null, null); + LDAIConfigTracker tracker = config.createTracker(); + + assertThat(tracker.getTrackData().getConfigKey(), is("key")); + assertThat(tracker.getTrackData().getVariationKey(), is("var-7")); + assertThat(tracker.getTrackData().getVersion(), is(11)); + assertThat(tracker.getTrackData().getModelName(), is("gpt-4")); + assertThat(tracker.getTrackData().getProviderName(), is("openai")); + } + + @Test + public void createTrackerFromTokenSharesRunIdAndConfig() { + String json = "{\"_ldMeta\":{\"enabled\":true,\"mode\":\"completion\",\"variationKey\":\"var-7\",\"version\":11}}"; + when(client.jsonValueVariation(anyString(), any(), any())).thenReturn(LDValue.parse(json)); + + LDAIConfigTracker original = ai.completionConfig("key", context, null, null).createTracker(); + LDAIConfigTracker restored = ai.createTracker(original.getResumptionToken(), context); + + assertThat(restored.getTrackData().getRunId(), is(original.getTrackData().getRunId())); + assertThat(restored.getTrackData().getConfigKey(), is("key")); + assertThat(restored.getTrackData().getVariationKey(), is("var-7")); + assertThat(restored.getTrackData().getVersion(), is(11)); + } + + @Test + public void eachCreateTrackerCallStartsANewRun() { + when(client.jsonValueVariation(anyString(), any(), any())).thenReturn(LDValue.ofNull()); + AICompletionConfig config = ai.completionConfig("key", context, null, null); + String runA = config.createTracker().getTrackData().getRunId(); + String runB = config.createTracker().getTrackData().getRunId(); + assertThat(runA, is(notNullValue())); + assertThat(runA.equals(runB), is(false)); + } } diff --git a/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImplTest.java b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImplTest.java new file mode 100644 index 00000000..59b49c5f --- /dev/null +++ b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImplTest.java @@ -0,0 +1,385 @@ +package com.launchdarkly.sdk.server.ai.internal; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyDouble; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; + +import com.launchdarkly.logging.LDLogger; +import com.launchdarkly.logging.Logs; +import com.launchdarkly.sdk.LDContext; +import com.launchdarkly.sdk.LDValue; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.FeedbackKind; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.JudgeResult; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.Metrics; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.MetricSummary; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.TokenUsage; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.TrackData; +import com.launchdarkly.sdk.server.interfaces.LDClientInterface; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; + +import org.junit.Before; +import org.junit.Test; + +@SuppressWarnings("javadoc") +public class LDAIConfigTrackerImplTest { + private LDClientInterface client; + private LDLogger logger; + private final LDContext context = LDContext.create("user-key"); + private final List events = new CopyOnWriteArrayList<>(); + + private static final class Event { + final String name; + final LDValue data; + final double metric; + + Event(String name, LDValue data, double metric) { + this.name = name; + this.data = data; + this.metric = metric; + } + } + + @Before + public void setUp() { + client = mock(LDClientInterface.class); + logger = LDLogger.withAdapter(Logs.capture(), "test"); + doAnswer(inv -> { + events.add(new Event(inv.getArgument(0), inv.getArgument(2), inv.getArgument(3))); + return null; + }).when(client).trackMetric(anyString(), any(), any(), anyDouble()); + } + + private LDAIConfigTrackerImpl tracker() { + return new LDAIConfigTrackerImpl( + client, "run-1", "cfg", "v1", 3, "gpt-4", "openai", context, null, logger); + } + + private List named(String name) { + return events.stream().filter(e -> e.name.equals(name)).collect(Collectors.toList()); + } + + // ---- duration ------------------------------------------------------------- + + @Test + public void trackDurationEmitsOnce() { + LDAIConfigTrackerImpl t = tracker(); + t.trackDuration(Duration.ofMillis(120)); + t.trackDuration(Duration.ofMillis(999)); + List e = named("$ld:ai:duration:total"); + assertThat(e, hasSize(1)); + assertThat(e.get(0).metric, is(120.0)); + assertThat(e.get(0).data.get("runId").stringValue(), is("run-1")); + } + + @Test + public void trackDurationClampsNegative() { + LDAIConfigTrackerImpl t = tracker(); + t.trackDuration(Duration.ofMillis(-50)); + assertThat(named("$ld:ai:duration:total").get(0).metric, is(0.0)); + } + + @Test + public void trackDurationOfReturnsResultAndRecordsDuration() throws Exception { + LDAIConfigTrackerImpl t = tracker(); + String result = t.trackDurationOf(() -> "ok"); + assertThat(result, is("ok")); + assertThat(named("$ld:ai:duration:total"), hasSize(1)); + } + + @Test + public void trackDurationOfRecordsDurationEvenWhenOperationThrows() { + LDAIConfigTrackerImpl t = tracker(); + assertThrows(IllegalStateException.class, () -> t.trackDurationOf(() -> { + throw new IllegalStateException("boom"); + })); + assertThat(named("$ld:ai:duration:total"), hasSize(1)); + } + + // ---- time to first token -------------------------------------------------- + + @Test + public void trackTimeToFirstTokenEmitsOnce() { + LDAIConfigTrackerImpl t = tracker(); + t.trackTimeToFirstToken(Duration.ofMillis(40)); + t.trackTimeToFirstToken(Duration.ofMillis(80)); + List e = named("$ld:ai:tokens:ttf"); + assertThat(e, hasSize(1)); + assertThat(e.get(0).metric, is(40.0)); + } + + // ---- success / error ------------------------------------------------------ + + @Test + public void successAndErrorShareAtMostOnce() { + LDAIConfigTrackerImpl t = tracker(); + t.trackSuccess(); + t.trackError(); + assertThat(named("$ld:ai:generation:success"), hasSize(1)); + assertThat(named("$ld:ai:generation:error"), is(empty())); + } + + // ---- feedback ------------------------------------------------------------- + + @Test + public void trackFeedbackEmitsOnceForKind() { + LDAIConfigTrackerImpl t = tracker(); + t.trackFeedback(FeedbackKind.POSITIVE); + t.trackFeedback(FeedbackKind.NEGATIVE); + assertThat(named("$ld:ai:feedback:user:positive"), hasSize(1)); + assertThat(named("$ld:ai:feedback:user:negative"), is(empty())); + } + + @Test + public void trackFeedbackNullIsIgnored() { + LDAIConfigTrackerImpl t = tracker(); + t.trackFeedback(null); + assertThat(events, is(empty())); + } + + // ---- tokens --------------------------------------------------------------- + + @Test + public void trackTokensEmitsEachPositiveComponentOnce() { + LDAIConfigTrackerImpl t = tracker(); + t.trackTokens(new TokenUsage(30, 10, 20)); + t.trackTokens(new TokenUsage(99, 99, 99)); + assertThat(named("$ld:ai:tokens:total"), hasSize(1)); + assertThat(named("$ld:ai:tokens:total").get(0).metric, is(30.0)); + assertThat(named("$ld:ai:tokens:input").get(0).metric, is(10.0)); + assertThat(named("$ld:ai:tokens:output").get(0).metric, is(20.0)); + } + + @Test + public void trackTokensSkipsZeroAndClampsNegative() { + LDAIConfigTrackerImpl t = tracker(); + t.trackTokens(new TokenUsage(0, -5, 7)); + assertThat(named("$ld:ai:tokens:total"), is(empty())); + assertThat(named("$ld:ai:tokens:input"), is(empty())); + assertThat(named("$ld:ai:tokens:output").get(0).metric, is(7.0)); + } + + // ---- tool calls ----------------------------------------------------------- + + @Test + public void trackToolCallEmitsEachTimeWithToolKey() { + LDAIConfigTrackerImpl t = tracker(); + t.trackToolCall("search"); + t.trackToolCall("search"); + List e = named("$ld:ai:tool_call"); + assertThat(e, hasSize(2)); + assertThat(e.get(0).data.get("toolKey").stringValue(), is("search")); + } + + @Test + public void trackToolCallsRecordsAll() { + LDAIConfigTrackerImpl t = tracker(); + t.trackToolCalls(Arrays.asList("a", "b", "c")); + assertThat(named("$ld:ai:tool_call"), hasSize(3)); + assertThat(t.getSummary().getToolCalls(), contains("a", "b", "c")); + } + + @Test + public void trackToolCallNullIsIgnored() { + LDAIConfigTrackerImpl t = tracker(); + t.trackToolCall(null); + assertThat(events, is(empty())); + } + + // ---- judge result --------------------------------------------------------- + + @Test + public void trackJudgeResultEmitsScoreAgainstMetricKey() { + LDAIConfigTrackerImpl t = tracker(); + t.trackJudgeResult(JudgeResult.builder(true, true) + .metricKey("relevance").score(0.75).judgeConfigKey("judge-1").build()); + List e = named("relevance"); + assertThat(e, hasSize(1)); + assertThat(e.get(0).metric, is(0.75)); + assertThat(e.get(0).data.get("judgeConfigKey").stringValue(), is("judge-1")); + } + + @Test + public void trackJudgeResultEmitsForLegitimateZeroScore() { + LDAIConfigTrackerImpl t = tracker(); + t.trackJudgeResult(JudgeResult.builder(true, true).metricKey("relevance").score(0.0).build()); + assertThat(named("relevance"), hasSize(1)); + assertThat(named("relevance").get(0).metric, is(0.0)); + } + + @Test + public void trackJudgeResultSkippedWhenNotSampledOrNoScore() { + LDAIConfigTrackerImpl t = tracker(); + t.trackJudgeResult(JudgeResult.builder(false, true).metricKey("relevance").score(0.9).build()); + t.trackJudgeResult(JudgeResult.builder(true, true).metricKey("relevance").build()); + assertThat(events, is(empty())); + } + + // ---- trackMetricsOf ------------------------------------------------------- + + @Test + public void trackMetricsOfRecordsOutcomeTokensAndToolCalls() throws Exception { + LDAIConfigTrackerImpl t = tracker(); + String result = t.trackMetricsOf( + r -> Metrics.builder(true) + .tokens(new TokenUsage(15, 5, 10)) + .toolCalls(Arrays.asList("x")) + .build(), + () -> "answer"); + assertThat(result, is("answer")); + assertThat(named("$ld:ai:duration:total"), hasSize(1)); + assertThat(named("$ld:ai:generation:success"), hasSize(1)); + assertThat(named("$ld:ai:tokens:total"), hasSize(1)); + assertThat(named("$ld:ai:tool_call"), hasSize(1)); + } + + @Test + public void trackMetricsOfRecordsErrorAndRethrowsWhenOperationThrows() { + LDAIConfigTrackerImpl t = tracker(); + assertThrows(IllegalStateException.class, () -> t.trackMetricsOf( + r -> Metrics.builder(true).build(), + () -> { + throw new IllegalStateException("op failed"); + })); + assertThat(named("$ld:ai:generation:error"), hasSize(1)); + assertThat(named("$ld:ai:duration:total"), hasSize(1)); + } + + @Test + public void trackMetricsOfRecordsErrorAndRethrowsWhenExtractorThrows() { + LDAIConfigTrackerImpl t = tracker(); + assertThrows(RuntimeException.class, () -> t.trackMetricsOf( + r -> { + throw new RuntimeException("extractor failed"); + }, + () -> "answer")); + assertThat(named("$ld:ai:generation:error"), hasSize(1)); + } + + // ---- data / summary ------------------------------------------------------- + + @Test + public void getTrackDataExposesCorrelationFields() { + TrackData d = tracker().getTrackData(); + assertThat(d.getRunId(), is("run-1")); + assertThat(d.getConfigKey(), is("cfg")); + assertThat(d.getVariationKey(), is("v1")); + assertThat(d.getVersion(), is(3)); + assertThat(d.getModelName(), is("gpt-4")); + assertThat(d.getProviderName(), is("openai")); + } + + @Test + public void resumptionTokenRoundTrips() { + LDAIConfigTrackerImpl t = tracker(); + LDAIConfigTrackerImpl restored = + LDAIConfigTrackerImpl.fromResumptionToken(t.getResumptionToken(), client, context, logger); + assertThat(restored.getTrackData().getRunId(), is("run-1")); + assertThat(restored.getTrackData().getConfigKey(), is("cfg")); + assertThat(restored.getTrackData().getVariationKey(), is("v1")); + assertThat(restored.getTrackData().getVersion(), is(3)); + // Model and provider names are not carried in the token. + assertThat(restored.getTrackData().getModelName(), is("")); + assertThat(restored.getTrackData().getProviderName(), is("")); + } + + @Test + public void summaryReflectsRecordedMetrics() { + LDAIConfigTrackerImpl t = tracker(); + t.trackDuration(Duration.ofMillis(100)); + t.trackSuccess(); + t.trackTokens(new TokenUsage(20, 8, 12)); + t.trackFeedback(FeedbackKind.POSITIVE); + MetricSummary s = t.getSummary(); + assertThat(s.getDurationMs(), is(100L)); + assertThat(s.getSuccess(), is(true)); + assertThat(s.getTokens(), is(new TokenUsage(20, 8, 12))); + assertThat(s.getFeedback(), is(FeedbackKind.POSITIVE)); + assertThat(s.getResumptionToken(), is(t.getResumptionToken())); + } + + // ---- concurrency ---------------------------------------------------------- + + @Test + public void concurrentOutcomeRecordsExactlyOneEvent() throws Exception { + LDAIConfigTrackerImpl t = tracker(); + runConcurrently(32, i -> { + if (i % 2 == 0) { + t.trackSuccess(); + } else { + t.trackError(); + } + }); + int outcomes = named("$ld:ai:generation:success").size() + named("$ld:ai:generation:error").size(); + assertThat(outcomes, is(1)); + } + + @Test + public void concurrentDurationRecordsExactlyOnce() throws Exception { + LDAIConfigTrackerImpl t = tracker(); + runConcurrently(32, i -> t.trackDuration(Duration.ofMillis(10 + i))); + assertThat(named("$ld:ai:duration:total"), hasSize(1)); + } + + @Test + public void concurrentToolCallsRecordAllWithIntactList() throws Exception { + LDAIConfigTrackerImpl t = tracker(); + runConcurrently(50, i -> t.trackToolCall("tool-" + i)); + assertThat(named("$ld:ai:tool_call"), hasSize(50)); + List expected = new ArrayList<>(); + for (int i = 0; i < 50; i++) { + expected.add("tool-" + i); + } + assertThat(t.getSummary().getToolCalls(), containsInAnyOrder(expected.toArray())); + } + + private interface IndexedTask { + void run(int index); + } + + private static void runConcurrently(int threads, IndexedTask task) throws InterruptedException { + ExecutorService pool = Executors.newFixedThreadPool(threads); + CountDownLatch ready = new CountDownLatch(threads); + CountDownLatch go = new CountDownLatch(1); + CountDownLatch done = new CountDownLatch(threads); + AtomicInteger failures = new AtomicInteger(0); + for (int i = 0; i < threads; i++) { + final int index = i; + pool.execute(() -> { + ready.countDown(); + try { + go.await(); + task.run(index); + } catch (Throwable t) { + failures.incrementAndGet(); + } finally { + done.countDown(); + } + }); + } + ready.await(); + go.countDown(); + done.await(10, TimeUnit.SECONDS); + pool.shutdownNow(); + assertThat(failures.get(), is(0)); + } +} diff --git a/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokensTest.java b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokensTest.java new file mode 100644 index 00000000..b499177d --- /dev/null +++ b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokensTest.java @@ -0,0 +1,135 @@ +package com.launchdarkly.sdk.server.ai.internal; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; +import static org.junit.Assert.assertThrows; + +import java.nio.charset.StandardCharsets; +import java.util.Base64; + +import org.junit.Test; + +@SuppressWarnings("javadoc") +public class ResumptionTokensTest { + // Byte-for-byte fixtures produced by base64url-encoding the canonical JSON, matching the other + // LaunchDarkly SDKs. If these change, cross-SDK resumption is broken. + private static final String FIXTURE_WITH_VARIATION = + "eyJydW5JZCI6InJ1bi0xIiwiY29uZmlnS2V5IjoiY2ZnIiwidmFyaWF0aW9uS2V5IjoidjEiLCJ2ZXJzaW9uIjozfQ"; + private static final String FIXTURE_EMPTY_VARIATION = + "eyJydW5JZCI6ImFiYyIsImNvbmZpZ0tleSI6Im15LWNvbmZpZyIsInZhcmlhdGlvbktleSI6IiIsInZlcnNpb24iOjB9"; + + @Test + public void encodeIsByteCompatibleWithFixture() { + assertThat(ResumptionTokens.encode("run-1", "cfg", "v1", 3, null), is(FIXTURE_WITH_VARIATION)); + } + + @Test + public void encodeAlwaysEmitsVariationKeyEvenWhenEmpty() { + assertThat(ResumptionTokens.encode("abc", "my-config", "", 0, null), is(FIXTURE_EMPTY_VARIATION)); + } + + @Test + public void encodeTreatsNullVariationKeyAsEmpty() { + assertThat(ResumptionTokens.encode("abc", "my-config", null, 0, null), is(FIXTURE_EMPTY_VARIATION)); + } + + @Test + public void canonicalKeyOrderIsFixed() { + String token = ResumptionTokens.encode("r", "c", "v", 7, "g"); + String json = new String(Base64.getUrlDecoder().decode(token), StandardCharsets.UTF_8); + assertThat(json, is("{\"runId\":\"r\",\"configKey\":\"c\",\"variationKey\":\"v\",\"version\":7,\"graphKey\":\"g\"}")); + } + + @Test + public void graphKeyOmittedWhenNull() { + String token = ResumptionTokens.encode("r", "c", "v", 7, null); + String json = new String(Base64.getUrlDecoder().decode(token), StandardCharsets.UTF_8); + assertThat(json, is("{\"runId\":\"r\",\"configKey\":\"c\",\"variationKey\":\"v\",\"version\":7}")); + } + + @Test + public void roundTripsAllFields() { + String token = ResumptionTokens.encode("run-9", "cfg-9", "var-9", 42, "graph-9"); + ResumptionTokens.Decoded d = ResumptionTokens.decode(token); + assertThat(d.getRunId(), is("run-9")); + assertThat(d.getConfigKey(), is("cfg-9")); + assertThat(d.getVariationKey(), is("var-9")); + assertThat(d.getVersion(), is(42)); + assertThat(d.getGraphKey(), is("graph-9")); + } + + @Test + public void decodeFixtureWithVariation() { + ResumptionTokens.Decoded d = ResumptionTokens.decode(FIXTURE_WITH_VARIATION); + assertThat(d.getRunId(), is("run-1")); + assertThat(d.getConfigKey(), is("cfg")); + assertThat(d.getVariationKey(), is("v1")); + assertThat(d.getVersion(), is(3)); + assertThat(d.getGraphKey(), is(nullValue())); + } + + @Test + public void decodeDefaultsAbsentVariationKeyToEmpty() { + String token = base64Url("{\"runId\":\"r\",\"configKey\":\"c\",\"version\":1}"); + ResumptionTokens.Decoded d = ResumptionTokens.decode(token); + assertThat(d.getVariationKey(), is("")); + } + + @Test + public void escapesSpecialCharactersInValues() { + String token = ResumptionTokens.encode("a\"b\\c\nd", "cfg", "v", 1, null); + ResumptionTokens.Decoded d = ResumptionTokens.decode(token); + assertThat(d.getRunId(), is("a\"b\\c\nd")); + } + + @Test + public void rejectsNullToken() { + assertThrows(IllegalArgumentException.class, () -> ResumptionTokens.decode(null)); + } + + @Test + public void rejectsInvalidBase64() { + assertThrows(IllegalArgumentException.class, () -> ResumptionTokens.decode("!!!not base64!!!")); + } + + @Test + public void rejectsNonObjectJson() { + assertThrows(IllegalArgumentException.class, () -> ResumptionTokens.decode(base64Url("[1,2,3]"))); + } + + @Test + public void rejectsMissingRunId() { + assertThrows(IllegalArgumentException.class, + () -> ResumptionTokens.decode(base64Url("{\"configKey\":\"c\",\"version\":1}"))); + } + + @Test + public void rejectsMistypedVersion() { + assertThrows(IllegalArgumentException.class, + () -> ResumptionTokens.decode(base64Url("{\"runId\":\"r\",\"configKey\":\"c\",\"version\":\"x\"}"))); + } + + @Test + public void rejectsNonIntegerVersion() { + assertThrows(IllegalArgumentException.class, + () -> ResumptionTokens.decode(base64Url("{\"runId\":\"r\",\"configKey\":\"c\",\"version\":1.5}"))); + } + + @Test + public void rejectsOversizedPayload() { + StringBuilder big = new StringBuilder("{\"runId\":\""); + for (int i = 0; i < ResumptionTokens.MAX_PAYLOAD_BYTES + 100; i++) { + big.append('a'); + } + big.append("\",\"configKey\":\"c\",\"version\":1}"); + String token = base64Url(big.toString()); + assertThat(token, is(notNullValue())); + assertThrows(IllegalArgumentException.class, () -> ResumptionTokens.decode(token)); + } + + private static String base64Url(String json) { + return Base64.getUrlEncoder().withoutPadding().encodeToString(json.getBytes(StandardCharsets.UTF_8)); + } +}