diff --git a/neat/common.py b/neat/common.py index 4ac3112..76b1261 100644 --- a/neat/common.py +++ b/neat/common.py @@ -20,6 +20,7 @@ from neat.contracts_extra import * import os import time +import json from neat.config import * from neat.db_utils import * @@ -208,4 +209,16 @@ def call_function_by_name(name, args): return getattr(m, function)(*args) +@contract +def parse_parameters(params): + """ Parse algorithm parameters from the config file. + :param params: JSON encoded parameters. + :type params: str + + :return: A dict of parameters. + :rtype: dict(str: *) + """ + return dict((str(k), v) + for k, v in json.loads(params).items()) + diff --git a/tests/test_common.py b/tests/test_common.py index 8db7a21..0a1cfb2 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -138,3 +138,8 @@ class Common(TestCase): expect(common).func_to_call(arg1, arg2).and_return('res').once() assert common.call_function_by_name('neat.common.func_to_call', [arg1, arg2]) == 'res' + + def test_parse_parameters(self): + params = '{"param1": 0.56, "param2": "abc"}' + self.assertEqual(common.parse_parameters(params), {'param1': 0.56, + 'param2': 'abc'})