Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1273,7 +1273,6 @@ def ckpt_export(
config_file_,
filepath_,
ckpt_file_,
bundle_root_,
net_id_,
meta_file_,
key_in_ckpt_,
Expand All @@ -1285,26 +1284,30 @@ def ckpt_export(
"config_file",
filepath=None,
ckpt_file=None,
bundle_root=os.getcwd(),
net_id=None,
meta_file=None,
key_in_ckpt="",
use_trace=False,
input_shape=None,
converter_kwargs={},
)
bundle_root = _args.get("bundle_root", os.getcwd())

parser = ConfigParser()

parser.read_config(f=config_file_)
meta_file_ = os.path.join(bundle_root_, "configs", "metadata.json") if meta_file_ is None else meta_file_
filepath_ = os.path.join(bundle_root_, "models", "model.ts") if filepath_ is None else filepath_
ckpt_file_ = os.path.join(bundle_root_, "models", "model.pt") if ckpt_file_ is None else ckpt_file_
if not os.path.exists(ckpt_file_):
raise FileNotFoundError(f'Checkpoint file "{ckpt_file_}" not found, please specify it in argument "ckpt_file".')
meta_file_ = os.path.join(bundle_root, "configs", "metadata.json") if meta_file_ is None else meta_file_
if os.path.exists(meta_file_):
parser.read_meta(f=meta_file_)

# the rest key-values in the _args are to override config content
for k, v in _args.items():
parser[k] = v

filepath_ = os.path.join(bundle_root, "models", "model.ts") if filepath_ is None else filepath_
ckpt_file_ = os.path.join(bundle_root, "models", "model.pt") if ckpt_file_ is None else ckpt_file_
if not os.path.exists(ckpt_file_):
raise FileNotFoundError(f'Checkpoint file "{ckpt_file_}" not found, please specify it in argument "ckpt_file".')

net_id_ = "network_def" if net_id_ is None else net_id_
try:
parser.get_parsed_content(net_id_)
Expand All @@ -1313,10 +1316,6 @@ def ckpt_export(
f'Network definition "{net_id_}" cannot be found in "{config_file_}", specify name with argument "net_id".'
) from e

# the rest key-values in the _args are to override config content
for k, v in _args.items():
parser[k] = v

# When export through torch.jit.trace without providing input_shape, will try to parse one from the parser.
if (not input_shape_) and use_trace:
input_shape_ = _get_fake_input_shape(parser=parser)
Expand Down