From 888b67f93ce5de1835318fe094f1d2f84740876a Mon Sep 17 00:00:00 2001 From: adeshmukh-ks Date: Tue, 3 Jun 2025 17:57:31 +0530 Subject: [PATCH 1/3] record_type_info and load_record_types functions added --- .../src/keepercli/commands/record_type.py | 88 ++++++++++ .../src/keepercli/register_commands.py | 2 + .../keepersdk/vault/record_type_management.py | 140 ++++++++++++++-- .../src/keepersdk/vault/record_type_utils.py | 148 +++++++++++++++++ .../unit_tests/test_record_type_management.py | 155 +++++++++++++++++- 5 files changed, 519 insertions(+), 14 deletions(-) create mode 100644 keepersdk-package/src/keepersdk/vault/record_type_utils.py diff --git a/keepercli-package/src/keepercli/commands/record_type.py b/keepercli-package/src/keepercli/commands/record_type.py index 643da9c6..43d8e37b 100644 --- a/keepercli-package/src/keepercli/commands/record_type.py +++ b/keepercli-package/src/keepercli/commands/record_type.py @@ -120,6 +120,94 @@ def execute(self, context: KeeperParams, **kwargs) -> None: logger.info(f"Custom record type deleted successfully with record type id: {result.recordTypeId}") +class RecordTypeInfoCommand(base.ArgparseCommand): + + def __init__(self): + self.parser = argparse.ArgumentParser( + prog='record-type-info', + description='Get record type info' + ) + RecordTypeInfoCommand.add_arguments_to_parser(self.parser) + super().__init__(self.parser) + + def add_arguments_to_parser(parser: argparse.ArgumentParser): + parser.add_argument( + '-lr', + '--list-record-type', + type=str, + dest='record_name', + action='store', + default=None, + const = '*', + nargs='?', + help='list record type by name or use * to list all' + ) + parser.add_argument( + '-lf', + '--list-field', + type=str, + dest='field_name', + action='store', + default=None, + help='list field type by name or use * to list all' + ) + parser.add_argument( + '-e', + '--example', + dest='example', + action='store_true', + help='Set to "true" to generate example JSON' + ) + + def execute(self, context: KeeperParams, **kwargs) -> None: + if not context.vault: + raise ValueError("Vault is not initialized.") + example = kwargs.get('example', False) + field = kwargs.get('field_name') + record_type = kwargs.get('record_name') + + result = record_type_management.record_type_info( + vault=context.vault, + field_name=field, + record_type_name=record_type, + example=example + ) + + logger.info(result) + + +class LoadRecordTypesCommand(base.ArgparseCommand): + + def __init__(self): + parser = argparse.ArgumentParser( + prog='load-record-types', + description='Loads custom record types from a JSON file.' + ) + parser.add_argument( + '--file', + dest='file', + action='store', + required=True, + help='Path to the JSON file containing the record type definition.' + ) + super().__init__(parser) + + def execute(self, context: KeeperParams, **kwargs) -> None: + if not context.vault: + raise ValueError("Vault is not initialized.") + + filepath = kwargs.get('file') + if not filepath: + raise ValueError("Missing required argument: --file") + + response = record_type_management.load_record_types(context.vault, filepath) + + if response != 0: + logger.info(f"Custom record types imported successfully. {response} record types were added.") + else: + logger.info("No custom record types were imported. Record types already exist in the vault or the file is empty.") + + record_implicit_fields = { 'title': '', # string 'custom': [], # Array of Field Data objects diff --git a/keepercli-package/src/keepercli/register_commands.py b/keepercli-package/src/keepercli/register_commands.py index 7d1cf747..04dff18b 100644 --- a/keepercli-package/src/keepercli/register_commands.py +++ b/keepercli-package/src/keepercli/register_commands.py @@ -45,6 +45,8 @@ def register_commands(commands: base.CliCommands, scopes: Optional[base.CommandS commands.register_command('record-type-add', record_type.RecordTypeAddCommand(), base.CommandScope.Vault) commands.register_command('record-type-edit', record_type.RecordTypeEditCommand(), base.CommandScope.Vault) commands.register_command('record-type-delete', record_type.RecordTypeDeleteCommand(), base.CommandScope.Vault) + commands.register_command('record-type-info', record_type.RecordTypeInfoCommand(), base.CommandScope.Vault, 'rti') + commands.register_command('load-record-types', record_type.LoadRecordTypesCommand(), base.CommandScope.Vault) if not scopes or bool(scopes & base.CommandScope.Enterprise): diff --git a/keepersdk-package/src/keepersdk/vault/record_type_management.py b/keepersdk-package/src/keepersdk/vault/record_type_management.py index 55168629..1388c01b 100644 --- a/keepersdk-package/src/keepersdk/vault/record_type_management.py +++ b/keepersdk-package/src/keepersdk/vault/record_type_management.py @@ -1,10 +1,14 @@ import json +import os +import tabulate -from typing import List, Dict +from typing import List, Dict, Optional -from . import vault_online, record_types +from . import vault_online, record_types, record_type_utils from ..proto import record_pb2 +from ..utils import get_logger +logger = get_logger() def create_custom_record_type(vault: vault_online.VaultOnline, title: str, fields: List[Dict[str, str]], description: str, categories: List[str] = None): is_enterprise_admin = vault.keeper_auth.auth_context.is_enterprise_admin @@ -47,7 +51,7 @@ def edit_custom_record_types(vault: vault_online.VaultOnline, record_type_id: in if not fields: raise ValueError('At least one field must be specified.') - is_enterprise_rt, real_type_id = isEnterpriseRecordType(record_type_id) + is_enterprise_rt, real_type_id = record_type_utils.isEnterpriseRecordType(record_type_id) if not is_enterprise_rt: raise ValueError('Only custom record types can be modified.') @@ -83,7 +87,7 @@ def delete_custom_record_types(vault: vault_online.VaultOnline, record_type_id: if not is_enterprise_admin: raise ValueError('This command is restricted to Keeper Enterprise administrators.') - is_enterprise_rt, real_type_id = isEnterpriseRecordType(record_type_id) + is_enterprise_rt, real_type_id = record_type_utils.isEnterpriseRecordType(record_type_id) if not is_enterprise_rt: raise ValueError('Only custom record types can be removed.') @@ -97,12 +101,122 @@ def delete_custom_record_types(vault: vault_online.VaultOnline, record_type_id: return response -def isEnterpriseRecordType(record_type_id: int) -> bool: - num_rts_per_scope = 1_000_000 - enterprise_scope = record_pb2.RT_ENTERPRISE - min_id = num_rts_per_scope * enterprise_scope - max_id = min_id + num_rts_per_scope - is_enterprise_rt = min_id < record_type_id <= max_id - real_type_id = record_type_id % num_rts_per_scope - - return is_enterprise_rt, real_type_id \ No newline at end of file +def record_type_info( + vault: vault_online.VaultOnline, + field_name: Optional[str] = None, + record_type_name: Optional[str] = None, + example: Optional[bool] = None, +): + #field types + if field_name is not None: + headers = ('Field Type ID', 'Lookup', 'Multiple', 'Description') + show_all_fields = field_name.strip() == '' or field_name.strip() == '*' + if show_all_fields: + rows = [] + for ft in record_types.FieldTypes.values(): + rows.append(record_type_utils.get_field_definitions(ft)) + return tabulate.tabulate(rows, headers=headers, tablefmt='simple') + else: + # Fetch a specific field type + ft = record_types.FieldTypes.get(field_name) + if not ft: + raise ValueError(f"Field type '{field_name}' is not a valid RecordField.") + row = record_type_utils.get_field_definitions(ft) + return tabulate.tabulate([row], headers=headers, tablefmt='simple') + + # Handle record type example + if record_type_name and record_type_name != '*' and record_type_name != '' and example: + record_type_example = record_type_utils.get_record_type_example(vault, record_type_name) + return record_type_example + + # Record Types + if record_type_name and record_type_name != '*' and record_type_name != '': + #Fetch a specific record type + record_type = vault.vault_data.get_record_type_by_name(record_type_name) + if not record_type: + raise ValueError(f"Record type '{record_type_name}' not found.") + + rows = [] + fields = record_type.fields + scope = record_type_utils.get_record_type_scope(record_type.scope) + rows.append([ + record_type.id, + record_type.name, + scope, + fields[0].label if hasattr(fields[0], 'label') else str(fields[0]) + ]) + for field in fields[1:]: + rows.append(['', '', '', field.label if hasattr(field, 'label') else str(field)]) + + headers = ('id', 'name', 'scope', 'fields') + return tabulate.tabulate(rows, headers=headers, tablefmt='simple') + else: + #Show all record types + record_types_list = record_type_utils.get_record_types(vault) + if not record_types_list: + raise ValueError("No record types found.") + + rows = [] + for rtid, name, scope in record_types_list: + rows.append([rtid, name, scope]) + + headers = ('Record Type ID', 'Record Type Name', 'Record Type Scope') + return tabulate.tabulate(rows, headers=headers, tablefmt='simple') + + +def load_record_types(vault: vault_online.VaultOnline, filepath) -> int: + count = 0 + + record_types_list = record_type_utils.validate_record_type_file(filepath) + + loaded_record_types = set() + existing_record_types = record_type_utils.get_record_types(vault) + if existing_record_types: + for existing_record_type in existing_record_types: + loaded_record_types.add(existing_record_type[1].lower()) + + for record_type in record_types_list: + record_type_name = record_type.get('record_type_name') + if not record_type_name: + logger.error('Record type name is missing in the record type definition.', record_type) + continue + + record_type_name = record_type_name[:30] + if record_type_name.lower() in loaded_record_types: + logger.info(f'Record type "{record_type_name}" already exists. Skipping.') + continue + + fields = record_type.get('fields') + if not isinstance(fields, list): + logger.error('Fields must be a list in the record type definition.', record_type) + continue + + is_valid = True + add_fields = [] + for field in fields: + field_type = field.get('$type') + if field_type not in record_types.RecordFields: + is_valid = False + break + fo = {'$ref': field.get('$type')} + if field.get('required') is True: + fo['required'] = True + add_fields.append(fo) + if not is_valid: + logger.error('Invalid field type in the record type definition.', record_type) + continue + + if len(add_fields) == 0: + logger.error('No fields found in the record type definition.', record_type) + continue + + create_custom_record_type( + vault=vault, + title=record_type_name, + fields=add_fields, + description=record_type.get('description') or '', + categories=record_type.get('categories') or [] + ) + count += 1 + + return count \ No newline at end of file diff --git a/keepersdk-package/src/keepersdk/vault/record_type_utils.py b/keepersdk-package/src/keepersdk/vault/record_type_utils.py new file mode 100644 index 00000000..60d10e75 --- /dev/null +++ b/keepersdk-package/src/keepersdk/vault/record_type_utils.py @@ -0,0 +1,148 @@ +import json + +from . import vault_online, storage_types, record_types, vault_types +from ..proto import record_pb2 + +def get_record_type_example(vault: vault_online.VaultOnline, record_type_name: str) -> str: + STR_VALUE = 'text' + + result = '' + rte = {} + record_type = vault.vault_data.get_record_type_by_name(record_type_name) + if record_type: + record_type_fields = record_type.fields + rte = { + 'type': record_type_name, + 'title': STR_VALUE, + 'notes': STR_VALUE, + 'fields': [], + 'custom': [] + } + + fields = record_type.fields or [] + fields = [x.label for x in fields] + for fname in fields: + ft = get_field_type(fname) + + required = next((x.required for x in record_type_fields if x.label == fname), None) + label = next((x.label for x in record_type_fields if x.label == fname), None) + + val = { + 'type': fname, + 'value': [ft.get('value') or ''], + 'required': required, + 'label': label + } + + if fname not in ('fileRef', 'addressRef', 'cardRef'): + if fname == 'phone' and ft and 'sample' in ft and 'region' in ft['sample']: + ft['sample']['region'] = 'US' + + rte['fields'].append(val) + else: + raise ValueError(f'No record type found with name {record_type_name}. Use "record-type-info" to list all record types') + + result = json.dumps(rte, indent=2) if rte else '' + return result + + +def get_record_types(vault:vault_online.VaultOnline) -> list[vault_types.RecordType]: + records = [] # (recordTypeId, name, scope) + record_types = vault.vault_data.get_record_types() + + if record_types: + for record_type in record_types: + name = record_type.name + scope = get_record_type_scope(record_type.scope) + records.append((record_type.id, name, scope)) + + return records + + +def get_field_type(id): + ftypes = [ + {**vars(record_types.RecordFields[rkey]), **vars(record_types.FieldTypes[fkey])} + for rkey in record_types.RecordFields + for fkey in record_types.FieldTypes + if record_types.RecordFields[rkey].type == record_types.FieldTypes[fkey].name + ] + result = next((ft for ft in ftypes if id.lower() == ft.get('name').lower()), {}) + if result: + # Determine value based on whether the id matches a FieldType or RecordField + field_type_obj = next((ft for ft in record_types.FieldTypes.values() if ft.name.lower() == id.lower()), None) + + if field_type_obj: + value = getattr(field_type_obj, 'value', None) + else: + value = result.get('type', None) + + result = { + 'id': result.get('$id') or result.get('name') or '', + 'type': result.get('type') or result.get('name') or '', + 'value': value, + } + return result + + +def isEnterpriseRecordType(record_type_id: int) -> bool: + num_rts_per_scope = 1_000_000 + enterprise_scope = record_pb2.RT_ENTERPRISE + min_id = num_rts_per_scope * enterprise_scope + max_id = min_id + num_rts_per_scope + is_enterprise_rt = min_id < record_type_id <= max_id + real_type_id = record_type_id % num_rts_per_scope + + return is_enterprise_rt, real_type_id + + +def get_field_definitions(field: record_types.FieldType): + recordfield_names = {rf.name for rf in record_types.RecordFields.values()} + lookup = field.name if field.name in recordfield_names else "" + multiple = ( + record_types.RecordFields[field.name].multiple.name + if lookup else "Optional" + ) + row = [ + field.name, + lookup, + multiple, + field.description + ] + return row + + +scope_map = { + storage_types.RecordTypeScope.Standard: 'Standard', + storage_types.RecordTypeScope.User: 'User', + storage_types.RecordTypeScope.Enterprise: 'Enterprise' +} + + +def get_record_type_scope(scope: storage_types.RecordTypeScope) -> str: + return scope_map.get(scope, str(scope)) + + +def validate_record_type_file(file_path: str) -> list: + if not file_path: + raise ValueError('File path is required.') + + if not file_path.endswith('.json'): + raise ValueError('Record type file must be a JSON file.') + + try: + with open(file_path, 'r') as f: + json_obj = json.load(f) + except json.JSONDecodeError as e: + raise ValueError(f'Invalid JSON in record type file: {e}') + except FileNotFoundError: + raise ValueError(f'Record type file not found: {file_path}') + + if not isinstance(json_obj, dict): + raise ValueError('Invalid custom record types file') + + record_types_list = json_obj.get('record_types') + + if not isinstance(record_types_list, list): + raise ValueError('Invalid custom record types list') + + return record_types_list \ No newline at end of file diff --git a/keepersdk-package/unit_tests/test_record_type_management.py b/keepersdk-package/unit_tests/test_record_type_management.py index 6f63e374..df70bb3f 100644 --- a/keepersdk-package/unit_tests/test_record_type_management.py +++ b/keepersdk-package/unit_tests/test_record_type_management.py @@ -1,5 +1,5 @@ import unittest -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch from keepersdk.proto import record_pb2 from keepersdk.vault import record_type_management @@ -136,5 +136,158 @@ def test_not_enterprise_record_type_id(self): self.assertIn("can be removed", str(cm.exception)) +class RecordTypeInfoTestCase(unittest.TestCase): + def setUp(self): + self.vault = MagicMock() + self.vault.vault_data.get_record_type_by_name = MagicMock() + self.vault.vault_data.get_record_types = MagicMock() + + @patch('keepersdk.vault.record_type_management.tabulate') + @patch('keepersdk.vault.record_type_management.record_types') + def test_field_name_all(self, mock_record_types, mock_tabulate): + # Setup mock FieldTypes and RecordFields + mock_ft = MagicMock() + mock_ft.name = 'login' + mock_ft.description = 'desc' + mock_record_types.FieldTypes.values.return_value = [mock_ft] + mock_rf = MagicMock() + mock_rf.name = 'login' + mock_rf.multiple.name = 'Optional' + mock_record_types.RecordFields.values.return_value = [mock_rf] + mock_tabulate.tabulate.return_value = 'table' + result = record_type_management.record_type_info(self.vault, field_name='*') + self.assertEqual(result, 'table') + + @patch('keepersdk.vault.record_type_management.tabulate') + @patch('keepersdk.vault.record_type_management.record_types') + def test_field_name_specific(self, mock_record_types, mock_tabulate): + mock_ft = MagicMock() + mock_ft.name = 'login' + mock_ft.description = 'desc' + mock_record_types.FieldTypes.get.return_value = mock_ft + mock_rf = MagicMock() + mock_rf.name = 'login' + mock_rf.multiple.name = 'Optional' + mock_record_types.RecordFields.values.return_value = [mock_rf] + mock_tabulate.tabulate.return_value = 'table' + result = record_type_management.record_type_info(self.vault, field_name='login') + self.assertEqual(result, 'table') + + @patch('keepersdk.vault.record_type_management.record_type_utils') + def test_record_type_example(self, mock_utils): + mock_utils.get_record_type_example.return_value = '{"type": "login"}' + result = record_type_management.record_type_info(self.vault, record_type_name='login', example=True) + self.assertEqual(result, '{"type": "login"}') + + @patch('keepersdk.vault.record_type_management.tabulate') + @patch('keepersdk.vault.record_type_management.record_type_utils') + def test_record_type_name_all(self, mock_utils, mock_tabulate): + mock_utils.get_record_types.return_value = [(1, 'login', 'Standard')] + mock_tabulate.tabulate.return_value = 'table' + result = record_type_management.record_type_info(self.vault, record_type_name='*') + self.assertEqual(result, 'table') + + def test_record_type_name_not_found(self): + self.vault.vault_data.get_record_type_by_name.return_value = None + with self.assertRaises(ValueError) as cm: + record_type_management.record_type_info(self.vault, record_type_name='notfound') + self.assertIn('not found', str(cm.exception)) + + @patch('keepersdk.vault.record_type_management.tabulate') + def test_record_type_name_details(self, mock_tabulate): + mock_record_type = MagicMock() + mock_record_type.id = 1 + mock_record_type.name = 'login' + mock_record_type.scope = 0 + field = MagicMock() + field.label = 'username' + mock_record_type.fields = [field] + self.vault.vault_data.get_record_type_by_name.return_value = mock_record_type + mock_tabulate.tabulate.return_value = 'table' + result = record_type_management.record_type_info(self.vault, record_type_name='login') + self.assertEqual(result, 'table') + + +class LoadRecordTypesTestCase(unittest.TestCase): + def setUp(self): + self.vault = MagicMock() + self.filepath = 'dummy.json' + self.patcher_validate = patch('keepersdk.vault.record_type_management.record_type_utils.validate_record_type_file') + self.mock_validate = self.patcher_validate.start() + self.addCleanup(self.patcher_validate.stop) + self.patcher_create = patch('keepersdk.vault.record_type_management.create_custom_record_type') + self.mock_create = self.patcher_create.start() + self.addCleanup(self.patcher_create.stop) + self.patcher_get_types = patch('keepersdk.vault.record_type_management.record_type_utils.get_record_types') + self.mock_get_types = self.patcher_get_types.start() + self.addCleanup(self.patcher_get_types.stop) + self.patcher_record_fields = patch('keepersdk.vault.record_type_management.record_types.RecordFields', {}) + self.mock_record_fields = self.patcher_record_fields.start() + self.addCleanup(self.patcher_record_fields.stop) + + def test_file_not_found(self): + self.mock_validate.side_effect = ValueError('Record type file not found: dummy.json') + with self.assertRaises(ValueError) as cm: + record_type_management.load_record_types(self.vault, self.filepath) + self.assertIn('Record type file not found', str(cm.exception)) + + def test_invalid_json(self): + self.mock_validate.side_effect = ValueError('Invalid JSON in record type file: ...') + with self.assertRaises(ValueError) as cm: + record_type_management.load_record_types(self.vault, self.filepath) + self.assertIn('Invalid JSON in record type file', str(cm.exception)) + + def test_json_not_dict(self): + self.mock_validate.side_effect = ValueError('Invalid custom record types file') + with self.assertRaises(ValueError) as cm: + record_type_management.load_record_types(self.vault, self.filepath) + self.assertIn('Invalid custom record types file', str(cm.exception)) + + def test_missing_record_types_list(self): + self.mock_validate.side_effect = ValueError('Invalid custom record types list') + with self.assertRaises(ValueError) as cm: + record_type_management.load_record_types(self.vault, self.filepath) + self.assertIn('Invalid custom record types list', str(cm.exception)) + + def test_record_types_list_not_list(self): + self.mock_validate.side_effect = ValueError('Invalid custom record types list') + with self.assertRaises(ValueError) as cm: + record_type_management.load_record_types(self.vault, self.filepath) + self.assertIn('Invalid custom record types list', str(cm.exception)) + + def test_skip_record_type_without_name(self): + self.mock_validate.return_value = [{}] + self.mock_get_types.return_value = [] + result = record_type_management.load_record_types(self.vault, self.filepath) + self.assertEqual(result, 0) + self.mock_create.assert_not_called() + + def test_skip_existing_record_type(self): + self.mock_validate.return_value = [{"record_type_name": "foo", "fields": [{"$type": "login", "$ref": "login"}]}] + mock_existing = MagicMock() + mock_existing.name = 'foo' + self.mock_get_types.return_value = [(1, 'foo', 'Enterprise')] + result = record_type_management.load_record_types(self.vault, self.filepath) + self.assertEqual(result, 0) + self.mock_create.assert_not_called() + + def test_skip_invalid_fields(self): + self.mock_validate.return_value = [{"record_type_name": "foo", "fields": [{"$type": "invalid", "$ref": "login"}]}] + self.mock_get_types.return_value = [] + with patch.dict('keepersdk.vault.record_type_management.record_types.RecordFields', {'login': MagicMock()}): + result = record_type_management.load_record_types(self.vault, self.filepath) + self.assertEqual(result, 0) + self.mock_create.assert_not_called() + + def test_successful_add(self): + self.mock_validate.return_value = [{"record_type_name": "foo", "fields": [{"$type": "login", "$ref": "login"}]}] + self.mock_get_types.return_value = [] + with patch.dict('keepersdk.vault.record_type_management.record_types.RecordFields', {'login': MagicMock()}): + self.mock_create.return_value = True + result = record_type_management.load_record_types(self.vault, self.filepath) + self.assertEqual(result, 1) + self.mock_create.assert_called_once() + + if __name__ == "__main__": unittest.main() From 3391fdb5dbaeb9429f5ca9351c84200a64013785 Mon Sep 17 00:00:00 2001 From: adeshmukh-ks Date: Mon, 9 Jun 2025 15:02:26 +0530 Subject: [PATCH 2/3] Moved functions to cli commands --- .../src/keepercli/commands/record_type.py | 196 +++++++++++++++++- .../keepercli/commands/record_type_utils.py | 148 +++++++++++++ .../keepersdk/vault/record_type_management.py | 134 +----------- .../unit_tests/test_record_type_management.py | 12 -- 4 files changed, 344 insertions(+), 146 deletions(-) create mode 100644 keepercli-package/src/keepercli/commands/record_type_utils.py diff --git a/keepercli-package/src/keepercli/commands/record_type.py b/keepercli-package/src/keepercli/commands/record_type.py index 43d8e37b..2f452fcd 100644 --- a/keepercli-package/src/keepercli/commands/record_type.py +++ b/keepercli-package/src/keepercli/commands/record_type.py @@ -2,11 +2,12 @@ import json import logging -from keepersdk.vault import record_type_management +from keepersdk.vault import record_type_management, record_types -from . import base +from . import base, record_type_utils from ..params import KeeperParams from .. import api +from ..helpers import report_utils logger = api.get_logger() @@ -45,6 +46,7 @@ def execute(self, context: KeeperParams, **kwargs) -> None: context.vault, title, fields, description, categories ) logger.info(f"Custom record type '{title}' created successfully with fields: {[f['$ref'] for f in fields]} and recordTypeId: {result.recordTypeId}") + return class RecordTypeEditCommand(base.ArgparseCommand): @@ -92,6 +94,7 @@ def execute(self, context: KeeperParams, **kwargs) -> None: context.vault, record_type_id, title, fields, description, categories ) logger.info(f"Custom record type (ID: {record_type_id}) updated successfully with fields: {[f['$ref'] for f in fields]} and recordTypeId: {result.recordTypeId}") + return class RecordTypeDeleteCommand(base.ArgparseCommand): @@ -118,6 +121,195 @@ def execute(self, context: KeeperParams, **kwargs) -> None: result = record_type_management.delete_custom_record_types(context.vault, record_type_id) logger.info(f"Custom record type deleted successfully with record type id: {result.recordTypeId}") + return + + +class RecordTypeInfoCommand(base.ArgparseCommand): + + def __init__(self): + self.parser = argparse.ArgumentParser( + prog='record-type-info', + description='Get record type info' + ) + RecordTypeInfoCommand.add_arguments_to_parser(self.parser) + super().__init__(self.parser) + + def add_arguments_to_parser(parser: argparse.ArgumentParser): + parser.add_argument( + '-lr', + '--list-record-type', + type=str, + dest='record_name', + action='store', + default=None, + const = '*', + nargs='?', + help='list record type by name or use * to list all' + ) + parser.add_argument( + '-lf', + '--list-field', + type=str, + dest='field_name', + action='store', + default=None, + help='list field type by name or use * to list all' + ) + parser.add_argument( + '-e', + '--example', + dest='example', + action='store_true', + help='Set to "true" to generate example JSON' + ) + + def execute(self, context: KeeperParams, **kwargs) -> None: + if not context.vault: + raise ValueError("Vault is not initialized.") + + vault = context.vault + example = kwargs.get('example', False) + field_name = kwargs.get('field_name') + record_type_name = kwargs.get('record_name') + + if field_name is not None: + headers = ('Field Type ID', 'Lookup', 'Multiple', 'Description') + show_all_fields = field_name.strip() == '' or field_name.strip() == '*' + if show_all_fields: + rows = [] + for ft in record_types.FieldTypes.values(): + rows.append(record_type_utils.get_field_definitions(ft)) + return report_utils.dump_report_data(rows, headers, column_width='auto', fmt='simple') + else: + # Fetch a specific field type + ft = record_types.FieldTypes.get(field_name) + if not ft: + raise ValueError(f"Field type '{field_name}' is not a valid RecordField.") + row = record_type_utils.get_field_definitions(ft) + return report_utils.dump_report_data([row], headers, column_width='auto', fmt='simple') + + if record_type_name and record_type_name != '*' and record_type_name != '' and example: + record_type_example = record_type_utils.get_record_type_example(vault, record_type_name) + logger.info(record_type_example) + return + + # Record Types + if record_type_name and record_type_name != '*' and record_type_name != '': + #Fetch a specific record type + record_type = vault.vault_data.get_record_type_by_name(record_type_name) + if not record_type: + raise ValueError(f"Record type '{record_type_name}' not found.") + + rows = [] + fields = record_type.fields + scope = record_type_utils.get_record_type_scope(record_type.scope) + rows.append([ + record_type.id, + record_type.name, + scope, + fields[0].label if hasattr(fields[0], 'label') else str(fields[0]) + ]) + for field in fields[1:]: + rows.append(['', '', '', field.label if hasattr(field, 'label') else str(field)]) + + headers = ('id', 'name', 'scope', 'fields') + return report_utils.dump_report_data(rows, headers, column_width='auto', fmt='simple') + else: + #Show all record types + record_types_list = record_type_utils.get_record_types(vault) + if not record_types_list: + raise ValueError("No record types found.") + + rows = [] + for rtid, name, scope in record_types_list: + rows.append([rtid, name, scope]) + + headers = ('Record Type ID', 'Record Type Name', 'Record Type Scope') + return report_utils.dump_report_data(rows, headers, column_width='auto', fmt='simple') + + +class LoadRecordTypesCommand(base.ArgparseCommand): + + def __init__(self): + parser = argparse.ArgumentParser( + prog='load-record-types', + description='Loads custom record types from a JSON file.' + ) + parser.add_argument( + '--file', + dest='file', + action='store', + required=True, + help='Path to the JSON file containing the record type definition.' + ) + super().__init__(parser) + + def execute(self, context: KeeperParams, **kwargs) -> None: + if not context.vault: + raise ValueError("Vault is not initialized.") + + filepath = kwargs.get('file') + if not filepath: + raise ValueError("Missing required argument: --file") + + count = 0 + record_types_list = record_type_utils.validate_record_type_file(filepath) + + loaded_record_types = set() + existing_record_types = record_type_utils.get_record_types(context.vault) + if existing_record_types: + for existing_record_type in existing_record_types: + loaded_record_types.add(existing_record_type[1].lower()) + + for record_type in record_types_list: + record_type_name = record_type.get('record_type_name') + if not record_type_name: + logger.error('Record type name is missing in the record type definition.', record_type) + continue + + record_type_name = record_type_name[:30] + if record_type_name.lower() in loaded_record_types: + logger.info(f'Record type "{record_type_name}" already exists. Skipping.') + continue + + fields = record_type.get('fields') + if not isinstance(fields, list): + logger.error('Fields must be a list in the record type definition.', record_type) + continue + + is_valid = True + add_fields = [] + for field in fields: + field_type = field.get('$type') + if field_type not in record_types.RecordFields: + is_valid = False + break + fo = {'$ref': field.get('$type')} + if field.get('required') is True: + fo['required'] = True + add_fields.append(fo) + if not is_valid: + logger.error('Invalid field type in the record type definition.', record_type) + continue + + if len(add_fields) == 0: + logger.error('No fields found in the record type definition.', record_type) + continue + + record_type_management.create_custom_record_type( + vault=context.vault, + title=record_type_name, + fields=add_fields, + description=record_type.get('description') or '', + categories=record_type.get('categories') or [] + ) + count += 1 + + if count != 0: + logger.info(f"Custom record types imported successfully. {count} record types were added.") + else: + logger.info("No custom record types were imported. Record types already exist in the vault or the file is empty.") + return class RecordTypeInfoCommand(base.ArgparseCommand): diff --git a/keepercli-package/src/keepercli/commands/record_type_utils.py b/keepercli-package/src/keepercli/commands/record_type_utils.py new file mode 100644 index 00000000..204f97ac --- /dev/null +++ b/keepercli-package/src/keepercli/commands/record_type_utils.py @@ -0,0 +1,148 @@ +import json + +from keepersdk.vault import vault_online, storage_types, record_types, vault_types +from keepersdk.proto import record_pb2 + +def get_record_type_example(vault: vault_online.VaultOnline, record_type_name: str) -> str: + STR_VALUE = 'text' + + result = '' + rte = {} + record_type = vault.vault_data.get_record_type_by_name(record_type_name) + if record_type: + record_type_fields = record_type.fields + rte = { + 'type': record_type_name, + 'title': STR_VALUE, + 'notes': STR_VALUE, + 'fields': [], + 'custom': [] + } + + fields = record_type.fields or [] + fields = [x.label for x in fields] + for fname in fields: + ft = get_field_type(fname) + + required = next((x.required for x in record_type_fields if x.label == fname), None) + label = next((x.label for x in record_type_fields if x.label == fname), None) + + val = { + 'type': fname, + 'value': [ft.get('value') or ''], + 'required': required, + 'label': label + } + + if fname not in ('fileRef', 'addressRef', 'cardRef'): + if fname == 'phone' and ft and 'sample' in ft and 'region' in ft['sample']: + ft['sample']['region'] = 'US' + + rte['fields'].append(val) + else: + raise ValueError(f'No record type found with name {record_type_name}. Use "record-type-info" to list all record types') + + result = json.dumps(rte, indent=2) if rte else '' + return result + + +def get_record_types(vault:vault_online.VaultOnline) -> list[vault_types.RecordType]: + records = [] # (recordTypeId, name, scope) + record_types = vault.vault_data.get_record_types() + + if record_types: + for record_type in record_types: + name = record_type.name + scope = get_record_type_scope(record_type.scope) + records.append((record_type.id, name, scope)) + + return records + + +def get_field_type(id): + ftypes = [ + {**vars(record_types.RecordFields[rkey]), **vars(record_types.FieldTypes[fkey])} + for rkey in record_types.RecordFields + for fkey in record_types.FieldTypes + if record_types.RecordFields[rkey].type == record_types.FieldTypes[fkey].name + ] + result = next((ft for ft in ftypes if id.lower() == ft.get('name').lower()), {}) + if result: + # Determine value based on whether the id matches a FieldType or RecordField + field_type_obj = next((ft for ft in record_types.FieldTypes.values() if ft.name.lower() == id.lower()), None) + + if field_type_obj: + value = getattr(field_type_obj, 'value', None) + else: + value = result.get('type', None) + + result = { + 'id': result.get('$id') or result.get('name') or '', + 'type': result.get('type') or result.get('name') or '', + 'value': value, + } + return result + + +def isEnterpriseRecordType(record_type_id: int) -> bool: + num_rts_per_scope = 1_000_000 + enterprise_scope = record_pb2.RT_ENTERPRISE + min_id = num_rts_per_scope * enterprise_scope + max_id = min_id + num_rts_per_scope + is_enterprise_rt = min_id < record_type_id <= max_id + real_type_id = record_type_id % num_rts_per_scope + + return is_enterprise_rt, real_type_id + + +def get_field_definitions(field: record_types.FieldType): + recordfield_names = {rf.name for rf in record_types.RecordFields.values()} + lookup = field.name if field.name in recordfield_names else "" + multiple = ( + record_types.RecordFields[field.name].multiple.name + if lookup else "Optional" + ) + row = [ + field.name, + lookup, + multiple, + field.description + ] + return row + + +scope_map = { + storage_types.RecordTypeScope.Standard: 'Standard', + storage_types.RecordTypeScope.User: 'User', + storage_types.RecordTypeScope.Enterprise: 'Enterprise' +} + + +def get_record_type_scope(scope: storage_types.RecordTypeScope) -> str: + return scope_map.get(scope, str(scope)) + + +def validate_record_type_file(file_path: str) -> list: + if not file_path: + raise ValueError('File path is required.') + + if not file_path.endswith('.json'): + raise ValueError('Record type file must be a JSON file.') + + try: + with open(file_path, 'r') as f: + json_obj = json.load(f) + except json.JSONDecodeError as e: + raise ValueError(f'Invalid JSON in record type file: {e}') + except FileNotFoundError: + raise ValueError(f'Record type file not found: {file_path}') + + if not isinstance(json_obj, dict): + raise ValueError('Invalid custom record types file') + + record_types_list = json_obj.get('record_types') + + if not isinstance(record_types_list, list): + raise ValueError('Invalid custom record types list') + + return record_types_list \ No newline at end of file diff --git a/keepersdk-package/src/keepersdk/vault/record_type_management.py b/keepersdk-package/src/keepersdk/vault/record_type_management.py index 1388c01b..66e59dd9 100644 --- a/keepersdk-package/src/keepersdk/vault/record_type_management.py +++ b/keepersdk-package/src/keepersdk/vault/record_type_management.py @@ -51,11 +51,6 @@ def edit_custom_record_types(vault: vault_online.VaultOnline, record_type_id: in if not fields: raise ValueError('At least one field must be specified.') - is_enterprise_rt, real_type_id = record_type_utils.isEnterpriseRecordType(record_type_id) - - if not is_enterprise_rt: - raise ValueError('Only custom record types can be modified.') - field_definitions = [] for field in fields: field_name = field.get("$ref") @@ -75,7 +70,7 @@ def edit_custom_record_types(vault: vault_online.VaultOnline, record_type_id: in request_payload = record_pb2.RecordType() request_payload.content = json.dumps(record_type_data) request_payload.scope = record_pb2.RT_ENTERPRISE - request_payload.recordTypeId = real_type_id + request_payload.recordTypeId = record_type_id response = vault.keeper_auth.execute_auth_rest('vault/record_type_update', request_payload, response_type=record_pb2.RecordTypeModifyResponse) @@ -86,137 +81,12 @@ def delete_custom_record_types(vault: vault_online.VaultOnline, record_type_id: is_enterprise_admin = vault.keeper_auth.auth_context.is_enterprise_admin if not is_enterprise_admin: raise ValueError('This command is restricted to Keeper Enterprise administrators.') - - is_enterprise_rt, real_type_id = record_type_utils.isEnterpriseRecordType(record_type_id) - - if not is_enterprise_rt: - raise ValueError('Only custom record types can be removed.') request_payload = record_pb2.RecordType() request_payload.scope = record_pb2.RT_ENTERPRISE - request_payload.recordTypeId = real_type_id + request_payload.recordTypeId = record_type_id response = vault.keeper_auth.execute_auth_rest('vault/record_type_delete', request_payload, response_type=record_pb2.RecordTypeModifyResponse) return response - -def record_type_info( - vault: vault_online.VaultOnline, - field_name: Optional[str] = None, - record_type_name: Optional[str] = None, - example: Optional[bool] = None, -): - #field types - if field_name is not None: - headers = ('Field Type ID', 'Lookup', 'Multiple', 'Description') - show_all_fields = field_name.strip() == '' or field_name.strip() == '*' - if show_all_fields: - rows = [] - for ft in record_types.FieldTypes.values(): - rows.append(record_type_utils.get_field_definitions(ft)) - return tabulate.tabulate(rows, headers=headers, tablefmt='simple') - else: - # Fetch a specific field type - ft = record_types.FieldTypes.get(field_name) - if not ft: - raise ValueError(f"Field type '{field_name}' is not a valid RecordField.") - row = record_type_utils.get_field_definitions(ft) - return tabulate.tabulate([row], headers=headers, tablefmt='simple') - - # Handle record type example - if record_type_name and record_type_name != '*' and record_type_name != '' and example: - record_type_example = record_type_utils.get_record_type_example(vault, record_type_name) - return record_type_example - - # Record Types - if record_type_name and record_type_name != '*' and record_type_name != '': - #Fetch a specific record type - record_type = vault.vault_data.get_record_type_by_name(record_type_name) - if not record_type: - raise ValueError(f"Record type '{record_type_name}' not found.") - - rows = [] - fields = record_type.fields - scope = record_type_utils.get_record_type_scope(record_type.scope) - rows.append([ - record_type.id, - record_type.name, - scope, - fields[0].label if hasattr(fields[0], 'label') else str(fields[0]) - ]) - for field in fields[1:]: - rows.append(['', '', '', field.label if hasattr(field, 'label') else str(field)]) - - headers = ('id', 'name', 'scope', 'fields') - return tabulate.tabulate(rows, headers=headers, tablefmt='simple') - else: - #Show all record types - record_types_list = record_type_utils.get_record_types(vault) - if not record_types_list: - raise ValueError("No record types found.") - - rows = [] - for rtid, name, scope in record_types_list: - rows.append([rtid, name, scope]) - - headers = ('Record Type ID', 'Record Type Name', 'Record Type Scope') - return tabulate.tabulate(rows, headers=headers, tablefmt='simple') - - -def load_record_types(vault: vault_online.VaultOnline, filepath) -> int: - count = 0 - - record_types_list = record_type_utils.validate_record_type_file(filepath) - - loaded_record_types = set() - existing_record_types = record_type_utils.get_record_types(vault) - if existing_record_types: - for existing_record_type in existing_record_types: - loaded_record_types.add(existing_record_type[1].lower()) - - for record_type in record_types_list: - record_type_name = record_type.get('record_type_name') - if not record_type_name: - logger.error('Record type name is missing in the record type definition.', record_type) - continue - - record_type_name = record_type_name[:30] - if record_type_name.lower() in loaded_record_types: - logger.info(f'Record type "{record_type_name}" already exists. Skipping.') - continue - - fields = record_type.get('fields') - if not isinstance(fields, list): - logger.error('Fields must be a list in the record type definition.', record_type) - continue - - is_valid = True - add_fields = [] - for field in fields: - field_type = field.get('$type') - if field_type not in record_types.RecordFields: - is_valid = False - break - fo = {'$ref': field.get('$type')} - if field.get('required') is True: - fo['required'] = True - add_fields.append(fo) - if not is_valid: - logger.error('Invalid field type in the record type definition.', record_type) - continue - - if len(add_fields) == 0: - logger.error('No fields found in the record type definition.', record_type) - continue - - create_custom_record_type( - vault=vault, - title=record_type_name, - fields=add_fields, - description=record_type.get('description') or '', - categories=record_type.get('categories') or [] - ) - count += 1 - - return count \ No newline at end of file diff --git a/keepersdk-package/unit_tests/test_record_type_management.py b/keepersdk-package/unit_tests/test_record_type_management.py index df70bb3f..a9888c6f 100644 --- a/keepersdk-package/unit_tests/test_record_type_management.py +++ b/keepersdk-package/unit_tests/test_record_type_management.py @@ -79,12 +79,6 @@ def test_not_enterprise_admin(self): record_type_management.edit_custom_record_types(self.vault, record_type_id, "Title", [{"$ref": "login"}], "desc", ["test"]) self.assertIn("restricted to Keeper Enterprise administrators", str(cm.exception)) - def test_not_enterprise_record_type_id(self): - record_type_id = 1 - with self.assertRaises(ValueError) as cm: - record_type_management.edit_custom_record_types(self.vault, record_type_id, "Title", [{"$ref": "login"}], "desc", ["test"]) - self.assertIn("can be modified", str(cm.exception)) - def test_missing_fields(self): record_type_id = 2000001 with self.assertRaises(ValueError) as cm: @@ -129,12 +123,6 @@ def test_not_enterprise_admin(self): record_type_management.delete_custom_record_types(self.vault, record_type_id) self.assertIn("restricted to Keeper Enterprise administrators", str(cm.exception)) - def test_not_enterprise_record_type_id(self): - record_type_id = 1 - with self.assertRaises(ValueError) as cm: - record_type_management.delete_custom_record_types(self.vault, record_type_id) - self.assertIn("can be removed", str(cm.exception)) - class RecordTypeInfoTestCase(unittest.TestCase): def setUp(self): From 470670f067ccd4584a59dfb6ad0ec6e26600ee1c Mon Sep 17 00:00:00 2001 From: adeshmukh-ks Date: Tue, 10 Jun 2025 10:29:46 +0530 Subject: [PATCH 3/3] Corrected review changes --- .../src/keepercli/commands/record_type.py | 90 +---------- .../keepersdk/vault/record_type_management.py | 6 +- .../src/keepersdk/vault/record_type_utils.py | 148 ----------------- .../unit_tests/test_record_type_management.py | 153 ------------------ 4 files changed, 3 insertions(+), 394 deletions(-) delete mode 100644 keepersdk-package/src/keepersdk/vault/record_type_utils.py diff --git a/keepercli-package/src/keepercli/commands/record_type.py b/keepercli-package/src/keepercli/commands/record_type.py index 2f452fcd..4151f471 100644 --- a/keepercli-package/src/keepercli/commands/record_type.py +++ b/keepercli-package/src/keepercli/commands/record_type.py @@ -160,7 +160,7 @@ def add_arguments_to_parser(parser: argparse.ArgumentParser): '--example', dest='example', action='store_true', - help='Set to "true" to generate example JSON' + help='Use --example to generate example JSON' ) def execute(self, context: KeeperParams, **kwargs) -> None: @@ -312,94 +312,6 @@ def execute(self, context: KeeperParams, **kwargs) -> None: return -class RecordTypeInfoCommand(base.ArgparseCommand): - - def __init__(self): - self.parser = argparse.ArgumentParser( - prog='record-type-info', - description='Get record type info' - ) - RecordTypeInfoCommand.add_arguments_to_parser(self.parser) - super().__init__(self.parser) - - def add_arguments_to_parser(parser: argparse.ArgumentParser): - parser.add_argument( - '-lr', - '--list-record-type', - type=str, - dest='record_name', - action='store', - default=None, - const = '*', - nargs='?', - help='list record type by name or use * to list all' - ) - parser.add_argument( - '-lf', - '--list-field', - type=str, - dest='field_name', - action='store', - default=None, - help='list field type by name or use * to list all' - ) - parser.add_argument( - '-e', - '--example', - dest='example', - action='store_true', - help='Set to "true" to generate example JSON' - ) - - def execute(self, context: KeeperParams, **kwargs) -> None: - if not context.vault: - raise ValueError("Vault is not initialized.") - example = kwargs.get('example', False) - field = kwargs.get('field_name') - record_type = kwargs.get('record_name') - - result = record_type_management.record_type_info( - vault=context.vault, - field_name=field, - record_type_name=record_type, - example=example - ) - - logger.info(result) - - -class LoadRecordTypesCommand(base.ArgparseCommand): - - def __init__(self): - parser = argparse.ArgumentParser( - prog='load-record-types', - description='Loads custom record types from a JSON file.' - ) - parser.add_argument( - '--file', - dest='file', - action='store', - required=True, - help='Path to the JSON file containing the record type definition.' - ) - super().__init__(parser) - - def execute(self, context: KeeperParams, **kwargs) -> None: - if not context.vault: - raise ValueError("Vault is not initialized.") - - filepath = kwargs.get('file') - if not filepath: - raise ValueError("Missing required argument: --file") - - response = record_type_management.load_record_types(context.vault, filepath) - - if response != 0: - logger.info(f"Custom record types imported successfully. {response} record types were added.") - else: - logger.info("No custom record types were imported. Record types already exist in the vault or the file is empty.") - - record_implicit_fields = { 'title': '', # string 'custom': [], # Array of Field Data objects diff --git a/keepersdk-package/src/keepersdk/vault/record_type_management.py b/keepersdk-package/src/keepersdk/vault/record_type_management.py index 66e59dd9..3e103a72 100644 --- a/keepersdk-package/src/keepersdk/vault/record_type_management.py +++ b/keepersdk-package/src/keepersdk/vault/record_type_management.py @@ -1,10 +1,8 @@ import json -import os -import tabulate -from typing import List, Dict, Optional +from typing import List, Dict -from . import vault_online, record_types, record_type_utils +from . import vault_online, record_types from ..proto import record_pb2 from ..utils import get_logger diff --git a/keepersdk-package/src/keepersdk/vault/record_type_utils.py b/keepersdk-package/src/keepersdk/vault/record_type_utils.py deleted file mode 100644 index 60d10e75..00000000 --- a/keepersdk-package/src/keepersdk/vault/record_type_utils.py +++ /dev/null @@ -1,148 +0,0 @@ -import json - -from . import vault_online, storage_types, record_types, vault_types -from ..proto import record_pb2 - -def get_record_type_example(vault: vault_online.VaultOnline, record_type_name: str) -> str: - STR_VALUE = 'text' - - result = '' - rte = {} - record_type = vault.vault_data.get_record_type_by_name(record_type_name) - if record_type: - record_type_fields = record_type.fields - rte = { - 'type': record_type_name, - 'title': STR_VALUE, - 'notes': STR_VALUE, - 'fields': [], - 'custom': [] - } - - fields = record_type.fields or [] - fields = [x.label for x in fields] - for fname in fields: - ft = get_field_type(fname) - - required = next((x.required for x in record_type_fields if x.label == fname), None) - label = next((x.label for x in record_type_fields if x.label == fname), None) - - val = { - 'type': fname, - 'value': [ft.get('value') or ''], - 'required': required, - 'label': label - } - - if fname not in ('fileRef', 'addressRef', 'cardRef'): - if fname == 'phone' and ft and 'sample' in ft and 'region' in ft['sample']: - ft['sample']['region'] = 'US' - - rte['fields'].append(val) - else: - raise ValueError(f'No record type found with name {record_type_name}. Use "record-type-info" to list all record types') - - result = json.dumps(rte, indent=2) if rte else '' - return result - - -def get_record_types(vault:vault_online.VaultOnline) -> list[vault_types.RecordType]: - records = [] # (recordTypeId, name, scope) - record_types = vault.vault_data.get_record_types() - - if record_types: - for record_type in record_types: - name = record_type.name - scope = get_record_type_scope(record_type.scope) - records.append((record_type.id, name, scope)) - - return records - - -def get_field_type(id): - ftypes = [ - {**vars(record_types.RecordFields[rkey]), **vars(record_types.FieldTypes[fkey])} - for rkey in record_types.RecordFields - for fkey in record_types.FieldTypes - if record_types.RecordFields[rkey].type == record_types.FieldTypes[fkey].name - ] - result = next((ft for ft in ftypes if id.lower() == ft.get('name').lower()), {}) - if result: - # Determine value based on whether the id matches a FieldType or RecordField - field_type_obj = next((ft for ft in record_types.FieldTypes.values() if ft.name.lower() == id.lower()), None) - - if field_type_obj: - value = getattr(field_type_obj, 'value', None) - else: - value = result.get('type', None) - - result = { - 'id': result.get('$id') or result.get('name') or '', - 'type': result.get('type') or result.get('name') or '', - 'value': value, - } - return result - - -def isEnterpriseRecordType(record_type_id: int) -> bool: - num_rts_per_scope = 1_000_000 - enterprise_scope = record_pb2.RT_ENTERPRISE - min_id = num_rts_per_scope * enterprise_scope - max_id = min_id + num_rts_per_scope - is_enterprise_rt = min_id < record_type_id <= max_id - real_type_id = record_type_id % num_rts_per_scope - - return is_enterprise_rt, real_type_id - - -def get_field_definitions(field: record_types.FieldType): - recordfield_names = {rf.name for rf in record_types.RecordFields.values()} - lookup = field.name if field.name in recordfield_names else "" - multiple = ( - record_types.RecordFields[field.name].multiple.name - if lookup else "Optional" - ) - row = [ - field.name, - lookup, - multiple, - field.description - ] - return row - - -scope_map = { - storage_types.RecordTypeScope.Standard: 'Standard', - storage_types.RecordTypeScope.User: 'User', - storage_types.RecordTypeScope.Enterprise: 'Enterprise' -} - - -def get_record_type_scope(scope: storage_types.RecordTypeScope) -> str: - return scope_map.get(scope, str(scope)) - - -def validate_record_type_file(file_path: str) -> list: - if not file_path: - raise ValueError('File path is required.') - - if not file_path.endswith('.json'): - raise ValueError('Record type file must be a JSON file.') - - try: - with open(file_path, 'r') as f: - json_obj = json.load(f) - except json.JSONDecodeError as e: - raise ValueError(f'Invalid JSON in record type file: {e}') - except FileNotFoundError: - raise ValueError(f'Record type file not found: {file_path}') - - if not isinstance(json_obj, dict): - raise ValueError('Invalid custom record types file') - - record_types_list = json_obj.get('record_types') - - if not isinstance(record_types_list, list): - raise ValueError('Invalid custom record types list') - - return record_types_list \ No newline at end of file diff --git a/keepersdk-package/unit_tests/test_record_type_management.py b/keepersdk-package/unit_tests/test_record_type_management.py index a9888c6f..b250b31c 100644 --- a/keepersdk-package/unit_tests/test_record_type_management.py +++ b/keepersdk-package/unit_tests/test_record_type_management.py @@ -124,158 +124,5 @@ def test_not_enterprise_admin(self): self.assertIn("restricted to Keeper Enterprise administrators", str(cm.exception)) -class RecordTypeInfoTestCase(unittest.TestCase): - def setUp(self): - self.vault = MagicMock() - self.vault.vault_data.get_record_type_by_name = MagicMock() - self.vault.vault_data.get_record_types = MagicMock() - - @patch('keepersdk.vault.record_type_management.tabulate') - @patch('keepersdk.vault.record_type_management.record_types') - def test_field_name_all(self, mock_record_types, mock_tabulate): - # Setup mock FieldTypes and RecordFields - mock_ft = MagicMock() - mock_ft.name = 'login' - mock_ft.description = 'desc' - mock_record_types.FieldTypes.values.return_value = [mock_ft] - mock_rf = MagicMock() - mock_rf.name = 'login' - mock_rf.multiple.name = 'Optional' - mock_record_types.RecordFields.values.return_value = [mock_rf] - mock_tabulate.tabulate.return_value = 'table' - result = record_type_management.record_type_info(self.vault, field_name='*') - self.assertEqual(result, 'table') - - @patch('keepersdk.vault.record_type_management.tabulate') - @patch('keepersdk.vault.record_type_management.record_types') - def test_field_name_specific(self, mock_record_types, mock_tabulate): - mock_ft = MagicMock() - mock_ft.name = 'login' - mock_ft.description = 'desc' - mock_record_types.FieldTypes.get.return_value = mock_ft - mock_rf = MagicMock() - mock_rf.name = 'login' - mock_rf.multiple.name = 'Optional' - mock_record_types.RecordFields.values.return_value = [mock_rf] - mock_tabulate.tabulate.return_value = 'table' - result = record_type_management.record_type_info(self.vault, field_name='login') - self.assertEqual(result, 'table') - - @patch('keepersdk.vault.record_type_management.record_type_utils') - def test_record_type_example(self, mock_utils): - mock_utils.get_record_type_example.return_value = '{"type": "login"}' - result = record_type_management.record_type_info(self.vault, record_type_name='login', example=True) - self.assertEqual(result, '{"type": "login"}') - - @patch('keepersdk.vault.record_type_management.tabulate') - @patch('keepersdk.vault.record_type_management.record_type_utils') - def test_record_type_name_all(self, mock_utils, mock_tabulate): - mock_utils.get_record_types.return_value = [(1, 'login', 'Standard')] - mock_tabulate.tabulate.return_value = 'table' - result = record_type_management.record_type_info(self.vault, record_type_name='*') - self.assertEqual(result, 'table') - - def test_record_type_name_not_found(self): - self.vault.vault_data.get_record_type_by_name.return_value = None - with self.assertRaises(ValueError) as cm: - record_type_management.record_type_info(self.vault, record_type_name='notfound') - self.assertIn('not found', str(cm.exception)) - - @patch('keepersdk.vault.record_type_management.tabulate') - def test_record_type_name_details(self, mock_tabulate): - mock_record_type = MagicMock() - mock_record_type.id = 1 - mock_record_type.name = 'login' - mock_record_type.scope = 0 - field = MagicMock() - field.label = 'username' - mock_record_type.fields = [field] - self.vault.vault_data.get_record_type_by_name.return_value = mock_record_type - mock_tabulate.tabulate.return_value = 'table' - result = record_type_management.record_type_info(self.vault, record_type_name='login') - self.assertEqual(result, 'table') - - -class LoadRecordTypesTestCase(unittest.TestCase): - def setUp(self): - self.vault = MagicMock() - self.filepath = 'dummy.json' - self.patcher_validate = patch('keepersdk.vault.record_type_management.record_type_utils.validate_record_type_file') - self.mock_validate = self.patcher_validate.start() - self.addCleanup(self.patcher_validate.stop) - self.patcher_create = patch('keepersdk.vault.record_type_management.create_custom_record_type') - self.mock_create = self.patcher_create.start() - self.addCleanup(self.patcher_create.stop) - self.patcher_get_types = patch('keepersdk.vault.record_type_management.record_type_utils.get_record_types') - self.mock_get_types = self.patcher_get_types.start() - self.addCleanup(self.patcher_get_types.stop) - self.patcher_record_fields = patch('keepersdk.vault.record_type_management.record_types.RecordFields', {}) - self.mock_record_fields = self.patcher_record_fields.start() - self.addCleanup(self.patcher_record_fields.stop) - - def test_file_not_found(self): - self.mock_validate.side_effect = ValueError('Record type file not found: dummy.json') - with self.assertRaises(ValueError) as cm: - record_type_management.load_record_types(self.vault, self.filepath) - self.assertIn('Record type file not found', str(cm.exception)) - - def test_invalid_json(self): - self.mock_validate.side_effect = ValueError('Invalid JSON in record type file: ...') - with self.assertRaises(ValueError) as cm: - record_type_management.load_record_types(self.vault, self.filepath) - self.assertIn('Invalid JSON in record type file', str(cm.exception)) - - def test_json_not_dict(self): - self.mock_validate.side_effect = ValueError('Invalid custom record types file') - with self.assertRaises(ValueError) as cm: - record_type_management.load_record_types(self.vault, self.filepath) - self.assertIn('Invalid custom record types file', str(cm.exception)) - - def test_missing_record_types_list(self): - self.mock_validate.side_effect = ValueError('Invalid custom record types list') - with self.assertRaises(ValueError) as cm: - record_type_management.load_record_types(self.vault, self.filepath) - self.assertIn('Invalid custom record types list', str(cm.exception)) - - def test_record_types_list_not_list(self): - self.mock_validate.side_effect = ValueError('Invalid custom record types list') - with self.assertRaises(ValueError) as cm: - record_type_management.load_record_types(self.vault, self.filepath) - self.assertIn('Invalid custom record types list', str(cm.exception)) - - def test_skip_record_type_without_name(self): - self.mock_validate.return_value = [{}] - self.mock_get_types.return_value = [] - result = record_type_management.load_record_types(self.vault, self.filepath) - self.assertEqual(result, 0) - self.mock_create.assert_not_called() - - def test_skip_existing_record_type(self): - self.mock_validate.return_value = [{"record_type_name": "foo", "fields": [{"$type": "login", "$ref": "login"}]}] - mock_existing = MagicMock() - mock_existing.name = 'foo' - self.mock_get_types.return_value = [(1, 'foo', 'Enterprise')] - result = record_type_management.load_record_types(self.vault, self.filepath) - self.assertEqual(result, 0) - self.mock_create.assert_not_called() - - def test_skip_invalid_fields(self): - self.mock_validate.return_value = [{"record_type_name": "foo", "fields": [{"$type": "invalid", "$ref": "login"}]}] - self.mock_get_types.return_value = [] - with patch.dict('keepersdk.vault.record_type_management.record_types.RecordFields', {'login': MagicMock()}): - result = record_type_management.load_record_types(self.vault, self.filepath) - self.assertEqual(result, 0) - self.mock_create.assert_not_called() - - def test_successful_add(self): - self.mock_validate.return_value = [{"record_type_name": "foo", "fields": [{"$type": "login", "$ref": "login"}]}] - self.mock_get_types.return_value = [] - with patch.dict('keepersdk.vault.record_type_management.record_types.RecordFields', {'login': MagicMock()}): - self.mock_create.return_value = True - result = record_type_management.load_record_types(self.vault, self.filepath) - self.assertEqual(result, 1) - self.mock_create.assert_called_once() - - if __name__ == "__main__": unittest.main()