diff --git a/RLTest/utils.py b/RLTest/utils.py index c012ee9..9ed3f8d 100644 --- a/RLTest/utils.py +++ b/RLTest/utils.py @@ -98,6 +98,34 @@ def dicty(args): def join_lists(lists): return list(itertools.chain.from_iterable(lists)) +def _merge_by_words(explicit_str, defaultArgs): + """Merge a plain explicit arg string with defaults using word-level key matching. + For each default arg, if its key doesn't appear as a word in the explicit string, + append the entire default arg to the string. + Returns the merged string wrapped as [[merged_string]]. + """ + if not defaultArgs or not defaultArgs[0]: + return [[explicit_str]] + explicit_words_upper = [w.upper() for w in explicit_str.split()] + merged = explicit_str + for arg in defaultArgs[0]: + key = arg.split()[0].upper() + if key not in explicit_words_upper: + merged += ' ' + arg + return [[merged.strip()]] + +def _merge_by_dict(modulesArgs, defaultArgs): + """Merge structured (already-split) modulesArgs with defaults using dict-based key matching. + For each module, any default key not present in the explicit args is appended. + """ + modules_args_dict = args_list_to_dict(modulesArgs) + for imod, args_list in enumerate(defaultArgs): + for arg in args_list: + name = arg.split(' ')[0].upper() + if name not in modules_args_dict[imod]: + modulesArgs[imod] += [arg] + return modulesArgs + def fix_modulesArgs(modules, modulesArgs, defaultArgs=None, haveSeqs=True): # modulesArgs is one of the following: # None @@ -105,32 +133,24 @@ def fix_modulesArgs(modules, modulesArgs, defaultArgs=None, haveSeqs=True): # ['args ...', ...]: arg list for a single module # [['arg', ...', ...], ...]: arg strings for multiple modules - # arg string is a string of words separated by whitespace. - # arg string can be separated by semicolons into (logical) arg lists. - # semicolons can be escaped with a backslash. - # if no semicolons are present, the string is treated as space-separated key-value pairs, - # where each consecutive pair of words forms a 'KEY VALUE' arg. - # thus, 'K1 V1 K2 V2' becomes ['K1 V1', 'K2 V2'] - # an odd number of words without semicolons is an error. - # for args with multiple values, semicolons are required: - # thus, 'K1 V1; K2 V2 V3' becomes ['K1 V1', 'K2 V2 V3'] - # arg list is a list of arg strings. - # arg list starts with an arg name that can later be used for argument overriding. + # For a plain string without semicolons: + # If defaultArgs exist, merge by checking if each default key appears as + # a word in the explicit string. Missing defaults are appended. + # If no defaultArgs, keep the string as-is (no splitting needed). + # For strings with semicolons, split by semicolons and use dict-based merge. + # For list inputs, use dict-based merge. + + is_plain_str = False # tracks if input was a plain string without semicolons if type(modulesArgs) == str: - # case # 'args ...': arg string for a single module - # transformed into [['arg', ...]] parts = split_by_semicolon(modulesArgs) if len(parts) == 1: - # No semicolons found - treat as space-separated key-value pairs - words = parts[0].split() - if len(words) % 2 != 0: - print(Colors.Bred(f"Error in args: odd number of words in key-value pairs: '{modulesArgs}'. " - f"Use semicolons to separate args with multiple values (e.g. 'KEY1 V1; KEY2 V2 V3').")) - sys.exit(1) - if len(words) > 2: - parts = [f"{words[i]} {words[i + 1]}" for i in range(0, len(words), 2)] - modulesArgs = [parts] + # No semicolons - keep as plain string + is_plain_str = True + modulesArgs = [[modulesArgs.strip()]] + else: + # Has semicolons - already split + modulesArgs = [parts] elif type(modulesArgs) == list: args = [] is_list = False @@ -138,7 +158,6 @@ def fix_modulesArgs(modules, modulesArgs, defaultArgs=None, haveSeqs=True): for argx in modulesArgs: if type(argx) == list: # case [['arg', ...], ...]: arg strings for multiple modules - # already transformed into [['arg', ...], ...] if is_str: print(Colors.Bred('Error in args: %s' % str(modulesArgs))) sys.exit(1) @@ -150,7 +169,6 @@ def fix_modulesArgs(modules, modulesArgs, defaultArgs=None, haveSeqs=True): args += [argx] else: # case ['args ...', ...]: arg list for a single module - # transformed into [['arg', ...], ...] if is_list: print(Colors.Bred('Error in args: %s' % str(modulesArgs))) sys.exit(1) @@ -183,19 +201,13 @@ def fix_modulesArgs(modules, modulesArgs, defaultArgs=None, haveSeqs=True): return modulesArgs # if there are fewer defaultArgs than modulesArgs, we should bail out - # as we cannot pad the defaults with emply arg lists if defaultArgs and len(modulesArgs) > len(defaultArgs): print(Colors.Bred('Number of module args sets in Env does not match number of modules')) print(defaultArgs) print(modulesArgs) sys.exit(1) - # for each module, sync defaultArgs to modulesARgs - modules_args_dict = args_list_to_dict(modulesArgs) - for imod, args_list in enumerate(defaultArgs): - for arg in args_list: - name = arg.split(' ')[0].upper() - if name not in modules_args_dict[imod]: - modulesArgs[imod] += [arg] - - return modulesArgs + if is_plain_str: + return _merge_by_words(modulesArgs[0][0], defaultArgs) + else: + return _merge_by_dict(modulesArgs, defaultArgs) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index eaf7d0c..af01ad0 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -5,41 +5,39 @@ class TestFixModulesArgs(TestCase): - # 1. Single key-value pair string + # 1. Single key-value pair string, no defaults - kept as single string def test_single_key_value_pair(self): result = fix_modulesArgs(['/mod.so'], 'WORKERS 4') self.assertEqual(result, [['WORKERS 4']]) - # 2. Multiple key-value pairs without semicolons (new behavior) - def test_multiple_kv_pairs_no_semicolons(self): + # 2. Multiple key-value pairs without semicolons, no defaults - kept as single string + def test_multiple_kv_pairs_no_semicolons_no_defaults(self): result = fix_modulesArgs(['/mod.so'], '_FREE_RESOURCE_ON_THREAD FALSE TIMEOUT 80 WORKERS 4') - self.assertEqual(result, [['_FREE_RESOURCE_ON_THREAD FALSE', 'TIMEOUT 80', 'WORKERS 4']]) + self.assertEqual(result, [['_FREE_RESOURCE_ON_THREAD FALSE TIMEOUT 80 WORKERS 4']]) # 3. Semicolon-separated args (existing behavior) def test_semicolon_separated_args(self): result = fix_modulesArgs(['/mod.so'], 'KEY1 V1; KEY2 V2') self.assertEqual(result, [['KEY1 V1', 'KEY2 V2']]) - # 4a. Odd number of words without semicolons - should error - def test_odd_words_no_semicolons_exits(self): - with self.assertRaises(SystemExit): - fix_modulesArgs(['/mod.so'], 'FLAG TIMEOUT 80') + # 4. Odd number of words without semicolons, no defaults - kept as single string, no error + def test_odd_words_no_semicolons_no_error(self): + result = fix_modulesArgs(['/mod.so'], 'FLAG TIMEOUT 80 ') + self.assertEqual(result, [['FLAG TIMEOUT 80']]) # 4b. Odd number of words with semicolons - valid, semicolons split first def test_odd_words_with_semicolons_valid(self): result = fix_modulesArgs(['/mod.so'], 'FLAG; TIMEOUT 80') self.assertEqual(result, [['FLAG', 'TIMEOUT 80']]) - # 5a. Space-separated string overrides matching defaults, non-matching defaults added - def test_space_separated_overrides_defaults(self): + # 5a. Plain string with defaults - word-based merge, missing defaults appended + def test_plain_string_overrides_defaults(self): defaults = [['WORKERS 8', 'TIMEOUT 60', 'EXTRA 1']] result = fix_modulesArgs(['/mod.so'], 'WORKERS 4 TIMEOUT 80', defaults) - result_dict = {arg.split(' ')[0]: arg for arg in result[0]} - self.assertEqual(result_dict['WORKERS'], 'WORKERS 4') - self.assertEqual(result_dict['TIMEOUT'], 'TIMEOUT 80') - self.assertEqual(result_dict['EXTRA'], 'EXTRA 1') + # Result is a single merged string + self.assertEqual(result, [['WORKERS 4 TIMEOUT 80 EXTRA 1']]) - # 5b. Semicolon-separated string overrides matching defaults + # 5b. Semicolon-separated string overrides matching defaults (dict-based merge) def test_semicolon_separated_overrides_defaults(self): defaults = [['WORKERS 8', 'TIMEOUT 60', 'EXTRA 1']] result = fix_modulesArgs(['/mod.so'], 'WORKERS 4; TIMEOUT 80', defaults) @@ -48,14 +46,11 @@ def test_semicolon_separated_overrides_defaults(self): self.assertEqual(result_dict['TIMEOUT'], 'TIMEOUT 80') self.assertEqual(result_dict['EXTRA'], 'EXTRA 1') - # 5c. Space-separated explicit overrides some defaults, non-overlapping defaults are merged - def test_space_separated_partial_override_with_defaults(self): + # 5c. Plain string partial override - missing defaults appended + def test_plain_string_partial_override_with_defaults(self): defaults = [['_FREE_RESOURCE_ON_THREAD TRUE', 'TIMEOUT 100', 'WORKERS 8']] result = fix_modulesArgs(['/mod.so'], 'WORKERS 4 TIMEOUT 80', defaults) - result_dict = {arg.split(' ')[0]: arg for arg in result[0]} - self.assertEqual(result_dict['WORKERS'], 'WORKERS 4') - self.assertEqual(result_dict['TIMEOUT'], 'TIMEOUT 80') - self.assertEqual(result_dict['_FREE_RESOURCE_ON_THREAD'], '_FREE_RESOURCE_ON_THREAD TRUE') + self.assertEqual(result, [['WORKERS 4 TIMEOUT 80 _FREE_RESOURCE_ON_THREAD TRUE']]) # 6. None input with defaults - deep copy of defaults def test_none_uses_defaults(self): @@ -66,7 +61,7 @@ def test_none_uses_defaults(self): result[0][0] = 'MODIFIED' self.assertEqual(defaults[0][0], 'WORKERS 8') - # 7. List of strings with defaults - overlapping and non-overlapping keys + # 7. List of strings with defaults - dict-based merge def test_list_of_strings_with_defaults(self): defaults = [['K1 default1', 'K2 default2', 'K4 default4']] result = fix_modulesArgs(['/mod.so'], ['K1 override1', 'K2 override2', 'K3 new3'], defaults) @@ -76,35 +71,40 @@ def test_list_of_strings_with_defaults(self): self.assertEqual(result_dict['K3'], 'K3 new3') self.assertEqual(result_dict['K4'], 'K4 default4') - # 8. List of lists (multi-module) with defaults - overlapping and non-overlapping keys + # 8. List of lists (multi-module) with defaults - dict-based merge def test_multi_module_with_defaults(self): modules = ['/mod1.so', '/mod2.so'] explicit = [['K1 v1', 'K2 v2'], ['K3 v3']] defaults = [['K1 d1', 'K5 d5'], ['K3 d3', 'K4 d4']] result = fix_modulesArgs(modules, explicit, defaults) - # Module 1: K1 overridden, K5 added from defaults dict1 = {arg.split(' ')[0]: arg for arg in result[0]} self.assertEqual(dict1['K1'], 'K1 v1') self.assertEqual(dict1['K2'], 'K2 v2') self.assertEqual(dict1['K5'], 'K5 d5') - # Module 2: K3 overridden, K4 added from defaults dict2 = {arg.split(' ')[0]: arg for arg in result[1]} self.assertEqual(dict2['K3'], 'K3 v3') self.assertEqual(dict2['K4'], 'K4 d4') + # 9. Odd words with defaults - word-based merge, flags and multi-value args handled + def test_odd_words_with_defaults(self): + defaults = [['FORK_GC_CLEAN_NUMERIC_EMPTY_NODES', 'TIMEOUT 90']] + result = fix_modulesArgs(['/mod.so'], 'workers 0 nogc FORK_GC_CLEAN_NUMERIC_EMPTY_NODES timeout 90', defaults) + self.assertEqual(result, [['workers 0 nogc FORK_GC_CLEAN_NUMERIC_EMPTY_NODES timeout 90']]) - # 9. Case-insensitive matching between explicit args and defaults (both directions) - def test_case_insensitive_override(self): - # Uppercase explicit overrides lowercase defaults - defaults = [['workers 8', 'timeout 60', 'EXTRA 1', 'MIxEd 7', 'lower true']] - result = fix_modulesArgs(['/mod.so'], 'WORKERS 4 TIMEOUT 80 miXed 0 LOWER false', defaults) - result_dict = {arg.split(' ')[0]: arg for arg in result[0]} - self.assertEqual(result_dict['WORKERS'], 'WORKERS 4') - self.assertEqual(result_dict['TIMEOUT'], 'TIMEOUT 80') - self.assertEqual(result_dict['EXTRA'], 'EXTRA 1') - self.assertEqual(result_dict['miXed'], 'miXed 0') - self.assertEqual(result_dict['LOWER'], 'LOWER false') - self.assertNotIn('workers', result_dict) - self.assertNotIn('timeout', result_dict) - self.assertNotIn('MIxEd', result_dict) - self.assertNotIn('lower', result_dict) + # 10. Plain string with defaults - unknown keys not in defaults stay, missing defaults appended + def test_plain_string_new_keys_with_defaults(self): + defaults = [['TIMEOUT 60']] + result = fix_modulesArgs(['/mod.so'], 'WORKERS 4 TIMEOUT 80', defaults) + self.assertEqual(result, [['WORKERS 4 TIMEOUT 80']]) + + # 11. Case-insensitive word matching for plain string merge + def test_case_insensitive_word_merge(self): + defaults = [['workers 8', 'TIMEOUT 60', 'EXTRA 1']] + result = fix_modulesArgs(['/mod.so'], 'WORKERS 4 timeout 80', defaults) + self.assertEqual(result, [['WORKERS 4 timeout 80 EXTRA 1']]) + + # 12. Substring key should not falsely match (GC should not match nogc) + def test_no_substring_match(self): + defaults = [['GC enabled']] + result = fix_modulesArgs(['/mod.so'], 'nogc TIMEOUT 80', defaults) + self.assertEqual(result, [['nogc TIMEOUT 80 GC enabled']])