diff --git a/CedarJava/src/main/java/com/cedarpolicy/model/policy/Policy.java b/CedarJava/src/main/java/com/cedarpolicy/model/policy/Policy.java index 5f1a7a93..702bbb89 100644 --- a/CedarJava/src/main/java/com/cedarpolicy/model/policy/Policy.java +++ b/CedarJava/src/main/java/com/cedarpolicy/model/policy/Policy.java @@ -23,7 +23,8 @@ import com.cedarpolicy.model.Effect; import java.util.concurrent.atomic.AtomicInteger; - +import java.util.HashMap; +import java.util.Map; /** Policies in the Cedar language. */ public class Policy { @@ -36,18 +37,18 @@ public class Policy { public final String policySrc; /** Policy ID. */ public final String policyID; + /** Annotations */ + private Map annotations; /** * Creates a Cedar policy object. * - * @param policy String containing the source code of a Cedar policy in the Cedar policy - * language. - * @param policyID The id of this policy. Must be unique. Note: We may flip the order of the - * arguments here for idiomatic reasons. + * @param policy String containing the source code of a Cedar policy in the Cedar policy language. + * @param policyID The id of this policy. Must be unique. Note: We may flip the order of the arguments here for + * idiomatic reasons. */ @SuppressFBWarnings("CT_CONSTRUCTOR_THROW") - public Policy( - @JsonProperty("policySrc") String policy, @JsonProperty("policyID") String policyID) + public Policy(@JsonProperty("policySrc") String policy, @JsonProperty("policyID") String policyID) throws NullPointerException { if (policy == null) { @@ -82,8 +83,8 @@ public String toString() { /** * Returns the effect of a policy. * - * Determines the policy effect by attempting static policy first, then template. - * In future, it will only support static policies once new class is introduced for Template. + * Determines the policy effect by attempting static policy first, then template. In future, it will only support + * static policies once new class is introduced for Template. * * @return The effect of the policy, either "permit" or "forbid" * @throws InternalException @@ -114,17 +115,69 @@ public static Policy parseStaticPolicy(String policyStr) throws InternalExceptio return new Policy(policyText, null); } - public static Policy parsePolicyTemplate(String templateStr) throws InternalException, NullPointerException { + public static Policy parsePolicyTemplate(String templateStr) throws InternalException, NullPointerException { String templateText = parsePolicyTemplateJni(templateStr); return new Policy(templateText, null); } + /** + * Gets a copy of the policy annotations map. Annotations are loaded lazily when this method is first called. Works + * for both static policies and templates. + * + * @return A new HashMap containing the policy's annotations. For annotations without explicit values, an empty + * string ("") is used as the value + */ + public Map getAnnotations() throws InternalException { + ensureAnnotationsLoaded(); + return new HashMap<>(this.annotations); + } + + /** + * Gets the value of a specific annotation by its key. + * + * @param key The annotation key to look up + * @return The value associated with the annotation key, or null if the key doesn't exist + * @throws InternalException if there is an error loading or parsing the annotations + */ + public String getAnnotation(String key) throws InternalException { + ensureAnnotationsLoaded(); + return this.annotations.getOrDefault(key, null); + } + + /** + * Ensures that the annotations map is loaded for this policy. If annotations haven't been loaded yet, attempts to + * load them first from static policy, then falls back to template if needed. + * + * @throws InternalException if there is an error loading or parsing the annotations + */ + private void ensureAnnotationsLoaded() throws InternalException { + if (annotations == null) { + try { + this.annotations = getPolicyAnnotationsJni(this.policySrc); + } catch (InternalException e) { + if (e.getMessage().contains("expected a static policy")) { + this.annotations = getTemplateAnnotationsJni(this.policySrc); + } else { + throw e; + } + } + } + } + private static native String parsePolicyJni(String policyStr) throws InternalException, NullPointerException; + private static native String parsePolicyTemplateJni(String policyTemplateStr) throws InternalException, NullPointerException; private native String toJsonJni(String policyStr) throws InternalException, NullPointerException; + private static native String fromJsonJni(String policyJsonStr) throws InternalException, NullPointerException; + private native String policyEffectJni(String policyStr) throws InternalException, NullPointerException; + private native String templateEffectJni(String policyStr) throws InternalException, NullPointerException; + + private static native Map getPolicyAnnotationsJni(String policyStr) throws InternalException; + + private static native Map getTemplateAnnotationsJni(String policyStr) throws InternalException; } diff --git a/CedarJava/src/test/java/com/cedarpolicy/PolicyTests.java b/CedarJava/src/test/java/com/cedarpolicy/PolicyTests.java index a58acfc3..9bb71bcc 100644 --- a/CedarJava/src/test/java/com/cedarpolicy/PolicyTests.java +++ b/CedarJava/src/test/java/com/cedarpolicy/PolicyTests.java @@ -19,8 +19,8 @@ import com.cedarpolicy.model.exception.InternalException; import com.cedarpolicy.model.policy.Policy; import com.cedarpolicy.model.Effect; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Test; import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -28,6 +28,9 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; +import java.util.Map; +import java.util.HashMap; + public class PolicyTests { @Test public void parseStaticPolicyTests() { @@ -140,4 +143,88 @@ public void policyEffectTest() throws InternalException { assertEquals(forbidTemplate.effect(), Effect.FORBID); } + + @Test + public void givenStaticPolicyGetAnnotationsReturns() throws InternalException { + Policy staticPolicy = Policy.parseStaticPolicy(""" + @id("policyID1") + @annotation("myAnnotation") + @emptyAnnotation + permit(principal, action, resource); + """); + + Map expectedMap = new HashMap<>(); + expectedMap.put("id", "policyID1"); + expectedMap.put("annotation", "myAnnotation"); + expectedMap.put("emptyAnnotation", ""); + + assertEquals(staticPolicy.getAnnotations(), expectedMap); + + Policy staticPolicyNoAnnotations = Policy.parseStaticPolicy(""" + permit(principal, action, resource); + """); + + assertEquals(staticPolicyNoAnnotations.getAnnotations(), new HashMap<>()); + } + + @Test + public void givenStaticPolicyGetAnnotationReturns() throws InternalException { + Policy staticPolicy = Policy.parseStaticPolicy(""" + @id("policyID1") + @annotation("myAnnotation") + @emptyAnnotation + permit(principal, action, resource); + """); + + assertEquals(staticPolicy.getAnnotation("annotation"), "myAnnotation"); + assertEquals(staticPolicy.getAnnotation("emptyAnnotation"), ""); + + Policy staticPolicyNoAnnotations = Policy.parseStaticPolicy(""" + permit(principal, action, resource); + """); + + assertEquals(staticPolicyNoAnnotations.getAnnotation("invalidAnnotation"), null); + } + + @Test + public void givenTemplatePolicyGetAnnotationsReturns() throws InternalException { + Policy templatePolicy = Policy.parsePolicyTemplate(""" + @id("policyID1") + @annotation("myAnnotation") + @emptyAnnotation + permit(principal == ?principal, action, resource); + """); + + Map expectedMap = new HashMap<>(); + expectedMap.put("id", "policyID1"); + expectedMap.put("annotation", "myAnnotation"); + expectedMap.put("emptyAnnotation", ""); + + assertEquals(templatePolicy.getAnnotations(), expectedMap); + + Policy templatePolicyNoAnnotations = Policy.parsePolicyTemplate(""" + permit(principal == ?principal, action, resource); + """); + + assertEquals(templatePolicyNoAnnotations.getAnnotations(), new HashMap<>()); + } + + @Test + public void givenTemplatePolicyGetAnnotationReturns() throws InternalException { + Policy templatePolicy = Policy.parsePolicyTemplate(""" + @id("policyID1") + @annotation("myAnnotation") + @emptyAnnotation + permit(principal == ?principal, action, resource); + """); + + assertEquals(templatePolicy.getAnnotation("annotation"), "myAnnotation"); + assertEquals(templatePolicy.getAnnotation("emptyAnnotation"), ""); + + Policy templatePolicyNoAnnotations = Policy.parsePolicyTemplate(""" + permit(principal == ?principal, action, resource); + """); + + assertEquals(templatePolicyNoAnnotations.getAnnotation("invalidAnnotation"), null); + } } diff --git a/CedarJavaFFI/src/interface.rs b/CedarJavaFFI/src/interface.rs index 4374b20c..4d607e99 100644 --- a/CedarJavaFFI/src/interface.rs +++ b/CedarJavaFFI/src/interface.rs @@ -35,6 +35,7 @@ use std::{error::Error, str::FromStr, thread}; use crate::objects::JFormatterConfig; use crate::{ answer::Answer, + jmap::Map, jset::Set, objects::{JEntityId, JEntityTypeName, JEntityUID, JPolicy, Object}, utils::raise_npe, @@ -336,6 +337,85 @@ fn create_java_policy_set<'a>( .expect("Failed to create new PolicySet object") } +#[jni_fn("com.cedarpolicy.model.policy.Policy")] +pub fn getPolicyAnnotationsJni<'a>( + mut env: JNIEnv<'a>, + _: JClass, + policy_jstr: JString<'a>, +) -> jvalue { + match get_policy_annotations_internal(&mut env, policy_jstr) { + Err(e) => jni_failed(&mut env, e.as_ref()), + Ok(annotations) => annotations.as_jni(), + } +} + +pub fn get_policy_annotations_internal<'a>( + env: &mut JNIEnv<'a>, + policy_jstr: JString<'a>, +) -> Result> { + if policy_jstr.is_null() { + raise_npe(env) + } else { + let policy_jstring = env.get_string(&policy_jstr)?; + let policy_string = String::from(policy_jstring); + + match Policy::from_str(&policy_string) { + Err(e) => Err(Box::new(e)), + Ok(policy) => { + let java_map = create_java_map_from_annotations(env, policy.annotations()); + Ok(JValueGen::Object(java_map)) + } + } + } +} + +#[jni_fn("com.cedarpolicy.model.policy.Policy")] +pub fn getTemplateAnnotationsJni<'a>( + mut env: JNIEnv<'a>, + _: JClass, + template_jstr: JString<'a>, +) -> jvalue { + match get_template_annotations_internal(&mut env, template_jstr) { + Err(e) => jni_failed(&mut env, e.as_ref()), + Ok(annotations) => annotations.as_jni(), + } +} + +pub fn get_template_annotations_internal<'a>( + env: &mut JNIEnv<'a>, + template_jstr: JString<'a>, +) -> Result> { + if template_jstr.is_null() { + raise_npe(env) + } else { + let template_jstring = env.get_string(&template_jstr)?; + let template_string = String::from(template_jstring); + + match Template::from_str(&template_string) { + Err(e) => Err(Box::new(e)), + Ok(template) => { + let java_map = create_java_map_from_annotations(env, template.annotations()); + Ok(JValueGen::Object(java_map)) + } + } + } +} + +fn create_java_map_from_annotations<'a, 'b>( + env: &mut JNIEnv<'a>, + annotations: impl Iterator, +) -> JObject<'a> { + let mut map = Map::new(env).unwrap(); + + for (annotation_key, annotation_value) in annotations { + let key: JString = env.new_string(annotation_key).unwrap().into(); + let value: JString = env.new_string(annotation_value).unwrap().into(); + map.put(env, key, value).unwrap(); + } + + map.into_inner() +} + #[jni_fn("com.cedarpolicy.model.policy.Policy")] pub fn parsePolicyTemplateJni<'a>( mut env: JNIEnv<'a>, @@ -625,7 +705,7 @@ fn policies_str_to_pretty_internal<'a>( } #[cfg(test)] -mod interface_tests { +mod jvm_based_tests { use super::*; use crate::jvm_test_utils::*; use jni::JavaVM; @@ -652,5 +732,143 @@ mod interface_tests { policy_effect_test_util(&mut env, "permit(principal,action,resource);", "permit"); policy_effect_test_util(&mut env, "forbid(principal,action,resource);", "forbid"); } + + #[track_caller] + fn assert_id_annotation_eq( + env: &mut JNIEnv, + annotations: &JObject, + annotation_key: &str, + expected_annotation_value: &str, + ) { + let annotation_key_jstr = env.new_string(annotation_key).unwrap(); + let actual_annotation_value_obj = env + .call_method( + annotations, + "get", + "(Ljava/lang/Object;)Ljava/lang/Object;", + &[JValueGen::Object(annotation_key_jstr.as_ref())], + ) + .unwrap() + .l() + .unwrap(); + + let actual_annotation_value_jstr = + JString::cast(env, actual_annotation_value_obj).unwrap(); + let actual_annotation_value_str = + String::from(env.get_string(&actual_annotation_value_jstr).unwrap()); + + assert_eq!( + actual_annotation_value_str, expected_annotation_value, + "Returned annotation value should match the annotation in the policy." + ) + } + + #[test] + fn static_policy_annotations_tests() { + let mut env = JVM.attach_current_thread().unwrap(); + let policy_string = env + .new_string("@id(\"policyID1\") @myAnnotationKey(\"myAnnotatedValue\") permit(principal,action,resource);") + .unwrap(); + let annotations = get_policy_annotations_internal(&mut env, policy_string) + .unwrap() + .l() + .unwrap(); + + assert_id_annotation_eq(&mut env, &annotations, "id", "policyID1"); + assert_id_annotation_eq( + &mut env, + &annotations, + "myAnnotationKey", + "myAnnotatedValue", + ); + } + + #[test] + fn template_policy_annotations_tests() { + let mut env = JVM.attach_current_thread().unwrap(); + let policy_string = env + .new_string("@id(\"policyID1\") @myAnnotationKey(\"myAnnotatedValue\") permit(principal==?principal,action,resource);") + .unwrap(); + let annotations = get_template_annotations_internal(&mut env, policy_string) + .unwrap() + .l() + .unwrap(); + + assert_id_annotation_eq(&mut env, &annotations, "id", "policyID1"); + assert_id_annotation_eq( + &mut env, + &annotations, + "myAnnotationKey", + "myAnnotatedValue", + ); + } + } + + mod map_tests { + use super::*; + + #[test] + fn map_new_tests() { + let mut env = JVM.attach_current_thread().unwrap(); + let java_hash_map = Map::::new(&mut env); + + assert!(java_hash_map.is_ok(), "Map creation should succeed"); + + assert!( + env.is_instance_of(java_hash_map.unwrap().into_inner(), "java/util/HashMap") + .unwrap(), + "Object should be a HashMap instance." + ); + } + + #[test] + fn map_put_tests() { + let mut env = JVM.attach_current_thread().unwrap(); + let mut java_hash_map = Map::::new(&mut env).unwrap(); + + let key = env.new_string("test_key").unwrap(); + let value = env.new_string("test_value").unwrap(); + + let result = java_hash_map.put(&mut env, key, value); + + assert!(result.is_ok(), "Map put should succeed."); + + let new_key = env.new_string("test_key").unwrap(); + let new_value = env.new_string("updated_value").unwrap(); + + let update_result = java_hash_map.put(&mut env, new_key, new_value); + + assert!(result.is_ok(), "Map put should succeed."); + + let update_result_jstr = JString::cast(&mut env, update_result.unwrap()).unwrap(); + let update_result_str = String::from(env.get_string(&update_result_jstr).unwrap()); + + assert_eq!( + update_result_str, "test_value", + "Value returned from map update should match the original value of test_key." + ) + } + + #[test] + fn map_get_tests() { + let mut env = JVM.attach_current_thread().unwrap(); + let mut java_hash_map = Map::::new(&mut env).unwrap(); + + let key = env.new_string("test_key").unwrap(); + let value = env.new_string("test_value").unwrap(); + + let _ = java_hash_map.put(&mut env, key, value); + + let retrieval_key = env.new_string("test_key").unwrap(); + let retrieved_value = java_hash_map.get(&mut env, retrieval_key).unwrap(); + + let retrieved_value_jstr = JString::cast(&mut env, retrieved_value).unwrap(); + let retrieved_value_str = String::from(env.get_string(&retrieved_value_jstr).unwrap()); + + assert_eq!( + retrieved_value_str, "test_value", + "Retrieved value should be equal to the inserted value." + ) + } } } diff --git a/CedarJavaFFI/src/jmap.rs b/CedarJavaFFI/src/jmap.rs new file mode 100644 index 00000000..8d9e56d5 --- /dev/null +++ b/CedarJavaFFI/src/jmap.rs @@ -0,0 +1,88 @@ +/* + * Copyright Cedar Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use std::marker::PhantomData; + +use crate::{objects::Object, utils::Result}; +use jni::{ + objects::{JObject, JValueGen}, + JNIEnv, +}; + +/// Typed wrapper for Java maps +/// (java.util.Map) +#[derive(Debug)] +pub struct Map<'a, T, U> { + /// Underlying Java object + obj: JObject<'a>, + /// ZST for tracking key type info + key_marker: PhantomData, + /// ZST for tracking value type info + value_marker: PhantomData, +} + +impl<'a, T: Object<'a>, U: Object<'a>> Map<'a, T, U> { + /// Construct an empty hash map, which will serve as a map + pub fn new(env: &mut JNIEnv<'a>) -> Result { + let obj = env.new_object("java/util/HashMap", "()V", &[])?; + + Ok(Self { + obj, + key_marker: PhantomData, + value_marker: PhantomData, + }) + } + + /// Get a value mapped to a key + pub fn get(&mut self, env: &mut JNIEnv<'a>, k: T) -> Result> { + let key = JValueGen::Object(k.as_ref()); + let value = env + .call_method( + &self.obj, + "get", + "(Ljava/lang/Object;)Ljava/lang/Object;", + &[key], + )? + .l()?; + Ok(value) + } + + /// Put a key-value pair into the map + pub fn put(&mut self, env: &mut JNIEnv<'a>, k: T, v: U) -> Result> { + let key = JValueGen::Object(k.as_ref()); + let value = JValueGen::Object(v.as_ref()); + let value = env + .call_method( + &self.obj, + "put", + "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;", + &[key, value], + )? + .l()?; + Ok(value) + } + + /// Consumes the Map and returns the underlying JObject + pub fn into_inner(self) -> JObject<'a> { + self.obj + } +} + +impl<'a, T, U> AsRef> for Map<'a, T, U> { + fn as_ref(&self) -> &JObject<'a> { + &self.obj + } +} diff --git a/CedarJavaFFI/src/lib.rs b/CedarJavaFFI/src/lib.rs index 194b0ace..5b740146 100644 --- a/CedarJavaFFI/src/lib.rs +++ b/CedarJavaFFI/src/lib.rs @@ -18,6 +18,7 @@ mod answer; mod interface; mod jlist; +mod jmap; mod jset; mod jvm_test_utils; mod objects;