Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 46 additions & 34 deletions RLTest/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,47 +98,66 @@ def dicty(args):
def join_lists(lists):
return list(itertools.chain.from_iterable(lists))

def _merge_by_words(explicit_str, defaultArgs):
"""Merge a plain explicit arg string with defaults using word-level key matching.
For each default arg, if its key doesn't appear as a word in the explicit string,
append the entire default arg to the string.
Returns the merged string wrapped as [[merged_string]].
"""
if not defaultArgs or not defaultArgs[0]:
return [[explicit_str]]
explicit_words_upper = [w.upper() for w in explicit_str.split()]
merged = explicit_str
for arg in defaultArgs[0]:
key = arg.split()[0].upper()
if key not in explicit_words_upper:
merged += ' ' + arg
return [[merged.strip()]]

def _merge_by_dict(modulesArgs, defaultArgs):
"""Merge structured (already-split) modulesArgs with defaults using dict-based key matching.
For each module, any default key not present in the explicit args is appended.
"""
modules_args_dict = args_list_to_dict(modulesArgs)
for imod, args_list in enumerate(defaultArgs):
for arg in args_list:
name = arg.split(' ')[0].upper()
if name not in modules_args_dict[imod]:
modulesArgs[imod] += [arg]
return modulesArgs

def fix_modulesArgs(modules, modulesArgs, defaultArgs=None, haveSeqs=True):
# modulesArgs is one of the following:
# None
# 'args ...': arg string for a single module
# ['args ...', ...]: arg list for a single module
# [['arg', ...', ...], ...]: arg strings for multiple modules

# arg string is a string of words separated by whitespace.
# arg string can be separated by semicolons into (logical) arg lists.
# semicolons can be escaped with a backslash.
# if no semicolons are present, the string is treated as space-separated key-value pairs,
# where each consecutive pair of words forms a 'KEY VALUE' arg.
# thus, 'K1 V1 K2 V2' becomes ['K1 V1', 'K2 V2']
# an odd number of words without semicolons is an error.
# for args with multiple values, semicolons are required:
# thus, 'K1 V1; K2 V2 V3' becomes ['K1 V1', 'K2 V2 V3']
# arg list is a list of arg strings.
# arg list starts with an arg name that can later be used for argument overriding.
# For a plain string without semicolons:
# If defaultArgs exist, merge by checking if each default key appears as
# a word in the explicit string. Missing defaults are appended.
# If no defaultArgs, keep the string as-is (no splitting needed).
# For strings with semicolons, split by semicolons and use dict-based merge.
# For list inputs, use dict-based merge.

is_plain_str = False # tracks if input was a plain string without semicolons

if type(modulesArgs) == str:
# case # 'args ...': arg string for a single module
# transformed into [['arg', ...]]
parts = split_by_semicolon(modulesArgs)
if len(parts) == 1:
# No semicolons found - treat as space-separated key-value pairs
words = parts[0].split()
if len(words) % 2 != 0:
print(Colors.Bred(f"Error in args: odd number of words in key-value pairs: '{modulesArgs}'. "
f"Use semicolons to separate args with multiple values (e.g. 'KEY1 V1; KEY2 V2 V3')."))
sys.exit(1)
if len(words) > 2:
parts = [f"{words[i]} {words[i + 1]}" for i in range(0, len(words), 2)]
modulesArgs = [parts]
# No semicolons - keep as plain string
is_plain_str = True
modulesArgs = [[modulesArgs.strip()]]
else:
# Has semicolons - already split
modulesArgs = [parts]
elif type(modulesArgs) == list:
args = []
is_list = False
is_str = False
for argx in modulesArgs:
if type(argx) == list:
# case [['arg', ...], ...]: arg strings for multiple modules
# already transformed into [['arg', ...], ...]
if is_str:
print(Colors.Bred('Error in args: %s' % str(modulesArgs)))
sys.exit(1)
Expand All @@ -150,7 +169,6 @@ def fix_modulesArgs(modules, modulesArgs, defaultArgs=None, haveSeqs=True):
args += [argx]
else:
# case ['args ...', ...]: arg list for a single module
# transformed into [['arg', ...], ...]
if is_list:
print(Colors.Bred('Error in args: %s' % str(modulesArgs)))
sys.exit(1)
Expand Down Expand Up @@ -183,19 +201,13 @@ def fix_modulesArgs(modules, modulesArgs, defaultArgs=None, haveSeqs=True):
return modulesArgs

# if there are fewer defaultArgs than modulesArgs, we should bail out
# as we cannot pad the defaults with emply arg lists
if defaultArgs and len(modulesArgs) > len(defaultArgs):
print(Colors.Bred('Number of module args sets in Env does not match number of modules'))
print(defaultArgs)
print(modulesArgs)
sys.exit(1)

