diff --git a/fire/core.py b/fire/core.py index 8ca142c7..6fd1bf7a 100644 --- a/fire/core.py +++ b/fire/core.py @@ -78,7 +78,7 @@ def main(argv): import asyncio # pylint: disable=import-error,g-import-not-at-top # pytype: disable=import-error -def Fire(component=None, command=None, name=None): +def Fire(component=None, command=None, name=None, serialize=None): """This function, Fire, is the main entrypoint for Python Fire. Executes a command either from the `command` argument or from sys.argv by @@ -164,7 +164,7 @@ def Fire(component=None, command=None, name=None): raise FireExit(0, component_trace) # The command succeeded normally; print the result. - _PrintResult(component_trace, verbose=component_trace.verbose) + _PrintResult(component_trace, verbose=component_trace.verbose, serialize=serialize) result = component_trace.GetResult() return result @@ -241,12 +241,19 @@ def _IsHelpShortcut(component_trace, remaining_args): return show_help -def _PrintResult(component_trace, verbose=False): +def _PrintResult(component_trace, verbose=False, serialize=None): """Prints the result of the Fire call to stdout in a human readable way.""" # TODO(dbieber): Design human readable deserializable serialization method # and move serialization to its own module. result = component_trace.GetResult() + # Allow users to modify the return value of the component and provide + # custom formatting. + if serialize: + if not callable(serialize): + raise FireError("serialize argument {} must be empty or callable.".format(serialize)) + result = serialize(result) + if value_types.HasCustomStr(result): # If the object has a custom __str__ method, rather than one inherited from # object, then we use that to serialize the object. diff --git a/fire/core_test.py b/fire/core_test.py index 27c9f418..a0576ee9 100644 --- a/fire/core_test.py +++ b/fire/core_test.py @@ -194,6 +194,30 @@ def testClassMethod(self): 7, ) + def testCustomSerialize(self): + def serialize(x): + if isinstance(x, list): + return ', '.join(str(xi) for xi in x) + if isinstance(x, dict): + return ', '.join('{}={!r}'.format(k, v) for k, v in x.items()) + if x == 'special': + return ['SURPRISE!!', "I'm a list!"] + return x + + ident = lambda x: x + + with self.assertOutputMatches(stdout='a, b', stderr=None): + result = core.Fire(ident, command=['[a,b]'], serialize=serialize) + with self.assertOutputMatches(stdout='a=5, b=6', stderr=None): + result = core.Fire(ident, command=['{a:5,b:6}'], serialize=serialize) + with self.assertOutputMatches(stdout='asdf', stderr=None): + result = core.Fire(ident, command=['asdf'], serialize=serialize) + with self.assertOutputMatches(stdout="SURPRISE!!\nI'm a list!\n", stderr=None): + result = core.Fire(ident, command=['special'], serialize=serialize) + with self.assertRaises(core.FireError): + core.Fire(ident, command=['asdf'], serialize=55) + + @testutils.skipIf(six.PY2, 'lru_cache is Python 3 only.') def testLruCacheDecoratorBoundArg(self): self.assertEqual(