From 0b48e38dad4205ddb2966772112135546f54e37c Mon Sep 17 00:00:00 2001 From: unknown Date: Fri, 14 Apr 2023 09:26:16 +0100 Subject: [PATCH] applied fix to return multiple rules in a policy --- opa_client/opa.py | 21 +++++----- opa_client/test/test_opa.py | 77 ++++++++++++++++++++++--------------- 2 files changed, 57 insertions(+), 41 deletions(-) diff --git a/opa_client/opa.py b/opa_client/opa.py index 108723b..6331636 100644 --- a/opa_client/opa.py +++ b/opa_client/opa.py @@ -491,12 +491,15 @@ def __get_policies_info(self) -> dict: for path in policy.get('ast').get('package').get('path'): permission_url += '/' + path.get('value') temp_policy.append(permission_url) - for rule in policy.get('ast').get('rules'): - if not rule.get('default'): - continue + + rules = list(set( + [rule.get("head").get("name") for rule in policy.get("ast").get("rules")] + )) + for rule in rules: temp_url = permission_url - temp_url += '/' + rule.get('head').get('name') + temp_url += "/" + rule temp_rules.append(temp_url) + temp_dict[policy.get('id')] = {'path': temp_policy, 'rules': temp_rules} return temp_dict @@ -517,13 +520,11 @@ def __check( for path in result.get('ast').get('package').get('path'): permission_url += '/' + path.get('value') - for rule in result.get('ast').get('rules'): - if not rule.get('default'): - continue - if rule.get('head').get('name') == rule_name: + rules = [rule.get("head").get("name") for rule in result.get("ast").get("rules")] + if rule_name in rules: + permission_url += "/" + rule_name + find = True - permission_url += '/' + rule.get('head').get('name') - find = True if find: encoded_json = json.dumps(input_data).encode('utf-8') permission_url = self.prepare_args(permission_url, query_params) diff --git a/opa_client/test/test_opa.py b/opa_client/test/test_opa.py index c60d40b..7bc2712 100644 --- a/opa_client/test/test_opa.py +++ b/opa_client/test/test_opa.py @@ -21,7 +21,7 @@ def tearDown(self): del self.myclient def test_client(self): - """Set up the test for OpaClient object""" + """Set up the test for OpaClient object""" client = OpaClient('localhost', 8181, 'v1') self.assertEqual('http://localhost:8181/v1', client._root_url) @@ -35,59 +35,74 @@ def test_client(self): self.assertEqual('localhost', self.myclient._host) self.assertEqual(8181, self.myclient._port) - def test_functions(self): - + def test_connection_to_opa(self): self.assertEqual("Yes I'm here :)", self.myclient.check_connection()) - self.assertEqual(list(), self.myclient.get_policies_list()) - - self.assertEqual(dict(), self.myclient.get_policies_info()) - - # _dict = {'test': {'path': [ - # 'http://localhost:8181/v1/data/play'], - # 'rules': ['http://localhost:8181/v1/data/play/hello']} - # } + + def test_functions(self): + new_policy = """ + package test.policy - # self.assertEqual(_dict, self.myclient.get_policies_info()) + import data.test.acl + import input - new_policy = """ - package play + default allow = false - default hello = false + allow { + access := acl[input.user] + access[_] == input.access + } - hello { - m := input.message - m == "world" + authorized_users[user] { + access := acl[user] + access[_] == input.access } """ - self.assertEqual(True, self.myclient.update_opa_policy_fromstring(new_policy, 'test')) - self.assertEqual(['test'], self.myclient.get_policies_list()) _dict = { 'test': { - 'path': ['http://localhost:8181/v1/data/play'], - 'rules': ['http://localhost:8181/v1/data/play/hello'], + 'path': ['http://localhost:8181/v1/data/test/policy'], + 'rules': [ + 'http://localhost:8181/v1/data/test/policy/allow', + 'http://localhost:8181/v1/data/test/policy/authorized_users' + ], } } - self.assertEqual(_dict, self.myclient.get_policies_info()) + my_policy_list = { + "alice": ["read","write"], + "bob": ["read"] + } - my_policy_list = [ - {'resource': '/api/someapi', 'identity': 'your_identity', 'method': 'PUT'}, - {'resource': '/api/someapi', 'identity': 'your_identity', 'method': 'GET'}, - ] + self.assertEqual(list(), self.myclient.get_policies_list()) + self.assertEqual(dict(), self.myclient.get_policies_info()) + self.assertEqual(True, self.myclient.update_opa_policy_fromstring(new_policy, 'test')) + self.assertEqual(['test'], self.myclient.get_policies_list()) + + policy_info = self.myclient.get_policies_info() + self.assertEqual(_dict['test']['path'], policy_info['test']['path']) + for rule in _dict['test']['rules']: + self.assertIn(rule, policy_info['test']['rules']) self.assertTrue( - True, self.myclient.update_or_create_opa_data(my_policy_list, 'exampledata/accesses') + True, self.myclient.update_or_create_opa_data(my_policy_list, 'test/acl') ) - value = {'result': {'hello': False}} self.assertEqual(True, self.myclient.opa_policy_to_file('test')) - self.assertEqual(value, self.myclient.get_opa_raw_data('play')) + value = {'result': {'acl': {'alice': ['read', 'write'], 'bob': ['read']}, 'policy': {'allow': False, 'authorized_users': []}}} + self.assertEqual(value, self.myclient.get_opa_raw_data('test')) + + _input_a = {"input": {"user": "alice", "access": "write"}} + _input_b = {"input": {"access": "read"}} + value_a = {"result": True} + value_b = {"result": ["alice", "bob"]} + self.assertEqual(value_a, self.myclient.check_permission(input_data=_input_a, policy_name="test", rule_name="allow")) + self.assertEqual(value_b, self.myclient.check_permission(input_data=_input_b, policy_name="test", rule_name="authorized_users")) self.assertTrue(True, self.myclient.delete_opa_policy('test')) with self.assertRaises(DeletePolicyError): self.myclient.delete_opa_policy('test') + self.assertTrue(True, self.myclient.delete_opa_data('test/acl')) with self.assertRaises(DeleteDataError): - self.myclient.delete_opa_data('play') + self.myclient.delete_opa_data('test/acl')