diff --git a/ocrd_models/ocrd_models/ocrd_mets.py b/ocrd_models/ocrd_models/ocrd_mets.py index 8571254e56..b42fa1ad13 100644 --- a/ocrd_models/ocrd_models/ocrd_mets.py +++ b/ocrd_models/ocrd_models/ocrd_mets.py @@ -3,6 +3,7 @@ """ from datetime import datetime import re +import typing from lxml import etree as ET from ocrd_utils import ( @@ -159,22 +160,21 @@ def find_files(self, ID=None, fileGrp=None, pageId=None, mimetype=None, url=None Yields: :py:class:`ocrd_models:ocrd_file:OcrdFile` instantiations """ + pageId_list = [] if pageId: - if pageId.startswith(REGEX_PREFIX): - pageIds, pageId = re.compile(pageId[REGEX_PREFIX_LEN:]), list() - else: - pageIds, pageId = pageId.split(','), list() - pageIds_expanded = [] - for pageId_ in pageIds: - if '..' in pageId_: - pageIds_expanded += generate_range(*pageId_.split('..', 1)) - pageIds += pageIds_expanded + pageId_patterns = [] + for pageId_token in re.split(r',', pageId): + if pageId_token.startswith(REGEX_PREFIX): + pageId_patterns.append(re.compile(pageId_token[REGEX_PREFIX_LEN:])) + elif '..' in pageId_token: + pageId_patterns += generate_range(*pageId_token.split('..', 1)) + else: + pageId_patterns += [pageId_token] for page in self._tree.getroot().xpath( '//mets:div[@TYPE="page"]', namespaces=NS): - if (page.get('ID') in pageIds if isinstance(pageIds, list) else - pageIds.fullmatch(page.get('ID'))): - pageId.extend( - [fptr.get('FILEID') for fptr in page.findall('mets:fptr', NS)]) + if page.get('ID') in pageId_patterns or \ + any([isinstance(p, typing.Pattern) and p.fullmatch(page.get('ID')) for p in pageId_patterns]): + pageId_list += [fptr.get('FILEID') for fptr in page.findall('mets:fptr', NS)] if ID and ID.startswith(REGEX_PREFIX): ID = re.compile(ID[REGEX_PREFIX_LEN:]) if fileGrp and fileGrp.startswith(REGEX_PREFIX): @@ -190,7 +190,7 @@ def find_files(self, ID=None, fileGrp=None, pageId=None, mimetype=None, url=None else: if not ID.fullmatch(cand.get('ID')): continue - if pageId is not None and cand.get('ID') not in pageId: + if pageId is not None and cand.get('ID') not in pageId_list: continue if fileGrp: diff --git a/ocrd_utils/ocrd_utils/str.py b/ocrd_utils/ocrd_utils/str.py index 2944bb1d33..7211699f0c 100644 --- a/ocrd_utils/ocrd_utils/str.py +++ b/ocrd_utils/ocrd_utils/str.py @@ -195,10 +195,12 @@ def generate_range(start, end): Generate a list of strings by incrementing the number part of ``start`` until including ``end``. """ ret = [] - start_num, end_num = re.search(r'\d+', start), re.search(r'\d+', end) - if not (start_num and end_num): - raise ValueError("Unable to generate range %s .. %s, could not detect number part" % (start, end)) - start_num, end_num = start_num.group(0), end_num.group(0) + try: + start_num, end_num = re.findall(r'\d+', start)[-1], re.findall(r'\d+', end)[-1] + except IndexError: + raise ValueError("Range '%s..%s': could not find numeric part" % (start, end)) + if start_num == end_num: + raise ValueError("Range '%s..%s': evaluates to the same number") for i in range(int(start_num), int(end_num) + 1): ret.append(start.replace(start_num, str(i).zfill(len(start_num)))) return ret diff --git a/tests/model/test_ocrd_mets.py b/tests/model/test_ocrd_mets.py index e91ab66142..f8ea1c4fe9 100644 --- a/tests/model/test_ocrd_mets.py +++ b/tests/model/test_ocrd_mets.py @@ -74,7 +74,8 @@ def test_find_all_files(sbb_sample_01): assert len(sbb_sample_01.find_all_files(url='OCR-D-IMG/FILE_0005_IMAGE.tif')) == 1, '1 xlink:href="OCR-D-IMG/FILE_0005_IMAGE.tif"' assert len(sbb_sample_01.find_all_files(pageId='PHYS_0001..PHYS_0005')) == 35, '35 files for page "PHYS_0001..PHYS_0005"' assert len(sbb_sample_01.find_all_files(pageId='//PHYS_000(1|2)')) == 34, '34 files in PHYS_001 and PHYS_0002' - + assert len(sbb_sample_01.find_all_files(pageId='//PHYS_0001,//PHYS_0005')) == 18, '18 files in PHYS_001 and PHYS_0005 (two regexes)' + assert len(sbb_sample_01.find_all_files(pageId='//PHYS_0005,PHYS_0001..PHYS_0002')) == 35, '35 files in //PHYS_0005,PHYS_0001..PHYS_0002' def test_find_all_files_local_only(sbb_sample_01): assert len(sbb_sample_01.find_all_files(pageId='PHYS_0001', diff --git a/tests/test_utils.py b/tests/test_utils.py index 467724bd40..b04a9a2722 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -315,8 +315,10 @@ def test_make_file_id_744(self): def test_generate_range(self): assert generate_range('PHYS_0001', 'PHYS_0005') == ['PHYS_0001', 'PHYS_0002', 'PHYS_0003', 'PHYS_0004', 'PHYS_0005'] - with self.assertRaisesRegex(ValueError, 'Unable to generate range'): + with self.assertRaisesRegex(ValueError, 'could not find numeric part'): generate_range('NONUMBER', 'ALSO_NONUMBER') + with self.assertRaisesRegex(ValueError, 'evaluates to the same number'): + generate_range('PHYS_0001_123', 'PHYS_0010_123') def test_safe_filename(self): assert safe_filename('Hello world,!') == 'Hello_world_'