diff --git a/pyodi/apps/coco/coco_split.py b/pyodi/apps/coco/coco_split.py index 7b277f1..e997f5f 100644 --- a/pyodi/apps/coco/coco_split.py +++ b/pyodi/apps/coco/coco_split.py @@ -133,15 +133,15 @@ def property_split( train_split = { "images": train_images, "annotations": train_annotations, - "info": data["info"], - "licenses": data["licenses"], + "info": data.get("info", {}), + "licenses": data.get("licenses", []), "categories": data["categories"], } val_split = { "images": val_images, "annotations": val_annotations, - "info": data["info"], - "licenses": data["licenses"], + "info": data.get("info", {}), + "licenses": data.get("licenses", []), "categories": data["categories"], } @@ -202,16 +202,16 @@ def random_split( train_split = { "images": train_images, "annotations": train_annotations, - "info": data["info"], - "licenses": data["licenses"], + "info": data.get("info", {}), + "licenses": data.get("licenses", []), "categories": data["categories"], } val_split = { "images": val_images, "annotations": val_annotations, - "info": data["info"], - "licenses": data["licenses"], + "info": data.get("info", {}), + "licenses": data.get("licenses", []), "categories": data["categories"], } diff --git a/tests/apps/test_coco_split.py b/tests/apps/test_coco_split.py index 136a447..12f8313 100644 --- a/tests/apps/test_coco_split.py +++ b/tests/apps/test_coco_split.py @@ -1,6 +1,8 @@ import json from pathlib import Path +import pytest + from pyodi.apps.coco.coco_split import property_split, random_split @@ -102,3 +104,41 @@ def test_property_coco_split(tmpdir): assert len(val_data["images"]) == 4 assert len(train_data["annotations"]) == 1 assert len(val_data["annotations"]) == 8 + + +@pytest.mark.parametrize("split_type", ["random", "property"]) +def test_split_without_info_and_licenses(tmpdir, split_type): + tmpdir = Path(tmpdir) + + coco_data = get_coco_data() + coco_data.pop("licenses") + coco_data.pop("info") + + assert "licenses" not in coco_data + assert "info" not in coco_data + + json.dump(coco_data, open(tmpdir / "coco.json", "w")) + + if split_type == "random": + train_path, val_path = random_split( + annotations_file=str(tmpdir / "coco.json"), + output_filename=str(tmpdir / "random_coco_split"), + val_percentage=0.25, + seed=49, + ) + else: + config = dict( + val={"file_name": {"frame 0": "vidA-0.jpg", "frame 1": "vidA-1.jpg"}} + ) + json.dump(config, open(tmpdir / "split_config.json", "w")) + + train_path, val_path = property_split( + annotations_file=str(tmpdir / "coco.json"), + output_filename=str(tmpdir / "property_coco_split"), + split_config_file=str(tmpdir / "split_config.json"), + ) + + for path in [train_path, val_path]: + data = json.load(open(path)) + assert data["licenses"] == [] + assert data["info"] == {}