# for each module, sync defaultArgs to modulesARgs
modules_args_dict = args_list_to_dict(modulesArgs)
for imod, args_list in enumerate(defaultArgs):
for arg in args_list:
name = arg.split(' ')[0].upper()
if name not in modules_args_dict[imod]:
modulesArgs[imod] += [arg]

return modulesArgs
if is_plain_str:
return _merge_by_words(modulesArgs[0][0], defaultArgs)
else:
return _merge_by_dict(modulesArgs, defaultArgs)
80 changes: 40 additions & 40 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,41 +5,39 @@

class TestFixModulesArgs(TestCase):

# 1. Single key-value pair string
# 1. Single key-value pair string, no defaults - kept as single string
def test_single_key_value_pair(self):
result = fix_modulesArgs(['/mod.so'], 'WORKERS 4')
self.assertEqual(result, [['WORKERS 4']])

# 2. Multiple key-value pairs without semicolons (new behavior)
def test_multiple_kv_pairs_no_semicolons(self):
# 2. Multiple key-value pairs without semicolons, no defaults - kept as single string
def test_multiple_kv_pairs_no_semicolons_no_defaults(self):
result = fix_modulesArgs(['/mod.so'], '_FREE_RESOURCE_ON_THREAD FALSE TIMEOUT 80 WORKERS 4')
self.assertEqual(result, [['_FREE_RESOURCE_ON_THREAD FALSE', 'TIMEOUT 80', 'WORKERS 4']])
self.assertEqual(result, [['_FREE_RESOURCE_ON_THREAD FALSE TIMEOUT 80 WORKERS 4']])

# 3. Semicolon-separated args (existing behavior)
def test_semicolon_separated_args(self):
result = fix_modulesArgs(['/mod.so'], 'KEY1 V1; KEY2 V2')
self.assertEqual(result, [['KEY1 V1', 'KEY2 V2']])

# 4a. Odd number of words without semicolons - should error
def test_odd_words_no_semicolons_exits(self):
with self.assertRaises(SystemExit):
fix_modulesArgs(['/mod.so'], 'FLAG TIMEOUT 80')
# 4. Odd number of words without semicolons, no defaults - kept as single string, no error
def test_odd_words_no_semicolons_no_error(self):
result = fix_modulesArgs(['/mod.so'], 'FLAG TIMEOUT 80 ')
self.assertEqual(result, [['FLAG TIMEOUT 80']])

# 4b. Odd number of words with semicolons - valid, semicolons split first
def test_odd_words_with_semicolons_valid(self):
result = fix_modulesArgs(['/mod.so'], 'FLAG; TIMEOUT 80')
self.assertEqual(result, [['FLAG', 'TIMEOUT 80']])

# 5a. Space-separated string overrides matching defaults, non-matching defaults added
def test_space_separated_overrides_defaults(self):
# 5a. Plain string with defaults - word-based merge, missing defaults appended
def test_plain_string_overrides_defaults(self):
defaults = [['WORKERS 8', 'TIMEOUT 60', 'EXTRA 1']]
result = fix_modulesArgs(['/mod.so'], 'WORKERS 4 TIMEOUT 80', defaults)
result_dict = {arg.split(' ')[0]: arg for arg in result[0]}
self.assertEqual(result_dict['WORKERS'], 'WORKERS 4')
self.assertEqual(result_dict['TIMEOUT'], 'TIMEOUT 80')
self.assertEqual(result_dict['EXTRA'], 'EXTRA 1')
# Result is a single merged string
self.assertEqual(result, [['WORKERS 4 TIMEOUT 80 EXTRA 1']])

# 5b. Semicolon-separated string overrides matching defaults
# 5b. Semicolon-separated string overrides matching defaults (dict-based merge)
def test_semicolon_separated_overrides_defaults(self):
defaults = [['WORKERS 8', 'TIMEOUT 60', 'EXTRA 1']]
result = fix_modulesArgs(['/mod.so'], 'WORKERS 4; TIMEOUT 80', defaults)
Expand All @@ -48,14 +46,11 @@ def test_semicolon_separated_overrides_defaults(self):
self.assertEqual(result_dict['TIMEOUT'], 'TIMEOUT 80')
self.assertEqual(result_dict['EXTRA'], 'EXTRA 1')

# 5c. Space-separated explicit overrides some defaults, non-overlapping defaults are merged
def test_space_separated_partial_override_with_defaults(self):
# 5c. Plain string partial override - missing defaults appended
def test_plain_string_partial_override_with_defaults(self):
defaults = [['_FREE_RESOURCE_ON_THREAD TRUE', 'TIMEOUT 100', 'WORKERS 8']]
result = fix_modulesArgs(['/mod.so'], 'WORKERS 4 TIMEOUT 80', defaults)
result_dict = {arg.split(' ')[0]: arg for arg in result[0]}
self.assertEqual(result_dict['WORKERS'], 'WORKERS 4')
self.assertEqual(result_dict['TIMEOUT'], 'TIMEOUT 80')
self.assertEqual(result_dict['_FREE_RESOURCE_ON_THREAD'], '_FREE_RESOURCE_ON_THREAD TRUE')
self.assertEqual(result, [['WORKERS 4 TIMEOUT 80 _FREE_RESOURCE_ON_THREAD TRUE']])

