diff --git a/core/src/main/java/io/substrait/expression/FieldReference.java b/core/src/main/java/io/substrait/expression/FieldReference.java index 4d525bcbc..173112942 100644 --- a/core/src/main/java/io/substrait/expression/FieldReference.java +++ b/core/src/main/java/io/substrait/expression/FieldReference.java @@ -40,10 +40,31 @@ public abstract class FieldReference implements Expression { /** * Returns the number of subquery levels stepped out of for an outer reference, if applicable. * + *

This offset-based mechanism is used for tree-shaped plans. For plans where a relation is + * shared via a {@code ReferenceRel} (making the reference target ambiguous), use {@link + * #outerReferenceRelReference()} instead. When both forms are set, conversion to protobuf prefers + * the id-based {@link #outerReferenceRelReference()}. + * * @return the optional number of steps out */ public abstract Optional outerReferenceStepsOut(); + /** + * Returns the plan-wide unique {@code relAnchor} of the relation this outer reference is rooted + * on, if applicable. + * + *

This id-based mechanism resolves outer references unambiguously in DAG-shaped plans where a + * relation is shared via a {@code ReferenceRel} and the offset-based {@link + * #outerReferenceStepsOut()} would be ambiguous. The value must match a {@link + * io.substrait.relation.Rel#getRelAnchor()} defined elsewhere in the plan. + * + *

This is the preferred outer-reference form: when it is set, conversion to protobuf emits it + * in favour of {@link #outerReferenceStepsOut()}. + * + * @return the optional referenced {@code relAnchor} + */ + public abstract Optional outerReferenceRelReference(); + /** * Returns the number of lambda nesting levels stepped out of for a lambda parameter reference, if * applicable. @@ -73,16 +94,28 @@ public R accept( } /** - * Validates that a field reference is not simultaneously an outer reference and a lambda - * parameter reference. + * Validates that this reference is not simultaneously an outer reference and a lambda parameter + * reference. + * + *

The two outer-reference forms — the offset-based {@link #outerReferenceStepsOut()} and the + * id-based {@link #outerReferenceRelReference()} — may both be set at once. During the + * transition towards id-based resolution a producer can dual-populate both forms; conversion to + * protobuf then prefers the id-based form (see {@code ExpressionProtoConverter}). This mirrors + * the Substrait + * breaking change policy, under which consumers prefer the new representation while the old + * one remains readable. * - * @throws IllegalArgumentException if both step-out values are set + * @throws IllegalArgumentException if this is both an outer reference and a lambda parameter + * reference */ @Value.Check protected void check() { - if (outerReferenceStepsOut().isPresent() && lambdaParameterReferenceStepsOut().isPresent()) { + boolean isOuterReference = + outerReferenceStepsOut().isPresent() || outerReferenceRelReference().isPresent(); + if (isOuterReference && lambdaParameterReferenceStepsOut().isPresent()) { throw new IllegalArgumentException( - "FieldReference cannot have both outerReferenceStepsOut and lambdaParameterReferenceStepsOut set"); + "FieldReference cannot be both an outer reference and a lambda parameter reference"); } } @@ -95,16 +128,19 @@ public boolean isSimpleRootReference() { return segments().size() == 1 && !inputExpression().isPresent() && !outerReferenceStepsOut().isPresent() + && !outerReferenceRelReference().isPresent() && !lambdaParameterReferenceStepsOut().isPresent(); } /** - * Returns whether this reference steps out into an enclosing (outer) query. + * Returns whether this reference steps out into an enclosing (outer) query, via either the + * offset-based ({@link #outerReferenceStepsOut()}) or id-based ({@link + * #outerReferenceRelReference()}) mechanism. * * @return {@code true} if this is an outer reference */ public boolean isOuterReference() { - return outerReferenceStepsOut().orElse(0) > 0; + return outerReferenceStepsOut().orElse(0) > 0 || outerReferenceRelReference().isPresent(); } /** @@ -234,6 +270,24 @@ public static FieldReference newRootStructOuterReference( .build(); } + /** + * Creates an id-based reference to a field of an enclosing (outer) query's root struct, resolved + * via the referenced relation's {@link Rel#getRelAnchor()} rather than a subquery-level offset. + * + * @param index the struct field index + * @param knownType the known type of the referenced field + * @param relReference the {@code relAnchor} of the relation this field reference is rooted on + * @return the field reference + */ + public static FieldReference newRootStructOuterReferenceByRelReference( + int index, Type knownType, int relReference) { + return ImmutableFieldReference.builder() + .addSegments(StructField.of(index)) + .type(knownType) + .outerReferenceRelReference(relReference) + .build(); + } + /** * Creates a reference to a field of a single input relation by overall field index. * diff --git a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java index 0ebf340dc..8e31fad80 100644 --- a/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ExpressionProtoConverter.java @@ -657,6 +657,15 @@ public Expression visit(FieldReference expr, EmptyVisitationContext context) { if (expr.inputExpression().isPresent()) { out.setExpression(toProto(expr.inputExpression().get())); + } else if (expr.outerReferenceRelReference().isPresent()) { + // Prefer the id-based outer reference when set: it resolves unambiguously in DAG-shaped plans + // and is the direction we are migrating towards. steps_out and rel_reference share a protobuf + // oneof, so at most one can be emitted; existing offset-only producers keep emitting + // steps_out + // and remain readable by older consumers. + out.setOuterReference( + io.substrait.proto.Expression.FieldReference.OuterReference.newBuilder() + .setRelReference(expr.outerReferenceRelReference().get())); } else if (expr.outerReferenceStepsOut().isPresent()) { out.setOuterReference( io.substrait.proto.Expression.FieldReference.OuterReference.newBuilder() diff --git a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java index 76a65a803..b6384a648 100644 --- a/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java +++ b/core/src/main/java/io/substrait/expression/proto/ProtoExpressionConverter.java @@ -88,10 +88,23 @@ public FieldReference from(io.substrait.proto.Expression.FieldReference referenc return FieldReference.ofRoot( rootType, getDirectReferenceSegments(reference.getDirectReference())); case OUTER_REFERENCE: - return FieldReference.newRootStructOuterReference( - reference.getDirectReference().getStructField().getField(), - rootType, - reference.getOuterReference().getStepsOut()); + { + io.substrait.proto.Expression.FieldReference.OuterReference outerReference = + reference.getOuterReference(); + int field = reference.getDirectReference().getStructField().getField(); + switch (outerReference.getOuterReferenceTypeCase()) { + case STEPS_OUT: + return FieldReference.newRootStructOuterReference( + field, rootType, outerReference.getStepsOut()); + case REL_REFERENCE: + return FieldReference.newRootStructOuterReferenceByRelReference( + field, rootType, outerReference.getRelReference()); + case OUTERREFERENCETYPE_NOT_SET: + default: + throw new IllegalArgumentException( + "Unhandled outer reference type: " + outerReference.getOuterReferenceTypeCase()); + } + } case LAMBDA_PARAMETER_REFERENCE: { io.substrait.proto.Expression.FieldReference.LambdaParameterReference lambdaParamRef = diff --git a/core/src/main/java/io/substrait/relation/OuterReferenceConverter.java b/core/src/main/java/io/substrait/relation/OuterReferenceConverter.java new file mode 100644 index 000000000..5cd6c7d9c --- /dev/null +++ b/core/src/main/java/io/substrait/relation/OuterReferenceConverter.java @@ -0,0 +1,368 @@ +package io.substrait.relation; + +import io.substrait.expression.Expression; +import io.substrait.expression.FieldReference; +import io.substrait.expression.ImmutableFieldReference; +import io.substrait.plan.Plan; +import io.substrait.util.EmptyVisitationContext; +import java.util.ArrayList; +import java.util.IdentityHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +/** + * Converts a Substrait plan between the two outer-reference resolution encodings: + * + *

+ * + *

The two forms are semantically equivalent for tree-shaped plans; this class translates one + * into the other so integrations can rely on a single encoding while the wire format keeps both + * (see the Substrait breaking + * change policy). Anchor assignment requires plan-wide context, which is why this lives in core + * rather than being duplicated per integration. + * + *

Supported scope: outer references whose binding relation is the input of a single-input + * expression host ({@link Filter}, {@link Project}). This covers correlated scalar/IN/EXISTS + * subqueries as produced by the SQL integrations. Multi-input scopes (a correlated reference into a + * join/set condition) and shared subtrees ({@code ReferenceRel}) are not representable as a single + * binding relation and cause an {@link UnsupportedOperationException}. + * + *

Anchor allocation starts at {@code 1}; callers must not pass {@code toIdBased} a plan + * that already carries {@code rel_anchor}s. + */ +public final class OuterReferenceConverter { + + private OuterReferenceConverter() {} + + private enum Direction { + TO_ID, + TO_STEPS_OUT + } + + /** + * Rewrites offset-based outer references ({@code steps_out}) as id-based references ({@code + * rel_reference}), assigning a {@link Rel#getRelAnchor() rel anchor} to each referenced relation. + * + * @param root the relation tree to convert + * @return an equivalent tree using id-based outer references + */ + public static Rel toIdBased(Rel root) { + return convert(root, Direction.TO_ID); + } + + /** + * Rewrites id-based outer references ({@code rel_reference}) as offset-based references ({@code + * steps_out}), removing the {@link Rel#getRelAnchor() rel anchors} they resolved to. + * + * @param root the relation tree to convert + * @return an equivalent tree using offset-based outer references + */ + public static Rel toStepsOut(Rel root) { + return convert(root, Direction.TO_STEPS_OUT); + } + + /** + * Applies {@link #toIdBased(Rel)} to every root of the given plan. + * + * @param plan the plan to convert + * @return an equivalent plan using id-based outer references + */ + public static Plan toIdBased(Plan plan) { + return convert(plan, Direction.TO_ID); + } + + /** + * Applies {@link #toStepsOut(Rel)} to every root of the given plan. + * + * @param plan the plan to convert + * @return an equivalent plan using offset-based outer references + */ + public static Plan toStepsOut(Plan plan) { + return convert(plan, Direction.TO_STEPS_OUT); + } + + private static Plan convert(Plan plan, Direction direction) { + List roots = new ArrayList<>(plan.getRoots().size()); + for (Plan.Root root : plan.getRoots()) { + roots.add(Plan.Root.builder().from(root).input(convert(root.getInput(), direction)).build()); + } + return Plan.builder().from(plan).roots(roots).build(); + } + + private static Rel convert(Rel root, Direction direction) { + State state = new State(direction); + RelRewriter relRewriter = new RelRewriter(state); + return root.accept(relRewriter, EmptyVisitationContext.INSTANCE).orElse(root); + } + + /** Mutable state shared between the paired relation and expression rewriters. */ + private static final class State { + final Direction direction; + + /** + * Stack of enclosing scope relations, one entry per subquery boundary. A {@code null} entry + * marks a multi-input (unsupported) scope. {@code steps_out=N} resolves to the entry {@code N} + * from the top. + */ + final List outerScopes = new ArrayList<>(); + + /** + * Anchors assigned to binding relations while converting to the id-based form (by identity). + */ + final Map anchorByRel = new IdentityHashMap<>(); + + /** Anchors that were resolved to a {@code steps_out} value and should be stripped from rels. */ + final java.util.Set resolvedAnchors = new java.util.HashSet<>(); + + /** The input relation whose expressions are currently being rewritten (RootReference scope). */ + Rel currentInput; + + int nextAnchor = 1; + + State(Direction direction) { + this.direction = direction; + } + + int allocateAnchor(Rel binding) { + Integer existing = anchorByRel.get(binding); + if (existing != null) { + return existing; + } + int anchor = nextAnchor++; + anchorByRel.put(binding, anchor); + return anchor; + } + } + + /** + * Relation rewriter that maintains the scope stack and stamps/strips {@code rel_anchor}s. Only + * single-input expression hosts are handled explicitly; every other relation is traversed by the + * copy-on-write base class. + */ + private static final class RelRewriter extends RelCopyOnWriteVisitor { + private final State state; + + RelRewriter(State state) { + super(relVisitor -> new ExprRewriter(relVisitor, state)); + this.state = state; + } + + @Override + public Optional visit(Filter filter, EmptyVisitationContext context) { + // Expressions first, so a reference binding to this filter's input is discovered before the + // input is (re)built and (for the id-based direction) stamped with its anchor. + Optional condition = + rewriteInScope( + filter.getInput(), + () -> filter.getCondition().accept(getExpressionCopyOnWriteVisitor(), context)); + Rel newInput = rewriteInput(filter.getInput()); + + if (!condition.isPresent() && newInput == filter.getInput()) { + return Optional.empty(); + } + return Optional.of( + Filter.builder() + .from(filter) + .input(newInput) + .condition(condition.orElse(filter.getCondition())) + .build()); + } + + @Override + public Optional visit(Project project, EmptyVisitationContext context) { + Optional> expressions = + rewriteInScope( + project.getInput(), () -> visitExprList(project.getExpressions(), context)); + Rel newInput = rewriteInput(project.getInput()); + + if (!expressions.isPresent() && newInput == project.getInput()) { + return Optional.empty(); + } + return Optional.of( + Project.builder() + .from(project) + .input(newInput) + .expressions(expressions.orElse(project.getExpressions())) + .build()); + } + + @Override + public Optional visit(NamedDdl ddl, EmptyVisitationContext context) { + // The copy-on-write base throws for DDL; a view definition is an independent (top-level) + // query that may itself contain correlated subqueries, so traverse it rather than throw. + if (!ddl.getViewDefinition().isPresent()) { + return Optional.empty(); + } + return ddl.getViewDefinition() + .get() + .accept(this, context) + .map(definition -> NamedDdl.builder().from(ddl).viewDefinition(definition).build()); + } + + /** + * Runs the given expression rewrite with {@code currentInput} temporarily set to {@code input}, + * so that subqueries entered while rewriting push the correct scope. + */ + private T rewriteInScope(Rel input, java.util.function.Supplier rewrite) { + Rel saved = state.currentInput; + state.currentInput = input; + try { + return rewrite.get(); + } finally { + state.currentInput = saved; + } + } + + /** + * Rebuilds the input and applies the anchor stamp/strip decided while rewriting expressions. + */ + private Rel rewriteInput(Rel originalInput) { + Rel rebuilt = + originalInput.accept(this, EmptyVisitationContext.INSTANCE).orElse(originalInput); + if (state.direction == Direction.TO_ID) { + Integer anchor = state.anchorByRel.get(originalInput); + if (anchor != null) { + return rebuilt.withRelAnchor(anchor); + } + } else { + Optional anchor = originalInput.getRelAnchor(); + if (anchor.isPresent() && state.resolvedAnchors.contains(anchor.get())) { + return rebuilt.withRelAnchor(Optional.empty()); + } + } + return rebuilt; + } + } + + /** Expression rewriter that pushes/pops subquery scopes and rewrites outer field references. */ + private static final class ExprRewriter extends ExpressionCopyOnWriteVisitor { + private final State state; + + ExprRewriter(RelCopyOnWriteVisitor relVisitor, State state) { + super(relVisitor); + this.state = state; + } + + @Override + public Optional visit(FieldReference fieldReference, EmptyVisitationContext context) + throws RuntimeException { + if (state.direction == Direction.TO_ID + && fieldReference.outerReferenceStepsOut().isPresent()) { + int anchor = anchorForStepsOut(fieldReference.outerReferenceStepsOut().get()); + return Optional.of( + ImmutableFieldReference.builder() + .from(fieldReference) + .outerReferenceStepsOut(Optional.empty()) + .outerReferenceRelReference(anchor) + .build()); + } + if (state.direction == Direction.TO_STEPS_OUT + && fieldReference.outerReferenceRelReference().isPresent()) { + int stepsOut = stepsOutForAnchor(fieldReference.outerReferenceRelReference().get()); + return Optional.of( + ImmutableFieldReference.builder() + .from(fieldReference) + .outerReferenceRelReference(Optional.empty()) + .outerReferenceStepsOut(stepsOut) + .build()); + } + + // Non-outer references: preserve all attributes, rewriting only a nested input expression. + if (fieldReference.inputExpression().isPresent()) { + Optional newInput = + fieldReference.inputExpression().get().accept(this, context); + if (newInput.isPresent()) { + return Optional.of( + ImmutableFieldReference.builder() + .from(fieldReference) + .inputExpression(newInput) + .build()); + } + } + return Optional.empty(); + } + + private int anchorForStepsOut(int stepsOut) { + int index = state.outerScopes.size() - stepsOut; + if (index < 0 || index >= state.outerScopes.size()) { + throw new IllegalArgumentException( + "Outer reference steps_out=" + stepsOut + " exceeds the enclosing subquery depth"); + } + Rel binding = state.outerScopes.get(index); + if (binding == null) { + throw new UnsupportedOperationException( + "Cannot assign a rel_anchor for an outer reference that resolves to a multi-input " + + "(e.g. join) scope; only single-input scopes are supported"); + } + return state.allocateAnchor(binding); + } + + private int stepsOutForAnchor(int relReference) { + for (int i = state.outerScopes.size() - 1; i >= 0; i--) { + Rel scope = state.outerScopes.get(i); + if (scope != null + && scope.getRelAnchor().isPresent() + && scope.getRelAnchor().get() == relReference) { + state.resolvedAnchors.add(relReference); + return state.outerScopes.size() - i; + } + } + throw new UnsupportedOperationException( + "Cannot resolve outer reference rel_reference=" + + relReference + + " to an enclosing scope; the referenced relation is not an ancestor (shared " + + "subtrees / ReferenceRel are not supported)"); + } + + @Override + public Optional visit( + Expression.ScalarSubquery scalarSubquery, EmptyVisitationContext context) { + return withSubqueryScope(() -> super.visit(scalarSubquery, context)); + } + + @Override + public Optional visit( + Expression.SetPredicate setPredicate, EmptyVisitationContext context) { + return withSubqueryScope(() -> super.visit(setPredicate, context)); + } + + @Override + public Optional visit( + Expression.InPredicate inPredicate, EmptyVisitationContext context) { + // needles are evaluated in the current (outer) scope; the haystack is the subquery boundary. + Optional> needles = visitExprList(inPredicate.needles(), context); + Optional haystack = + withSubqueryScope( + () -> inPredicate.haystack().accept(getRelCopyOnWriteVisitor(), context)); + + if (!needles.isPresent() && !haystack.isPresent()) { + return Optional.empty(); + } + return Optional.of( + Expression.InPredicate.builder() + .from(inPredicate) + .haystack(haystack.orElse(inPredicate.haystack())) + .needles(needles.orElse(inPredicate.needles())) + .build()); + } + + /** Pushes {@code currentInput} as a new subquery scope, runs the action, then pops it. */ + private T withSubqueryScope(java.util.function.Supplier action) { + Rel savedInput = state.currentInput; + state.outerScopes.add(state.currentInput); + try { + return action.get(); + } finally { + state.outerScopes.remove(state.outerScopes.size() - 1); + state.currentInput = savedInput; + } + } + } +} diff --git a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java index a58e2e3b4..68403a4eb 100644 --- a/core/src/main/java/io/substrait/relation/ProtoRelConverter.java +++ b/core/src/main/java/io/substrait/relation/ProtoRelConverter.java @@ -266,7 +266,8 @@ protected NamedWrite newNamedWrite(final WriteRel rel) { builder .commonExtension(optionalAdvancedExtension(rel.getCommon())) .remap(optionalRelmap(rel.getCommon())) - .hint(optionalHint(rel.getCommon())); + .hint(optionalHint(rel.getCommon())) + .relAnchor(optionalRelAnchor(rel.getCommon())); if (rel.hasAdvancedExtension()) { builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension())); @@ -296,7 +297,8 @@ protected Rel newExtensionWrite(final WriteRel rel) { builder .commonExtension(optionalAdvancedExtension(rel.getCommon())) .remap(optionalRelmap(rel.getCommon())) - .hint(optionalHint(rel.getCommon())); + .hint(optionalHint(rel.getCommon())) + .relAnchor(optionalRelAnchor(rel.getCommon())); if (rel.hasAdvancedExtension()) { builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension())); @@ -340,7 +342,8 @@ protected NamedDdl newNamedDdl(final DdlRel rel) { .viewDefinition(optionalViewDefinition(rel)) .commonExtension(optionalAdvancedExtension(rel.getCommon())) .remap(optionalRelmap(rel.getCommon())) - .hint(optionalHint(rel.getCommon())); + .hint(optionalHint(rel.getCommon())) + .relAnchor(optionalRelAnchor(rel.getCommon())); if (rel.hasAdvancedExtension()) { builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension())); @@ -369,7 +372,8 @@ protected ExtensionDdl newExtensionDdl(final DdlRel rel) { .viewDefinition(optionalViewDefinition(rel)) .commonExtension(optionalAdvancedExtension(rel.getCommon())) .remap(optionalRelmap(rel.getCommon())) - .hint(optionalHint(rel.getCommon())); + .hint(optionalHint(rel.getCommon())) + .relAnchor(optionalRelAnchor(rel.getCommon())); if (rel.hasAdvancedExtension()) { builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension())); @@ -471,7 +475,8 @@ protected Filter newFilter(FilterRel rel) { builder .commonExtension(optionalAdvancedExtension(rel.getCommon())) .remap(optionalRelmap(rel.getCommon())) - .hint(optionalHint(rel.getCommon())); + .hint(optionalHint(rel.getCommon())) + .relAnchor(optionalRelAnchor(rel.getCommon())); if (rel.hasAdvancedExtension()) { builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension())); } @@ -521,7 +526,8 @@ protected ExtensionLeaf newExtensionLeaf(ExtensionLeafRel rel) { ExtensionLeaf.from(detail) .commonExtension(optionalAdvancedExtension(rel.getCommon())) .remap(optionalRelmap(rel.getCommon())) - .hint(optionalHint(rel.getCommon())); + .hint(optionalHint(rel.getCommon())) + .relAnchor(optionalRelAnchor(rel.getCommon())); return builder.build(); } @@ -538,7 +544,8 @@ protected ExtensionSingle newExtensionSingle(ExtensionSingleRel rel) { ExtensionSingle.from(detail, input) .commonExtension(optionalAdvancedExtension(rel.getCommon())) .remap(optionalRelmap(rel.getCommon())) - .hint(optionalHint(rel.getCommon())); + .hint(optionalHint(rel.getCommon())) + .relAnchor(optionalRelAnchor(rel.getCommon())); return builder.build(); } @@ -555,7 +562,8 @@ protected ExtensionMulti newExtensionMulti(ExtensionMultiRel rel) { ExtensionMulti.from(detail, inputs) .commonExtension(optionalAdvancedExtension(rel.getCommon())) .remap(optionalRelmap(rel.getCommon())) - .hint(optionalHint(rel.getCommon())); + .hint(optionalHint(rel.getCommon())) + .relAnchor(optionalRelAnchor(rel.getCommon())); if (rel.hasDetail()) { builder.detail(detailFromExtensionMultiRel(rel.getDetail())); } @@ -593,7 +601,8 @@ protected NamedScan newNamedScan(ReadRel rel) { builder .commonExtension(optionalAdvancedExtension(rel.getCommon())) .remap(optionalRelmap(rel.getCommon())) - .hint(optionalHint(rel.getCommon())); + .hint(optionalHint(rel.getCommon())) + .relAnchor(optionalRelAnchor(rel.getCommon())); if (rel.hasAdvancedExtension()) { builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension())); } @@ -617,7 +626,8 @@ protected ExtensionTable newExtensionTable(final ReadRel rel) { .projection(optionalMaskExpression(rel)) .commonExtension(optionalAdvancedExtension(rel.getCommon())) .remap(optionalRelmap(rel.getCommon())) - .hint(optionalHint(rel.getCommon())); + .hint(optionalHint(rel.getCommon())) + .relAnchor(optionalRelAnchor(rel.getCommon())); if (rel.hasAdvancedExtension()) { builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension())); } @@ -659,7 +669,8 @@ protected LocalFiles newLocalFiles(ReadRel rel) { builder .commonExtension(optionalAdvancedExtension(rel.getCommon())) .remap(optionalRelmap(rel.getCommon())) - .hint(optionalHint(rel.getCommon())); + .hint(optionalHint(rel.getCommon())) + .relAnchor(optionalRelAnchor(rel.getCommon())); if (rel.hasAdvancedExtension()) { builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension())); } @@ -779,7 +790,8 @@ protected VirtualTableScan newVirtualTable(ReadRel rel) { builder .commonExtension(optionalAdvancedExtension(rel.getCommon())) .remap(optionalRelmap(rel.getCommon())) - .hint(optionalHint(rel.getCommon())); + .hint(optionalHint(rel.getCommon())) + .relAnchor(optionalRelAnchor(rel.getCommon())); if (rel.hasAdvancedExtension()) { builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension())); } @@ -804,7 +816,8 @@ protected Fetch newFetch(FetchRel rel) { builder .commonExtension(optionalAdvancedExtension(rel.getCommon())) .remap(optionalRelmap(rel.getCommon())) - .hint(optionalHint(rel.getCommon())); + .hint(optionalHint(rel.getCommon())) + .relAnchor(optionalRelAnchor(rel.getCommon())); if (rel.hasAdvancedExtension()) { builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension())); } @@ -832,7 +845,8 @@ protected Project newProject(ProjectRel rel) { builder .commonExtension(optionalAdvancedExtension(rel.getCommon())) .remap(optionalRelmap(rel.getCommon())) - .hint(optionalHint(rel.getCommon())); + .hint(optionalHint(rel.getCommon())) + .relAnchor(optionalRelAnchor(rel.getCommon())); if (rel.hasAdvancedExtension()) { builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension())); } @@ -878,7 +892,8 @@ protected Expand newExpand(ExpandRel rel) { builder .commonExtension(optionalAdvancedExtension(rel.getCommon())) .remap(optionalRelmap(rel.getCommon())) - .hint(optionalHint(rel.getCommon())); + .hint(optionalHint(rel.getCommon())) + .relAnchor(optionalRelAnchor(rel.getCommon())); return builder.build(); } @@ -929,7 +944,8 @@ protected Aggregate newAggregate(AggregateRel rel) { builder .commonExtension(optionalAdvancedExtension(rel.getCommon())) .remap(optionalRelmap(rel.getCommon())) - .hint(optionalHint(rel.getCommon())); + .hint(optionalHint(rel.getCommon())) + .relAnchor(optionalRelAnchor(rel.getCommon())); if (rel.hasAdvancedExtension()) { builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension())); } @@ -962,7 +978,8 @@ protected Sort newSort(SortRel rel) { builder .commonExtension(optionalAdvancedExtension(rel.getCommon())) .remap(optionalRelmap(rel.getCommon())) - .hint(optionalHint(rel.getCommon())); + .hint(optionalHint(rel.getCommon())) + .relAnchor(optionalRelAnchor(rel.getCommon())); if (rel.hasAdvancedExtension()) { builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension())); } @@ -996,7 +1013,8 @@ protected Join newJoin(JoinRel rel) { builder .commonExtension(optionalAdvancedExtension(rel.getCommon())) .remap(optionalRelmap(rel.getCommon())) - .hint(optionalHint(rel.getCommon())); + .hint(optionalHint(rel.getCommon())) + .relAnchor(optionalRelAnchor(rel.getCommon())); if (rel.hasAdvancedExtension()) { builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension())); } @@ -1016,7 +1034,8 @@ protected Rel newCross(CrossRel rel) { builder .commonExtension(optionalAdvancedExtension(rel.getCommon())) - .remap(optionalRelmap(rel.getCommon())); + .remap(optionalRelmap(rel.getCommon())) + .relAnchor(optionalRelAnchor(rel.getCommon())); if (rel.hasAdvancedExtension()) { builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension())); } @@ -1040,7 +1059,8 @@ protected Set newSet(SetRel rel) { builder .commonExtension(optionalAdvancedExtension(rel.getCommon())) .remap(optionalRelmap(rel.getCommon())) - .hint(optionalHint(rel.getCommon())); + .hint(optionalHint(rel.getCommon())) + .relAnchor(optionalRelAnchor(rel.getCommon())); if (rel.hasAdvancedExtension()) { builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension())); } @@ -1084,7 +1104,8 @@ protected Rel newHashJoin(HashJoinRel rel) { builder .commonExtension(optionalAdvancedExtension(rel.getCommon())) .remap(optionalRelmap(rel.getCommon())) - .hint(optionalHint(rel.getCommon())); + .hint(optionalHint(rel.getCommon())) + .relAnchor(optionalRelAnchor(rel.getCommon())); if (rel.hasAdvancedExtension()) { builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension())); } @@ -1129,7 +1150,8 @@ protected Rel newMergeJoin(MergeJoinRel rel) { builder .commonExtension(optionalAdvancedExtension(rel.getCommon())) .remap(optionalRelmap(rel.getCommon())) - .hint(optionalHint(rel.getCommon())); + .hint(optionalHint(rel.getCommon())) + .relAnchor(optionalRelAnchor(rel.getCommon())); if (rel.hasAdvancedExtension()) { builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension())); } @@ -1221,7 +1243,8 @@ protected NestedLoopJoin newNestedLoopJoin(NestedLoopJoinRel rel) { builder .commonExtension(optionalAdvancedExtension(rel.getCommon())) .remap(optionalRelmap(rel.getCommon())) - .hint(optionalHint(rel.getCommon())); + .hint(optionalHint(rel.getCommon())) + .relAnchor(optionalRelAnchor(rel.getCommon())); if (rel.hasAdvancedExtension()) { builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension())); } @@ -1264,7 +1287,8 @@ protected ConsistentPartitionWindow newConsistentPartitionWindow( builder .commonExtension(optionalAdvancedExtension(rel.getCommon())) .remap(optionalRelmap(rel.getCommon())) - .hint(optionalHint(rel.getCommon())); + .hint(optionalHint(rel.getCommon())) + .relAnchor(optionalRelAnchor(rel.getCommon())); if (rel.hasAdvancedExtension()) { builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension())); } @@ -1323,7 +1347,8 @@ protected ScatterExchange newScatterExchange(ExchangeRel rel) { builder .commonExtension(optionalAdvancedExtension(rel.getCommon())) .remap(optionalRelmap(rel.getCommon())) - .hint(optionalHint(rel.getCommon())); + .hint(optionalHint(rel.getCommon())) + .relAnchor(optionalRelAnchor(rel.getCommon())); if (rel.hasAdvancedExtension()) { builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension())); } @@ -1353,7 +1378,8 @@ protected SingleBucketExchange newSingleBucketExchange(ExchangeRel rel) { builder .commonExtension(optionalAdvancedExtension(rel.getCommon())) .remap(optionalRelmap(rel.getCommon())) - .hint(optionalHint(rel.getCommon())); + .hint(optionalHint(rel.getCommon())) + .relAnchor(optionalRelAnchor(rel.getCommon())); if (rel.hasAdvancedExtension()) { builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension())); } @@ -1384,7 +1410,8 @@ protected MultiBucketExchange newMultiBucketExchange(ExchangeRel rel) { builder .commonExtension(optionalAdvancedExtension(rel.getCommon())) .remap(optionalRelmap(rel.getCommon())) - .hint(optionalHint(rel.getCommon())); + .hint(optionalHint(rel.getCommon())) + .relAnchor(optionalRelAnchor(rel.getCommon())); if (rel.hasAdvancedExtension()) { builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension())); } @@ -1412,7 +1439,8 @@ protected RoundRobinExchange newRoundRobinExchange(ExchangeRel rel) { builder .commonExtension(optionalAdvancedExtension(rel.getCommon())) .remap(optionalRelmap(rel.getCommon())) - .hint(optionalHint(rel.getCommon())); + .hint(optionalHint(rel.getCommon())) + .relAnchor(optionalRelAnchor(rel.getCommon())); if (rel.hasAdvancedExtension()) { builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension())); } @@ -1439,7 +1467,8 @@ protected BroadcastExchange newBroadcastExchange(ExchangeRel rel) { builder .commonExtension(optionalAdvancedExtension(rel.getCommon())) .remap(optionalRelmap(rel.getCommon())) - .hint(optionalHint(rel.getCommon())); + .hint(optionalHint(rel.getCommon())) + .relAnchor(optionalRelAnchor(rel.getCommon())); if (rel.hasAdvancedExtension()) { builder.extension(protoExtensionConverter.fromProto(rel.getAdvancedExtension())); } @@ -1481,6 +1510,17 @@ protected static Optional optionalRelmap(io.substrait.proto.RelCommon relCommon.hasEmit() ? Rel.Remap.of(relCommon.getEmit().getOutputMappingList()) : null); } + /** + * Converts the {@link io.substrait.proto.RelCommon#getRelAnchor()} field to its POJO + * representation. + * + * @param relCommon the protobuf value to convert + * @return the converted result + */ + protected static Optional optionalRelAnchor(io.substrait.proto.RelCommon relCommon) { + return relCommon.hasRelAnchor() ? Optional.of(relCommon.getRelAnchor()) : Optional.empty(); + } + /** * Converts the corresponding protobuf message to its POJO representation. * diff --git a/core/src/main/java/io/substrait/relation/Rel.java b/core/src/main/java/io/substrait/relation/Rel.java index 74915e83b..ca4a9b5e6 100644 --- a/core/src/main/java/io/substrait/relation/Rel.java +++ b/core/src/main/java/io/substrait/relation/Rel.java @@ -25,6 +25,54 @@ public interface Rel { */ Optional getCommonExtension(); + /** + * Returns the plan-wide unique anchor identifying this relation, if set. + * + *

This corresponds to {@link io.substrait.proto.RelCommon#getRelAnchor()} and is required when + * this relation is the binding point for an id-based outer reference (see {@link + * io.substrait.expression.FieldReference#outerReferenceRelReference()}). When set it must be + * unique across all relations within a plan and {@code >= 1}. + * + * @return the optional relation anchor + */ + Optional getRelAnchor(); + + /** + * Returns a copy of this relation with its {@link #getRelAnchor() relation anchor} set to the + * given value. + * + *

Overridden by the generated Immutables {@code withRelAnchor(int)} on every concrete + * relation, this provides a type-agnostic way to stamp an anchor onto an arbitrary relation (used + * by {@link OuterReferenceConverter} when assigning anchors during {@code steps_out → + * rel_reference} conversion). Custom {@link Rel} implementations that are not Immutables-backed + * inherit this throwing default. + * + * @param relAnchor the plan-wide unique anchor to set (must be {@code >= 1}) + * @return a copy of this relation carrying the given anchor + */ + default Rel withRelAnchor(int relAnchor) { + throw new UnsupportedOperationException( + getClass() + " does not support setting a relation anchor"); + } + + /** + * Returns a copy of this relation with its {@link #getRelAnchor() relation anchor} set to the + * given optional value ({@link Optional#empty()} clears it). + * + *

Like {@link #withRelAnchor(int)}, this is overridden by the generated Immutables {@code + * withRelAnchor(Optional)} and provides a type-agnostic way to set or clear the anchor (used by + * {@link OuterReferenceConverter} when stripping anchors during {@code rel_reference → steps_out} + * conversion). Custom {@link Rel} implementations that are not Immutables-backed inherit this + * throwing default. + * + * @param relAnchor the anchor to set, or empty to clear it + * @return a copy of this relation carrying the given anchor + */ + default Rel withRelAnchor(Optional relAnchor) { + throw new UnsupportedOperationException( + getClass() + " does not support setting a relation anchor"); + } + /** * Returns the record type (schema) produced by this relation. * diff --git a/core/src/main/java/io/substrait/relation/RelProtoConverter.java b/core/src/main/java/io/substrait/relation/RelProtoConverter.java index 391c8e66b..aa1af0bee 100644 --- a/core/src/main/java/io/substrait/relation/RelProtoConverter.java +++ b/core/src/main/java/io/substrait/relation/RelProtoConverter.java @@ -942,6 +942,8 @@ private RelCommon common(io.substrait.relation.Rel rel) { .ifPresent( extension -> builder.setAdvancedExtension(extensionProtoConverter.toProto(extension))); + rel.getRelAnchor().ifPresent(builder::setRelAnchor); + io.substrait.relation.Rel.Remap remap = rel.getRemap().orElse(null); if (remap != null) { builder.setEmit(RelCommon.Emit.newBuilder().addAllOutputMapping(remap.indices())); diff --git a/core/src/test/java/io/substrait/relation/OuterReferenceConverterTest.java b/core/src/test/java/io/substrait/relation/OuterReferenceConverterTest.java new file mode 100644 index 000000000..47fdad4b0 --- /dev/null +++ b/core/src/test/java/io/substrait/relation/OuterReferenceConverterTest.java @@ -0,0 +1,224 @@ +package io.substrait.relation; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import io.substrait.TestBase; +import io.substrait.expression.Expression; +import io.substrait.expression.FieldReference; +import io.substrait.relation.Rel.Remap; +import io.substrait.type.TypeCreator; +import java.util.List; +import org.junit.jupiter.api.Test; + +/** + * Tests {@link OuterReferenceConverter}, which translates outer references between the offset-based + * ({@code steps_out}) and id-based ({@code rel_anchor}/{@code rel_reference}) encodings. + */ +class OuterReferenceConverterTest extends TestBase { + + private final Rel customerTableScan = + sb.namedScan( + List.of("customer"), + List.of("c_custkey", "c_nationkey"), + List.of(TypeCreator.REQUIRED.I64, TypeCreator.REQUIRED.I64)); + + private final Rel orderTableScan = + sb.namedScan( + List.of("orders"), + List.of("o_orderkey", "o_custkey"), + List.of(TypeCreator.REQUIRED.I64, TypeCreator.REQUIRED.I64)); + + private final Rel nationTableScan = + sb.namedScan( + List.of("nation"), + List.of("n_nationkey", "n_name"), + List.of(TypeCreator.REQUIRED.I64, TypeCreator.REQUIRED.STRING)); + + /** SELECT o_orderkey, (SELECT c_nationkey FROM customer WHERE c_custkey = orders.o_custkey). */ + private Rel oneStepPlan() { + return sb.project( + input -> + List.of( + sb.fieldReference(input, 0), + sb.scalarSubquery( + sb.project( + input2 -> List.of(sb.fieldReference(input2, 1)), + Remap.of(List.of(1)), + sb.filter( + input2 -> + sb.equal( + sb.fieldReference(input2, 0), + FieldReference.newRootStructOuterReference( + 1, TypeCreator.REQUIRED.I64, 1)), + customerTableScan)), + TypeCreator.NULLABLE.I64)), + Remap.of(List.of(2, 3)), + orderTableScan); + } + + /** + * Nested subquery whose innermost outer reference steps out two boundaries to the orders scan. + */ + private Rel twoStepPlan() { + return sb.project( + input -> + List.of( + sb.fieldReference(input, 0), + sb.scalarSubquery( + sb.project( + input2 -> List.of(sb.fieldReference(input2, 1)), + Remap.of(List.of(2)), + sb.filter( + input2 -> + sb.equal( + sb.fieldReference(input2, 0), + sb.scalarSubquery( + sb.project( + input3 -> List.of(sb.fieldReference(input3, 1)), + Remap.of(List.of(1)), + sb.filter( + input3 -> + sb.equal( + sb.fieldReference(input3, 0), + FieldReference.newRootStructOuterReference( + 1, TypeCreator.REQUIRED.I64, 2)), + customerTableScan)), + TypeCreator.NULLABLE.I64)), + nationTableScan)), + TypeCreator.NULLABLE.I64)), + Remap.of(List.of(2, 3)), + orderTableScan); + } + + @Test + void oneStepToIdBasedStampsAnchorAndReference() { + Rel stepsOut = oneStepPlan(); + Rel idBased = OuterReferenceConverter.toIdBased(stepsOut); + + assertNotEquals(stepsOut, idBased); + + // The outer project's input (the orders scan) is the binding relation and gets anchor 1. + Project outerProject = (Project) idBased; + assertEquals(1, outerProject.getInput().getRelAnchor().orElseThrow(AssertionError::new)); + + // The correlated reference now uses rel_reference instead of steps_out. + FieldReference outerRef = outerReferenceOf(outerProject); + assertEquals(1, outerRef.outerReferenceRelReference().orElseThrow(AssertionError::new)); + assertFalse(outerRef.outerReferenceStepsOut().isPresent()); + } + + @Test + void oneStepRoundTripsBackToStepsOut() { + Rel stepsOut = oneStepPlan(); + assertEquals( + stepsOut, OuterReferenceConverter.toStepsOut(OuterReferenceConverter.toIdBased(stepsOut))); + } + + @Test + void twoStepRoundTripsThroughIdBased() { + Rel stepsOut = twoStepPlan(); + Rel idBased = OuterReferenceConverter.toIdBased(stepsOut); + + assertNotEquals(stepsOut, idBased); + // The innermost reference steps out two levels to the orders scan, which gets the anchor. + assertEquals(1, ((Project) idBased).getInput().getRelAnchor().orElseThrow(AssertionError::new)); + assertEquals(stepsOut, OuterReferenceConverter.toStepsOut(idBased)); + } + + @Test + void sameScopeReferencesShareAnchor() { + // Two correlated references at the same subquery boundary must resolve to the same anchor. + Rel stepsOut = + sb.project( + input -> + List.of( + sb.fieldReference(input, 0), + sb.scalarSubquery( + sb.project( + input2 -> List.of(sb.fieldReference(input2, 1)), + Remap.of(List.of(1)), + sb.filter( + input2 -> + sb.and( + sb.equal( + sb.fieldReference(input2, 0), + FieldReference.newRootStructOuterReference( + 0, TypeCreator.REQUIRED.I64, 1)), + sb.equal( + sb.fieldReference(input2, 1), + FieldReference.newRootStructOuterReference( + 1, TypeCreator.REQUIRED.I64, 1))), + customerTableScan)), + TypeCreator.NULLABLE.I64)), + Remap.of(List.of(2, 3)), + orderTableScan); + + Rel idBased = OuterReferenceConverter.toIdBased(stepsOut); + assertEquals(1, ((Project) idBased).getInput().getRelAnchor().orElseThrow(AssertionError::new)); + assertEquals(stepsOut, OuterReferenceConverter.toStepsOut(idBased)); + } + + @Test + void stepsOutBeyondDepthIsRejected() { + Rel invalid = + sb.project( + input -> + List.of( + sb.scalarSubquery( + sb.project( + input2 -> List.of(sb.fieldReference(input2, 1)), + Remap.of(List.of(1)), + sb.filter( + input2 -> + sb.equal( + sb.fieldReference(input2, 0), + // steps_out=5 exceeds the single enclosing boundary + FieldReference.newRootStructOuterReference( + 1, TypeCreator.REQUIRED.I64, 5)), + customerTableScan)), + TypeCreator.NULLABLE.I64)), + Remap.of(List.of(2)), + orderTableScan); + + assertThrows(IllegalArgumentException.class, () -> OuterReferenceConverter.toIdBased(invalid)); + } + + @Test + void unresolvableRelReferenceIsRejected() { + // An id-based reference to an anchor that is not an enclosing scope cannot become steps_out. + Rel dangling = + sb.project( + input -> + List.of( + sb.scalarSubquery( + sb.project( + input2 -> List.of(sb.fieldReference(input2, 1)), + Remap.of(List.of(1)), + sb.filter( + input2 -> + sb.equal( + sb.fieldReference(input2, 0), + FieldReference.newRootStructOuterReferenceByRelReference( + 1, TypeCreator.REQUIRED.I64, 999)), + customerTableScan)), + TypeCreator.NULLABLE.I64)), + Remap.of(List.of(2)), + orderTableScan); + + assertThrows( + UnsupportedOperationException.class, () -> OuterReferenceConverter.toStepsOut(dangling)); + } + + /** Extracts the outer field reference from the one-step plan's scalar-subquery filter. */ + private FieldReference outerReferenceOf(Project outerProject) { + Expression.ScalarSubquery subquery = + (Expression.ScalarSubquery) outerProject.getExpressions().get(1); + Filter filter = (Filter) ((Project) subquery.input()).getInput(); + Expression.ScalarFunctionInvocation equal = + (Expression.ScalarFunctionInvocation) filter.getCondition(); + return (FieldReference) equal.arguments().get(1); + } +} diff --git a/core/src/test/java/io/substrait/type/proto/OuterReferenceRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/OuterReferenceRoundtripTest.java new file mode 100644 index 000000000..1eff3dbda --- /dev/null +++ b/core/src/test/java/io/substrait/type/proto/OuterReferenceRoundtripTest.java @@ -0,0 +1,85 @@ +package io.substrait.type.proto; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.substrait.TestBase; +import io.substrait.expression.Expression; +import io.substrait.expression.FieldReference; +import io.substrait.expression.ImmutableFieldReference; +import io.substrait.expression.proto.ProtoExpressionConverter; +import io.substrait.proto.Expression.FieldReference.OuterReference.OuterReferenceTypeCase; +import io.substrait.type.Type; +import org.junit.jupiter.api.Test; + +/** + * Round-trip tests for the two outer-reference resolution mechanisms introduced with Substrait + * v0.89.0: the offset-based {@link FieldReference#outerReferenceStepsOut()} and the id-based {@link + * FieldReference#outerReferenceRelReference()}. + * + *

Outer references derive their type from the converter's root type on the way back from proto, + * so a converter configured with a matching root type is used instead of {@link + * TestBase#protoExpressionConverter}. + */ +class OuterReferenceRoundtripTest extends TestBase { + + final Type.Struct outerSchema = + Type.Struct.builder().nullable(false).addFields(R.I64, R.STRING).build(); + + final ProtoExpressionConverter protoExpressionConverterWithRoot = + new ProtoExpressionConverter(functionCollector, extensions, outerSchema, protoRelConverter); + + private void verifyOuterReferenceRoundTrip(FieldReference reference) { + io.substrait.proto.Expression proto = expressionProtoConverter.toProto(reference); + Expression returned = protoExpressionConverterWithRoot.from(proto); + assertEquals(reference, returned); + } + + @Test + void offsetBasedOuterReference() { + FieldReference reference = FieldReference.newRootStructOuterReference(0, outerSchema, 1); + + assertTrue(reference.isOuterReference()); + assertFalse(reference.isSimpleRootReference()); + verifyOuterReferenceRoundTrip(reference); + } + + @Test + void idBasedOuterReference() { + FieldReference reference = + FieldReference.newRootStructOuterReferenceByRelReference(1, outerSchema, 42); + + assertTrue(reference.isOuterReference()); + assertFalse(reference.isSimpleRootReference()); + verifyOuterReferenceRoundTrip(reference); + } + + /** + * A reference may carry both outer-reference forms during the transition towards id-based + * resolution. Producing prefers the id-based {@code rel_reference}; since the two share a + * protobuf oneof, the offset-based {@code steps_out} is not emitted and is therefore absent after + * a round-trip. + */ + @Test + void bothFormsSetPrefersIdBased() { + FieldReference reference = + ImmutableFieldReference.builder() + .from(FieldReference.newRootStructOuterReference(0, outerSchema, 3)) + .outerReferenceRelReference(42) + .build(); + + assertTrue(reference.outerReferenceStepsOut().isPresent()); + assertTrue(reference.outerReferenceRelReference().isPresent()); + + io.substrait.proto.Expression proto = expressionProtoConverter.toProto(reference); + io.substrait.proto.Expression.FieldReference.OuterReference outerReference = + proto.getSelection().getOuterReference(); + assertEquals(OuterReferenceTypeCase.REL_REFERENCE, outerReference.getOuterReferenceTypeCase()); + assertEquals(42, outerReference.getRelReference()); + + FieldReference returned = (FieldReference) protoExpressionConverterWithRoot.from(proto); + assertEquals(42, returned.outerReferenceRelReference().orElseThrow(AssertionError::new)); + assertFalse(returned.outerReferenceStepsOut().isPresent()); + } +} diff --git a/core/src/test/java/io/substrait/type/proto/RelAnchorRoundtripTest.java b/core/src/test/java/io/substrait/type/proto/RelAnchorRoundtripTest.java new file mode 100644 index 000000000..8377c2ea8 --- /dev/null +++ b/core/src/test/java/io/substrait/type/proto/RelAnchorRoundtripTest.java @@ -0,0 +1,59 @@ +package io.substrait.type.proto; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import io.substrait.TestBase; +import io.substrait.relation.Filter; +import io.substrait.relation.Project; +import io.substrait.relation.Rel; +import java.util.Arrays; +import java.util.Collections; +import org.junit.jupiter.api.Test; + +/** + * Round-trip tests for the {@link Rel#getRelAnchor()} field (Substrait v0.89.0 {@code + * RelCommon.rel_anchor}), the plan-wide unique identifier used as the binding point for id-based + * outer references. + */ +class RelAnchorRoundtripTest extends TestBase { + + final Rel baseTable = + sb.namedScan( + Collections.singletonList("test_table"), + Arrays.asList("id", "name"), + Arrays.asList(R.I64, R.STRING)); + + @Test + void relAnchorOnProject() { + Rel projection = + Project.builder() + .input(baseTable) + .relAnchor(7) + .addExpressions(sb.fieldReference(baseTable, 0)) + .build(); + + assertEquals(7, projection.getRelAnchor().orElseThrow(AssertionError::new)); + verifyRoundTrip(projection); + } + + @Test + void relAnchorOnFilter() { + Rel filter = + Filter.builder() + .input(baseTable) + .relAnchor(1) + .condition(sb.equal(sb.fieldReference(baseTable, 0), sb.fieldReference(baseTable, 0))) + .build(); + + verifyRoundTrip(filter); + } + + @Test + void relAnchorAbsentByDefault() { + Rel projection = + Project.builder().input(baseTable).addExpressions(sb.fieldReference(baseTable, 0)).build(); + + assertEquals(false, projection.getRelAnchor().isPresent()); + verifyRoundTrip(projection); + } +} diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java index 0aef61857..9eff03249 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java @@ -28,6 +28,7 @@ import io.substrait.relation.NamedScan; import io.substrait.relation.NamedUpdate; import io.substrait.relation.NamedWrite; +import io.substrait.relation.OuterReferenceConverter; import io.substrait.relation.Project; import io.substrait.relation.Rel; import io.substrait.relation.Rel.Remap; @@ -164,8 +165,11 @@ public static RelNode convert( .typeSystem(converterProvider.getTypeSystem()) .programs() .build()); - return relRoot.accept( - converterProvider.getSubstraitRelNodeConverter(relBuilder), Context.newContext()); + // Normalize id-based outer references (rel_reference) back to offset-based ones (steps_out) so + // the depth-based conversion below handles both encodings uniformly. Offset-based plans are + // left unchanged. + return OuterReferenceConverter.toStepsOut(relRoot) + .accept(converterProvider.getSubstraitRelNodeConverter(relBuilder), Context.newContext()); } @Override diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java index bf5eedb75..684fa7095 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java @@ -29,6 +29,7 @@ import io.substrait.relation.NamedScan; import io.substrait.relation.NamedUpdate; import io.substrait.relation.NamedWrite; +import io.substrait.relation.OuterReferenceConverter; import io.substrait.relation.Project; import io.substrait.relation.Rel; import io.substrait.relation.Rel.Remap; @@ -841,7 +842,9 @@ public static Plan.Root convert(RelRoot relRoot, SimpleExtension.ExtensionCollec public static Plan.Root convert(RelRoot relRoot, ConverterProvider converterProvider) { SubstraitRelVisitor visitor = converterProvider.getSubstraitRelVisitor(); visitor.popFieldAccessDepthMap(relRoot.rel); - Rel rel = visitor.apply(relRoot.project()); + // Convert offset-based outer references (steps_out) into id-based ones (rel_anchor / + // rel_reference), the encoding Substrait is migrating towards. + Rel rel = OuterReferenceConverter.toIdBased(visitor.apply(relRoot.project())); // Avoid using the names from relRoot.validatedRowType because if there are // nested types (i.e ROW, MAP, etc) the typeConverter will pad names correctly @@ -875,6 +878,6 @@ public static Rel convert(RelNode relNode, SimpleExtension.ExtensionCollection e public static Rel convert(RelNode relNode, ConverterProvider converterProvider) { SubstraitRelVisitor visitor = converterProvider.getSubstraitRelVisitor(); visitor.popFieldAccessDepthMap(relNode); - return visitor.apply(relNode); + return OuterReferenceConverter.toIdBased(visitor.apply(relNode)); } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java index bec19102c..d1ccdf990 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java @@ -2,6 +2,7 @@ import io.substrait.isthmus.SubstraitRelNodeConverter.Context; import io.substrait.plan.Plan; +import io.substrait.relation.OuterReferenceConverter; import io.substrait.relation.Rel; import io.substrait.util.EmptyVisitationContext; import java.util.ArrayList; @@ -63,6 +64,10 @@ public SubstraitToCalcite( * @return {@link RelNode} */ public RelNode convert(Rel rel) { + // Normalize id-based outer references (rel_reference) into offset-based ones (steps_out) so the + // depth-based conversion handles both encodings uniformly. Offset-based plans are unchanged. + rel = OuterReferenceConverter.toStepsOut(rel); + RelBuilder relBuilder; if (catalogReader != null) { relBuilder = converterProvider.getRelBuilder(catalogReader.getRootSchema()); diff --git a/isthmus/src/test/java/io/substrait/isthmus/SubqueryPlanTest.java b/isthmus/src/test/java/io/substrait/isthmus/SubqueryPlanTest.java index 9e9aad657..2cbcb7773 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/SubqueryPlanTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/SubqueryPlanTest.java @@ -1,6 +1,7 @@ package io.substrait.isthmus; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -54,7 +55,9 @@ void existsCorrelatedSubquery() throws SqlParseException { .getSelection(); // l_orderkey assertEquals(0, correlatedCol.getDirectReference().getStructField().getField()); - assertEquals(1, correlatedCol.getOuterReference().getStepsOut()); + // Correlated references are now emitted in the id-based form (rel_reference) rather than + // steps_out; see OuterReferenceConverter. + assertTrue(correlatedCol.getOuterReference().hasRelReference()); } @Test @@ -97,7 +100,9 @@ void uniqueCorrelatedSubquery() throws IOException, SqlParseException { .getSelection(); // l_orderkey assertEquals(0, correlatedCol.getDirectReference().getStructField().getField()); - assertEquals(1, correlatedCol.getOuterReference().getStepsOut()); + // Correlated references are now emitted in the id-based form (rel_reference) rather than + // steps_out; see OuterReferenceConverter. + assertTrue(correlatedCol.getOuterReference().hasRelReference()); } @Test @@ -135,7 +140,9 @@ void inPredicateCorrelatedSubQuery() throws IOException, SqlParseException { .getSelection(); // l_partkey assertEquals(1, correlatedCol.getDirectReference().getStructField().getField()); - assertEquals(1, correlatedCol.getOuterReference().getStepsOut()); + // Correlated references are now emitted in the id-based form (rel_reference) rather than + // steps_out; see OuterReferenceConverter. + assertTrue(correlatedCol.getOuterReference().hasRelReference()); } @Test @@ -175,7 +182,9 @@ void notInPredicateCorrelatedSubquery() throws IOException, SqlParseException { .getSelection(); // l_partkey assertEquals(1, correlatedCol.getDirectReference().getStructField().getField()); - assertEquals(1, correlatedCol.getOuterReference().getStepsOut()); + // Correlated references are now emitted in the id-based form (rel_reference) rather than + // steps_out; see OuterReferenceConverter. + assertTrue(correlatedCol.getOuterReference().hasRelReference()); } @Test @@ -247,7 +256,7 @@ void existsNestedCorrelatedSubquery() throws IOException, SqlParseException { .getValue() .getSelection(); // p.p_partkey assertEquals(0, correlatedCol1.getDirectReference().getStructField().getField()); - assertEquals(2, correlatedCol1.getOuterReference().getStepsOut()); + assertTrue(correlatedCol1.getOuterReference().hasRelReference()); Expression.FieldReference correlatedCol2 = inner_subquery_cond2 @@ -256,7 +265,11 @@ void existsNestedCorrelatedSubquery() throws IOException, SqlParseException { .getValue() .getSelection(); // l.l_suppkey assertEquals(2, correlatedCol2.getDirectReference().getStructField().getField()); - assertEquals(1, correlatedCol2.getOuterReference().getStepsOut()); + // The two correlated columns reference different outer relations, so distinct rel anchors. + assertTrue(correlatedCol2.getOuterReference().hasRelReference()); + assertNotEquals( + correlatedCol1.getOuterReference().getRelReference(), + correlatedCol2.getOuterReference().getRelReference()); } @Test @@ -322,7 +335,7 @@ void nestedScalarCorrelatedSubquery() throws IOException, SqlParseException { .getValue() .getSelection(); // p.p_partkey assertEquals(0, correlatedCol1.getDirectReference().getStructField().getField()); - assertEquals(2, correlatedCol1.getOuterReference().getStepsOut()); + assertTrue(correlatedCol1.getOuterReference().hasRelReference()); Expression.FieldReference correlatedCol2 = inner_subquery_cond2 @@ -331,7 +344,11 @@ void nestedScalarCorrelatedSubquery() throws IOException, SqlParseException { .getValue() .getSelection(); // l.l_suppkey assertEquals(2, correlatedCol2.getDirectReference().getStructField().getField()); - assertEquals(1, correlatedCol2.getOuterReference().getStepsOut()); + // The two correlated columns reference different outer relations, so distinct rel anchors. + assertTrue(correlatedCol2.getOuterReference().hasRelReference()); + assertNotEquals( + correlatedCol1.getOuterReference().getRelReference(), + correlatedCol2.getOuterReference().getRelReference()); } @Test diff --git a/isthmus/src/test/java/io/substrait/isthmus/SubtraitRelVisitorExtensionTest.java b/isthmus/src/test/java/io/substrait/isthmus/SubtraitRelVisitorExtensionTest.java index a3f234c47..ea51c3453 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/SubtraitRelVisitorExtensionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/SubtraitRelVisitorExtensionTest.java @@ -110,6 +110,11 @@ public Optional getHint() { return input.getHint(); } + @Override + public Optional getRelAnchor() { + return input.getRelAnchor(); + } + @Override public O accept( final RelVisitor visitor, final C context) throws E { diff --git a/isthmus/src/test/java/io/substrait/isthmus/expression/SubqueryConversionTest.java b/isthmus/src/test/java/io/substrait/isthmus/expression/SubqueryConversionTest.java index 9d089825c..69211cfc5 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/expression/SubqueryConversionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/expression/SubqueryConversionTest.java @@ -87,6 +87,46 @@ void testOuterFieldReferenceOneStep() { SubstraitSqlDialect.toSql(calciteRel).getSql()); } + @Test + void testOuterFieldReferenceOneStepIdBased() { + /* + * Same correlated scalar subquery as testOuterFieldReferenceOneStep, but the outer reference is + * expressed in the id-based form (rel_reference) binding to the orders scan's rel_anchor. The + * consumer must resolve it to the same Calcite $cor0 correlation as the steps_out form. + */ + final Rel root = + sb.project( + input -> + List.of( + sb.fieldReference(input, 0), + sb.scalarSubquery( + sb.project( + input2 -> List.of(sb.fieldReference(input2, 1)), + Remap.of(List.of(1)), + sb.filter( + input2 -> + sb.equal( + sb.fieldReference(input2, 0), + FieldReference.newRootStructOuterReferenceByRelReference( + 1, TypeCreator.REQUIRED.I64, 1)), + customerTableScan)), + TypeCreator.NULLABLE.I64)), + Remap.of(List.of(2, 3)), + // The orders scan is the binding relation for the id-based outer reference. + orderTableScan.withRelAnchor(1)); + + final RelNode calciteRel = substraitToCalcite.convert(root); + + assertEquals( + "LogicalProject(variablesSet=[[$cor0]], o_orderkey0=[$0], $f3=[$SCALAR_QUERY({\n" + + "LogicalProject(c_nationkey=[$1])\n" + + " LogicalFilter(condition=[=($0, $cor0.o_custkey)])\n" + + " LogicalTableScan(table=[[customer]])\n" + + "})])\n" + + " LogicalTableScan(table=[[orders]])\n", + calciteRel.explain()); + } + @Test void testOuterFieldReferenceTwoSteps() { /* diff --git a/substrait b/substrait index f09071699..d61a9403d 160000 --- a/substrait +++ b/substrait @@ -1 +1 @@ -Subproject commit f09071699edcb41d90e1080625e9ee25fce2b81c +Subproject commit d61a9403dc6fe0006b8b712b97b9d6a3bada1acf