diff --git a/extensions/src/main/java/dev/cel/extensions/BUILD.bazel b/extensions/src/main/java/dev/cel/extensions/BUILD.bazel index 73bab08c9..b25fdf16d 100644 --- a/extensions/src/main/java/dev/cel/extensions/BUILD.bazel +++ b/extensions/src/main/java/dev/cel/extensions/BUILD.bazel @@ -329,6 +329,7 @@ java_library( deps = [ "//checker:checker_builder", "//common/exceptions:attribute_not_found", + "//common/exceptions:invalid_argument", "//common/internal:reflection_util", "//common/types", "//common/types:type_providers", diff --git a/extensions/src/main/java/dev/cel/extensions/CelNativeTypesExtensions.java b/extensions/src/main/java/dev/cel/extensions/CelNativeTypesExtensions.java index ae9483f7c..5ba97c6a6 100644 --- a/extensions/src/main/java/dev/cel/extensions/CelNativeTypesExtensions.java +++ b/extensions/src/main/java/dev/cel/extensions/CelNativeTypesExtensions.java @@ -28,6 +28,7 @@ import com.google.errorprone.annotations.Immutable; import dev.cel.checker.CelCheckerBuilder; import dev.cel.common.exceptions.CelAttributeNotFoundException; +import dev.cel.common.exceptions.CelInvalidArgumentException; import dev.cel.common.internal.ReflectionUtil; import dev.cel.common.types.CelType; import dev.cel.common.types.CelTypeProvider; @@ -47,6 +48,7 @@ import dev.cel.runtime.CelRuntimeLibrary; import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandles; +import java.lang.reflect.Array; import java.lang.reflect.Constructor; import java.lang.reflect.Field; import java.lang.reflect.Method; @@ -285,6 +287,15 @@ private static CelType mapJavaTypeToCelType( return celType; } + if (type.isArray()) { + TypeToken token = TypeToken.of(genericType); + TypeToken componentToken = + Preconditions.checkNotNull( + token.getComponentType(), "Array component type cannot be null"); + return ListType.create( + mapJavaTypeToCelType(componentToken.getRawType(), componentToken.getType(), classMap)); + } + if (type.isInterface() && !List.class.isAssignableFrom(type) && !Map.class.isAssignableFrom(type)) { @@ -412,6 +423,14 @@ private void discover(Type type) { TypeToken token = TypeToken.of(type); Class rawType = token.getRawType(); + if (rawType.isArray()) { + TypeToken componentToken = + Preconditions.checkNotNull( + token.getComponentType(), "Array component type cannot be null"); + discover(componentToken.getType()); + return; + } + if (List.class.isAssignableFrom(rawType)) { discover(ReflectionUtil.resolveGenericParameter(token, List.class, 0)); return; @@ -767,6 +786,9 @@ private static Object getDefaultValue(Class targetType) { if (Map.class.isAssignableFrom(targetType)) { return ImmutableMap.of(); } + if (targetType.isArray()) { + return Array.newInstance(targetType.getComponentType(), 0); + } try { Constructor constructor = targetType.getDeclaredConstructor(); @@ -814,6 +836,10 @@ public Object toRuntimeValue(Object value) { return new PojoStructValue(value, accessors, registry.classToTypeMap.get(clazz)); } + if (clazz.isArray() && clazz != byte[].class) { + return convertArrayToList(value); + } + return super.toRuntimeValue(value); } @@ -836,8 +862,14 @@ Object toNative(Object value, Class targetType, Type genericType) { return ((CelByteString) value).toByteArray(); } - if (List.class.isAssignableFrom(targetType) && value instanceof List) { - return convertListToNative((List) value, targetType, genericType); + if (value instanceof List) { + List listValue = (List) value; + if (List.class.isAssignableFrom(targetType)) { + return convertListToNative(listValue, targetType, genericType); + } + if (targetType.isArray()) { + return convertListToArray(listValue, targetType, genericType); + } } if (Map.class.isAssignableFrom(targetType) && value instanceof Map) { @@ -849,7 +881,7 @@ Object toNative(Object value, Class targetType, Type genericType) { // Safe reflection collection cast. @SuppressWarnings("unchecked") - private Object convertListToNative(List list, Class targetType, Type genericType) { + private List convertListToNative(List list, Class targetType, Type genericType) { TypeToken token = TypeToken.of(genericType); Type elementType = ReflectionUtil.resolveGenericParameter(token, List.class, 0); Class componentType = ReflectionUtil.getRawType(elementType); @@ -901,7 +933,7 @@ private Object convertListToNative(List list, Class targetType, Type gener // Safe reflection collection cast. @SuppressWarnings("unchecked") - private Object convertMapToNative(Map map, Class targetType, Type genericType) { + private Map convertMapToNative(Map map, Class targetType, Type genericType) { TypeToken token = TypeToken.of(genericType); Type keyType = ReflectionUtil.resolveGenericParameter(token, Map.class, 0); Type valueType = ReflectionUtil.resolveGenericParameter(token, Map.class, 1); @@ -962,6 +994,36 @@ private Object convertMapToNative(Map map, Class targetType, Type gener return builder.buildOrThrow(); } + private Object convertListToArray(List list, Class targetType, Type genericType) { + Class componentType = targetType.getComponentType(); + Object array = Array.newInstance(componentType, list.size()); + TypeToken token = TypeToken.of(genericType); + TypeToken componentToken = + Preconditions.checkNotNull( + token.getComponentType(), "Array component type cannot be null"); + Type componentGenericType = componentToken.getType(); + + for (int i = 0; i < list.size(); i++) { + Object element = list.get(i); + Object converted = toNative(element, componentType, componentGenericType); + Array.set(array, i, converted); + } + return array; + } + + private ImmutableList convertArrayToList(Object array) { + int length = Array.getLength(array); + ImmutableList.Builder builder = ImmutableList.builderWithExpectedSize(length); + for (int i = 0; i < length; i++) { + Object element = Array.get(array, i); + if (element == null) { + throw new CelInvalidArgumentException(String.format("Element at index %d is null.", i)); + } + builder.add(toRuntimeValue(element)); + } + return builder.build(); + } + private Object downcastPrimitives(Object value, Class targetType) { Class wrappedTargetType = Primitives.wrap(targetType); if (wrappedTargetType == Integer.class && value instanceof Long) { diff --git a/extensions/src/main/java/dev/cel/extensions/README.md b/extensions/src/main/java/dev/cel/extensions/README.md index fcf019d15..c2aca5e98 100644 --- a/extensions/src/main/java/dev/cel/extensions/README.md +++ b/extensions/src/main/java/dev/cel/extensions/README.md @@ -1114,14 +1114,14 @@ The type-mapping between Java and CEL is as follows: | `String` | `string` | | `java.time.Duration` | `duration` | | `java.time.Instant` | `timestamp` | -| `java.util.List` | `list` | +| `java.util.List`, `T[]` (except `byte[]`) | `list` | | `java.util.Map` | `map` | | `java.util.Optional` | `optional_type` | ### Notes * This is only supported for the planner runtime (e.g., `CelRuntimeFactory.plannerRuntimeBuilder()`). -* Native Java arrays (except `byte[]`) are not supported. Use `java.util.List` instead. +* Native Java arrays are supported. `byte[]` maps to `bytes`, while other arrays map to `list`. * If there is a name collision with a Protobuf type, the protobuf type will take precedence. * Instantiating new struct values (e.g., `Account{id: 1234}`) requires the class to have a no-argument constructor (public, protected, package-private, or private). * Final fields are supported only in a **read-only** capacity; they cannot be populated when instantiating new struct values. diff --git a/extensions/src/test/java/dev/cel/extensions/BUILD.bazel b/extensions/src/test/java/dev/cel/extensions/BUILD.bazel index 31720917f..9fda186cf 100644 --- a/extensions/src/test/java/dev/cel/extensions/BUILD.bazel +++ b/extensions/src/test/java/dev/cel/extensions/BUILD.bazel @@ -20,6 +20,7 @@ java_library( "//common/exceptions:attribute_not_found", "//common/exceptions:divide_by_zero", "//common/exceptions:index_out_of_bounds", + "//common/exceptions:invalid_argument", "//common/types", "//common/types:type_providers", "//common/values", diff --git a/extensions/src/test/java/dev/cel/extensions/CelNativeTypesExtensionsTest.java b/extensions/src/test/java/dev/cel/extensions/CelNativeTypesExtensionsTest.java index dcd3e811c..14d05a6bf 100644 --- a/extensions/src/test/java/dev/cel/extensions/CelNativeTypesExtensionsTest.java +++ b/extensions/src/test/java/dev/cel/extensions/CelNativeTypesExtensionsTest.java @@ -30,6 +30,7 @@ import dev.cel.common.CelContainer; import dev.cel.common.CelValidationException; import dev.cel.common.exceptions.CelAttributeNotFoundException; +import dev.cel.common.exceptions.CelInvalidArgumentException; import dev.cel.common.types.CelType; import dev.cel.common.types.ListType; import dev.cel.common.types.MapType; @@ -89,7 +90,8 @@ public final class CelNativeTypesExtensionsTest { TestNestedSimplePojo.class, TestGetterFieldTypeMismatchPojo.class, TestAbstractPojo.class, - TestURLPojo.class); + TestURLPojo.class, + TestArrayPojo.class); private static final Cel CEL = CelFactory.plannerCelBuilder() @@ -322,10 +324,10 @@ public void nativeTypes_anonymousClass_throwsException() { @Test public void nativeTypes_createStruct_privateConstructor() throws Exception { - Object result = eval("TestPrivateConstructorPojo{value:" + " 'hello'}"); + TestPrivateConstructorPojo result = + (TestPrivateConstructorPojo) eval("TestPrivateConstructorPojo{value:" + " 'hello'}"); - assertThat(result).isInstanceOf(TestPrivateConstructorPojo.class); - assertThat(((TestPrivateConstructorPojo) result).value).isEqualTo("hello"); + assertThat(result.value).isEqualTo("hello"); } @Test @@ -374,10 +376,9 @@ public void nativeTypes_missingNoArgConstructor_throws() throws Exception { @Test public void nativeTypes_createWithDeepConversion() throws Exception { - Object result = eval("TestDeepConversionPojo{ints: [1, 2], floats: {'a': 1.0, 'b': 2.0}}"); - - assertThat(result).isInstanceOf(TestDeepConversionPojo.class); - TestDeepConversionPojo pojo = (TestDeepConversionPojo) result; + TestDeepConversionPojo pojo = + (TestDeepConversionPojo) + eval("TestDeepConversionPojo{ints: [1, 2], floats: {'a': 1.0, 'b': 2.0}}"); assertThat(pojo.ints.get(0)).isEqualTo(1); assertThat(pojo.floats).containsEntry("a", 1.0f); } @@ -397,11 +398,92 @@ public void nativeTypes_unsupportedTypeSet_throwsOnRegistration() throws Excepti } @Test - public void nativeTypes_arrayType_throwsOnRegistration() throws Exception { - IllegalArgumentException e = + public void nativeTypes_arrayType_construction() throws Exception { + String expr = + "TestArrayPojo{" + + " strings: ['a', 'b']," + + " ints: [1, 2]," + + " nesteds: [TestNestedType{value: 'nested'}]," + + " matrix: [[1, 2], [3, 4]]," + + " nestedMatrix: [[TestNestedType{value: 'm1'}], [TestNestedType{value: 'm2'}]]," + + " byteArrays: [b'foo', b'bar']" + + "}"; + + TestArrayPojo pojo = (TestArrayPojo) eval(expr); + + assertThat(pojo.strings).isEqualTo(new String[] {"a", "b"}); + assertThat(pojo.ints).isEqualTo(new int[] {1, 2}); + assertThat(pojo.nesteds).hasLength(1); + assertThat(pojo.nesteds[0].value).isEqualTo("nested"); + assertThat(pojo.matrix).hasLength(2); + assertThat(pojo.matrix[0]).isEqualTo(new int[] {1, 2}); + assertThat(pojo.matrix[1]).isEqualTo(new int[] {3, 4}); + assertThat(pojo.nestedMatrix).hasLength(2); + assertThat(pojo.nestedMatrix[0][0].value).isEqualTo("m1"); + assertThat(pojo.nestedMatrix[1][0].value).isEqualTo("m2"); + assertThat(pojo.byteArrays).hasLength(2); + assertThat(pojo.byteArrays[0]).isEqualTo("foo".getBytes(UTF_8)); + assertThat(pojo.byteArrays[1]).isEqualTo("bar".getBytes(UTF_8)); + } + + @Test + public void nativeTypes_arrayType_selection() throws Exception { + CelNativeTypesExtensions extensions = CelExtensions.nativeTypes(TestArrayPojo.class); + Cel cel = + CelFactory.plannerCelBuilder() + .setContainer(CelContainer.ofName("dev.cel.extensions.CelNativeTypesExtensionsTest")) + .addCompilerLibraries(extensions) + .addRuntimeLibraries(extensions) + .addVar("pojo", StructTypeReference.create(TestArrayPojo.class.getCanonicalName())) + .build(); + String expr = + "pojo.strings[1] == 'b'" + + " && pojo.ints[0] == 1" + + " && pojo.nesteds[0].value == 'nested'" + + " && pojo.matrix[1][0] == 3" + + " && pojo.nestedMatrix[1][0].value == 'm2'" + + " && pojo.byteArrays[1] == b'bar'"; + CelAbstractSyntaxTree ast = cel.compile(expr).getAst(); + CelRuntime.Program program = cel.createProgram(ast); + + TestArrayPojo input = new TestArrayPojo(); + input.strings = new String[] {"a", "b"}; + input.ints = new int[] {1, 2}; + TestNestedType nested = new TestNestedType(); + nested.value = "nested"; + input.nesteds = new TestNestedType[] {nested}; + input.matrix = new int[][] {{1, 2}, {3, 4}}; + TestNestedType m1 = new TestNestedType(); + m1.value = "m1"; + TestNestedType m2 = new TestNestedType(); + m2.value = "m2"; + input.nestedMatrix = new TestNestedType[][] {{m1}, {m2}}; + input.byteArrays = new byte[][] {"foo".getBytes(UTF_8), "bar".getBytes(UTF_8)}; + + assertThat(program.eval(ImmutableMap.of("pojo", input))).isEqualTo(true); + } + + @Test + public void nativeTypes_arrayWithNullElement_throws() throws Exception { + CelNativeTypesExtensions extensions = CelExtensions.nativeTypes(TestArrayPojo.class); + Cel cel = + CelFactory.plannerCelBuilder() + .setContainer(CelContainer.ofName("dev.cel.extensions.CelNativeTypesExtensionsTest")) + .addCompilerLibraries(extensions) + .addRuntimeLibraries(extensions) + .addVar("pojo", StructTypeReference.create(TestArrayPojo.class.getCanonicalName())) + .build(); + CelAbstractSyntaxTree ast = cel.compile("pojo.strings").getAst(); + CelRuntime.Program program = cel.createProgram(ast); + + TestArrayPojo input = new TestArrayPojo(); + input.strings = new String[] {"a", null, "c"}; + + CelEvaluationException e = assertThrows( - IllegalArgumentException.class, () -> CelExtensions.nativeTypes(TestArrayPojo.class)); - assertThat(e).hasMessageThat().contains("Unsupported type for property 'values'"); + CelEvaluationException.class, () -> program.eval(ImmutableMap.of("pojo", input))); + assertThat(e).hasCauseThat().isInstanceOf(CelInvalidArgumentException.class); + assertThat(e).hasCauseThat().hasMessageThat().contains("Element at index 1 is null."); } @Test @@ -653,10 +735,7 @@ public void nativeTypes_createWithUint_fromUnsignedLong() throws Exception { .getAst(); CelRuntime.Program program = celRuntime.createProgram(ast); - Object result = program.eval(); - - assertThat(result).isInstanceOf(TestAllTypesPublicFieldsPojo.class); - TestAllTypesPublicFieldsPojo pojo = (TestAllTypesPublicFieldsPojo) result; + TestAllTypesPublicFieldsPojo pojo = (TestAllTypesPublicFieldsPojo) program.eval(); assertThat(pojo.uintVal).isEqualTo(UnsignedLong.fromLongBits(42L)); } @@ -773,6 +852,8 @@ public void nativeTypes_nullSafeTraversal() throws Exception { assertThat(cel.createProgram(cel.compile("pojo.int64Val").getAst()).eval(vars)).isEqualTo(0L); assertThat(cel.createProgram(cel.compile("pojo.nestedVal.value").getAst()).eval(vars)) .isEqualTo(""); + assertThat(cel.createProgram(cel.compile("size(pojo.arrayVal) == 0").getAst()).eval(vars)) + .isEqualTo(true); CelAbstractSyntaxTree abstractPojoAst = cel.compile("pojo.abstractPojo.value").getAst(); CelRuntime.Program abstractPojoProgram = cel.createProgram(abstractPojoAst); CelEvaluationException e = @@ -933,6 +1014,7 @@ public String get() { public double doubleVal; public float floatVal; public byte[] bytesVal; + public String[] arrayVal; public Duration durationVal; public Instant timestampVal; public TestNestedType nestedVal; @@ -1245,7 +1327,12 @@ public static class TestWildcardPojo { } public static class TestArrayPojo { - public String[] values; + public String[] strings; + public int[] ints; + public TestNestedType[] nesteds; + public int[][] matrix; + public TestNestedType[][] nestedMatrix; + public byte[][] byteArrays; } public static class TestOptionalUrlPojo {