# 6. None input with defaults - deep copy of defaults
def test_none_uses_defaults(self):
Expand All @@ -66,7 +61,7 @@ def test_none_uses_defaults(self):
result[0][0] = 'MODIFIED'
self.assertEqual(defaults[0][0], 'WORKERS 8')

# 7. List of strings with defaults - overlapping and non-overlapping keys
# 7. List of strings with defaults - dict-based merge
def test_list_of_strings_with_defaults(self):
defaults = [['K1 default1', 'K2 default2', 'K4 default4']]
result = fix_modulesArgs(['/mod.so'], ['K1 override1', 'K2 override2', 'K3 new3'], defaults)
Expand All @@ -76,35 +71,40 @@ def test_list_of_strings_with_defaults(self):
self.assertEqual(result_dict['K3'], 'K3 new3')
self.assertEqual(result_dict['K4'], 'K4 default4')

# 8. List of lists (multi-module) with defaults - overlapping and non-overlapping keys
# 8. List of lists (multi-module) with defaults - dict-based merge
def test_multi_module_with_defaults(self):
modules = ['/mod1.so', '/mod2.so']
explicit = [['K1 v1', 'K2 v2'], ['K3 v3']]
defaults = [['K1 d1', 'K5 d5'], ['K3 d3', 'K4 d4']]
result = fix_modulesArgs(modules, explicit, defaults)
# Module 1: K1 overridden, K5 added from defaults
dict1 = {arg.split(' ')[0]: arg for arg in result[0]}
self.assertEqual(dict1['K1'], 'K1 v1')
self.assertEqual(dict1['K2'], 'K2 v2')
self.assertEqual(dict1['K5'], 'K5 d5')
# Module 2: K3 overridden, K4 added from defaults
dict2 = {arg.split(' ')[0]: arg for arg in result[1]}
self.assertEqual(dict2['K3'], 'K3 v3')
self.assertEqual(dict2['K4'], 'K4 d4')

# 9. Odd words with defaults - word-based merge, flags and multi-value args handled
def test_odd_words_with_defaults(self):
defaults = [['FORK_GC_CLEAN_NUMERIC_EMPTY_NODES', 'TIMEOUT 90']]
result = fix_modulesArgs(['/mod.so'], 'workers 0 nogc FORK_GC_CLEAN_NUMERIC_EMPTY_NODES timeout 90', defaults)
self.assertEqual(result, [['workers 0 nogc FORK_GC_CLEAN_NUMERIC_EMPTY_NODES timeout 90']])

# 9. Case-insensitive matching between explicit args and defaults (both directions)
def test_case_insensitive_override(self):
# Uppercase explicit overrides lowercase defaults
defaults = [['workers 8', 'timeout 60', 'EXTRA 1', 'MIxEd 7', 'lower true']]
result = fix_modulesArgs(['/mod.so'], 'WORKERS 4 TIMEOUT 80 miXed 0 LOWER false', defaults)
result_dict = {arg.split(' ')[0]: arg for arg in result[0]}
self.assertEqual(result_dict['WORKERS'], 'WORKERS 4')
self.assertEqual(result_dict['TIMEOUT'], 'TIMEOUT 80')
self.assertEqual(result_dict['EXTRA'], 'EXTRA 1')
self.assertEqual(result_dict['miXed'], 'miXed 0')
self.assertEqual(result_dict['LOWER'], 'LOWER false')
self.assertNotIn('workers', result_dict)
self.assertNotIn('timeout', result_dict)
self.assertNotIn('MIxEd', result_dict)
self.assertNotIn('lower', result_dict)
# 10. Plain string with defaults - unknown keys not in defaults stay, missing defaults appended
def test_plain_string_new_keys_with_defaults(self):
defaults = [['TIMEOUT 60']]
result = fix_modulesArgs(['/mod.so'], 'WORKERS 4 TIMEOUT 80', defaults)
self.assertEqual(result, [['WORKERS 4 TIMEOUT 80']])

# 11. Case-insensitive word matching for plain string merge
def test_case_insensitive_word_merge(self):
defaults = [['workers 8', 'TIMEOUT 60', 'EXTRA 1']]
result = fix_modulesArgs(['/mod.so'], 'WORKERS 4 timeout 80', defaults)
self.assertEqual(result, [['WORKERS 4 timeout 80 EXTRA 1']])

# 12. Substring key should not falsely match (GC should not match nogc)
def test_no_substring_match(self):
defaults = [['GC enabled']]
result = fix_modulesArgs(['/mod.so'], 'nogc TIMEOUT 80', defaults)
self.assertEqual(result, [['nogc TIMEOUT 80 GC enabled']])
Loading