diff --git a/monai/apps/manifest/config_item.py b/monai/apps/manifest/config_item.py similarity index 100% rename from monai/apps/manifest/config_item.py rename to monai/apps/manifest/config_item.py diff --git a/monai/apps/manifest/export.py b/monai/apps/manifest/export.py new file mode 100644 index 0000000000..acfeb98291 --- /dev/null +++ b/monai/apps/manifest/export.py @@ -0,0 +1,59 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json + +import torch +from monai.apps import ConfigParser +from ignite.handlers import Checkpoint +from monai.data import save_net_with_metadata +from monai.networks import convert_to_torchscript + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--weights', '-w', type=str, help='file path of the trained model weights', required=True) + parser.add_argument('--config', '-c', type=str, help='file path of config file that defines network', required=True) + parser.add_argument('--meta', '-e', type=str, help='file path of the meta data') + args = parser.parse_args() + + # load config file + with open(args.config) as f: + config_dict = json.load(f) + # load meta data + with open(args.meta) as f: + meta_dict = json.load(f) + + net: torch.nn.Module = None + # TODO: parse network definiftion from config file and construct network instance + config_parser = ConfigParser(config_dict) + net = config_parser.get_instance("network") + + checkpoint = torch.load(args.weights) + # here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver + Checkpoint.load_objects(to_load={"model": net}, checkpoint=checkpoint) + + # convert to TorchScript model and save with meta data + net = convert_to_torchscript(model=net) + + save_net_with_metadata( + jit_obj=net, + filename_prefix_or_stream="model.ts", + include_config_vals=False, + append_timestamp=False, + meta_values=meta_dict, + more_extra_files={args.config: json.dumps(config_dict).encode()}, + ) + + +if __name__ == '__main__': + main() diff --git a/monai/apps/manifest/inference.py b/monai/apps/manifest/inference.py new file mode 100644 index 0000000000..41b9792869 --- /dev/null +++ b/monai/apps/manifest/inference.py @@ -0,0 +1,68 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json + +import torch +from monai.apps import ConfigParser +from monai.data import decollate_batch +from monai.inferers import Inferer +from monai.transforms import Transform +from monai.utils.enums import CommonKeys + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--config', '-c', type=str, help='config file that defines components', required=True) + parser.add_argument('--meta', '-e', type=str, help='file path of the meta data') + parser.add_argument('--override', '-o', type=str, help='config file that override components', required=False) + args = parser.parse_args() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + configs = {} + + # load meta data + with open(args.meta) as f: + configs.update(json.load(f)) + # load config file, can override meta data in config + with open(args.config) as f: + configs.update(json.load(f)) + + model: torch.nn.Module = None + dataloader: torch.utils.data.DataLoader = None + inferer: Inferer = None + postprocessing: Transform = None + # TODO: parse inference config file and construct instances + config_parser = ConfigParser(configs) + + # change JSON config content in python code, lazy instantiation + model_conf = config_parser.get_config("model") + model_conf["disabled"] = False + model = config_parser.build(model_conf).to(device) + + # instantialize the components immediately + dataloader = config_parser.get_instance("dataloader") + inferer = config_parser.get_instance("inferer") + postprocessing = config_parser.get_instance("postprocessing") + + model.eval() + with torch.no_grad(): + for d in dataloader: + images = d[CommonKeys.IMAGE].to(device) + # define sliding window size and batch size for windows inference + d[CommonKeys.PRED] = inferer(inputs=images, predictor=model) + # decollate the batch data into a list of dictionaries, then execute postprocessing transforms + [postprocessing(i) for i in decollate_batch(d)] + + +if __name__ == '__main__': + main() diff --git a/monai/apps/manifest/schema/metadata.json b/monai/apps/manifest/schema/metadata.json new file mode 100644 index 0000000000..babee8b30e --- /dev/null +++ b/monai/apps/manifest/schema/metadata.json @@ -0,0 +1,71 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://monai.io/mmar_metadata_schema.json", + "title": "metadata", + "description": "metadata that defines the context information for MMAR.", + "type": "object", + "properties": { + "version": { + "description": "version number of this MMAR.", + "type": "string" + }, + "monai_version": { + "description": "version number of MONAI used in this MMAR.", + "type": "string" + }, + "pytorch_version": { + "description": "version number of PyTorch used in this MMAR.", + "type": "string" + }, + "numpy_version": { + "description": "version number of MONAI used in this MMAR.", + "type": "string" + }, + "network_data_format": { + "description": "define the input and output data format for network.", + "type": "object", + "properties": { + "inputs": { + "type": "object", + "properties": { + "image": { + "type": "object", + "properties": { + "type": { + "type": "string", + "description": "the data format for `image`." + }, + "format": { + "type": "string" + }, + "num_channels": { + "type": "integer", + "minimum": 1 + }, + "spatial_shape": { + "type": "array", + "items": { + "type": "integer", + "minumum": 1 + } + }, + "dtype": { + "type": "string" + }, + "value_range": { + "type": "array", + "items": { + "type": "number", + "unuqueItems": true + } + }, + "required": ["num_channels", "spatial_shape", "value_range"] + } + } + } + } + } + }, + "required": ["monai_version", "pytorch_version", "network_data_format"] + } +} diff --git a/monai/apps/manifest/verify_network.py b/monai/apps/manifest/verify_network.py new file mode 100644 index 0000000000..c0dbf178dc --- /dev/null +++ b/monai/apps/manifest/verify_network.py @@ -0,0 +1,60 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json + +import torch +from monai.apps import ConfigParser +from monai.utils.type_conversion import get_equivalent_dtype + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--config', '-c', type=str, help='config file that defines components', required=True) + parser.add_argument('--meta', '-e', type=str, help='file path of the meta data') + args = parser.parse_args() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + configs = {} + + # load meta data + with open(args.meta) as f: + configs.update(json.load(f)) + # load config file, can override meta data in config + with open(args.config) as f: + configs.update(json.load(f)) + + model: torch.nn.Module = None + # TODO: parse inference config file and construct instances + config_parser = ConfigParser(configs) + + model = config_parser.get_instance("model") + input_channels = config_parser.get_config("network_data_format#inputs#image#num_channels") + input_spatial_shape = tuple(config_parser.get_config("network_data_format#inputs#image#spatial_shape")) + dtype = config_parser.get_config("network_data_format#inputs#image#dtype") + dtype = get_equivalent_dtype(dtype, data_type=torch.Tensor) + + output_channels = config_parser.get_config("network_data_format#outputs#pred#num_channels") + output_spatial_shape = tuple(config_parser.get_config("network_data_format#outputs#pred#spatial_shape")) + + model.eval() + with torch.no_grad(): + test_data = torch.rand(*(input_channels, *input_spatial_shape), dtype=dtype, device=device) + output = model(test_data) + if output.shape[0] != output_channels: + raise ValueError(f"channel number of output data doesn't match expection: {output_channels}.") + if output.shape[1:] != output_spatial_shape: + raise ValueError(f"spatial shape of output data doesn't match expection: {output_spatial_shape}.") + + +if __name__ == '__main__': + main() diff --git a/tests/test_component_locator.py b/tests/test_component_locator.py deleted file mode 100644 index eafb2152d1..0000000000 --- a/tests/test_component_locator.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -from pydoc import locate - -from monai.apps.manifest import ComponentLocator -from monai.utils import optional_import - -_, has_ignite = optional_import("ignite") - - -class TestComponentLocator(unittest.TestCase): - def test_locate(self): - locator = ComponentLocator(excludes=None if has_ignite else ["monai.handlers"]) - # test init mapping table and get the module path of component - self.assertEqual(locator.get_component_module_name("LoadImage"), "monai.transforms.io.array") - self.assertGreater(len(locator._components_table), 0) - for _, mods in locator._components_table.items(): - for i in mods: - self.assertGreater(len(mods), 0) - # ensure we can locate all the items by `name` - self.assertIsNotNone(locate(i), msg=f"can not locate target: {i}.") - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_config_item.py b/tests/test_config_item.py deleted file mode 100644 index b2c2fec6c6..0000000000 --- a/tests/test_config_item.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -from functools import partial -from typing import Callable - -import torch -from parameterized import parameterized - -import monai -from monai.apps import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem -from monai.data import DataLoader, Dataset -from monai.transforms import LoadImaged, RandTorchVisiond -from monai.utils import optional_import - -_, has_tv = optional_import("torchvision") - -TEST_CASE_1 = [{"lr": 0.001}, 0.0001] - -TEST_CASE_2 = [{"": "LoadImaged", "": {"keys": ["image"]}}, LoadImaged] -# test python `` -TEST_CASE_3 = [{"": "monai.transforms.LoadImaged", "": {"keys": ["image"]}}, LoadImaged] -# test `` -TEST_CASE_4 = [{"": "LoadImaged", "": True, "": {"keys": ["image"]}}, dict] -# test `` -TEST_CASE_5 = [{"": "LoadImaged", "": "true", "": {"keys": ["image"]}}, dict] -# test non-monai modules and excludes -TEST_CASE_6 = [ - {"": "torch.optim.Adam", "": {"params": torch.nn.PReLU().parameters(), "lr": 1e-4}}, - torch.optim.Adam, -] -TEST_CASE_7 = [{"": "decollate_batch", "": {"detach": True, "pad": True}}, partial] -# test args contains "name" field -TEST_CASE_8 = [ - {"": "RandTorchVisiond", "": {"keys": "image", "name": "ColorJitter", "brightness": 0.25}}, - RandTorchVisiond, -] -# test execute some function in args, test pre-imported global packages `monai` -TEST_CASE_9 = ["collate_fn", "$monai.data.list_data_collate"] -# test lambda function, should not execute the lambda function, just change the string -TEST_CASE_10 = ["collate_fn", "$lambda x: monai.data.list_data_collate(x) + torch.tensor(var)"] - - -class TestConfigItem(unittest.TestCase): - @parameterized.expand([TEST_CASE_1]) - def test_item(self, test_input, expected): - item = ConfigItem(config=test_input) - conf = item.get_config() - conf["lr"] = 0.0001 - item.update_config(config=conf) - self.assertEqual(item.get_config()["lr"], expected) - - @parameterized.expand( - [TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7] - + ([TEST_CASE_8] if has_tv else []) - ) - def test_component(self, test_input, output_type): - locator = ComponentLocator(excludes=["metrics"]) - configer = ConfigComponent(id="test", config=test_input, locator=locator) - ret = configer.instantiate() - if test_input.get("", False): - # test `` works fine - self.assertEqual(ret, None) - return - self.assertTrue(isinstance(ret, output_type)) - if isinstance(ret, LoadImaged): - self.assertEqual(ret.keys[0], "image") - - @parameterized.expand([TEST_CASE_9, TEST_CASE_10]) - def test_expression(self, id, test_input): - configer = ConfigExpression(id=id, config=test_input, globals={"monai": monai, "torch": torch}) - var = 100 - ret = configer.evaluate(locals={"var": var}) - self.assertTrue(isinstance(ret, Callable)) - - def test_lazy_instantiation(self): - config = {"": "DataLoader", "": {"dataset": Dataset(data=[1, 2]), "batch_size": 2}} - configer = ConfigComponent(config=config, locator=None) - init_config = configer.get_config() - # modify config content at runtime - init_config[""]["batch_size"] = 4 - configer.update_config(config=init_config) - - ret = configer.instantiate() - self.assertTrue(isinstance(ret, DataLoader)) - self.assertEqual(ret.batch_size, 4) - - -if __name__ == "__main__": - unittest.main()