Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paramtools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
"select_ne",
"read_json",
"get_example_paths",
"get_defaults",
"LeafGetter",
"get_leaves",
"ravel",
Expand Down
1 change: 1 addition & 0 deletions paramtools/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class InconsistentLabelsException(ParamToolsError):
"to_dict",
"_parse_validation_messages",
"sel",
"get_defaults",
]


Expand Down
14 changes: 13 additions & 1 deletion paramtools/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(
sort_values: bool = True,
**ops,
):
schemafactory = SchemaFactory(self.defaults)
schemafactory = SchemaFactory(self.get_defaults())
(
self._defaults_schema,
self._validator_schema,
Expand Down Expand Up @@ -1403,3 +1403,15 @@ def keyfunc(vo, label, label_values):
)[param]
setattr(self, param, sorted_values)
return data

def get_defaults(self):
"""
Hook for implementing custom behavior for getting the default parameters.


**Returns**

- `params`: String if URL or file path. Dict if this is the loaded params
dict.
"""
return utils.read_json(self.defaults)
2 changes: 0 additions & 2 deletions paramtools/schema_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
get_param_schema,
ParamToolsSchema,
)
from paramtools import utils


class SchemaFactory:
Expand All @@ -26,7 +25,6 @@ class SchemaFactory:
"""

def __init__(self, defaults):
defaults = utils.read_json(defaults)
self.defaults = {k: v for k, v in defaults.items() if k != "schema"}
self.schema = ParamToolsSchema().load(defaults.get("schema", {}))
(self.BaseParamSchema, self.label_validators) = get_param_schema(
Expand Down
30 changes: 30 additions & 0 deletions paramtools/tests/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,36 @@ class Params(Parameters):
assert params.hello_world == "hello world"
assert params.label_grid == {}

def test_get_defaults_override(self):
class Params(Parameters):
array_first = True
defaults = {
"schema": {
"labels": {
"somelabel": {
"type": "int",
"validators": {"range": {"min": 0, "max": 2}},
}
}
},
"hello_world": {
"title": "Hello, World!",
"description": "Simplest config possible.",
"type": "str",
"value": "hello world",
},
}

def get_defaults(self):
label = self.defaults["schema"]["labels"]["somelabel"]
label["validators"]["range"]["max"] = 5
return self.defaults

params = Params()
assert params.hello_world == "hello world"
assert params.label_grid == {"somelabel": [0, 1, 2, 3, 4, 5]}


def test_schema_not_dropped(self, defaults_spec_path):
with open(defaults_spec_path, "r") as f:
defaults_ = json.loads(f.read())
Expand Down