kedro
建议将参数存储在 conf/base/parameters.yml
中。假设它的内容如下:
step_size: 1model_params: learning_rate: 0.01 test_data_ratio: 0.2 num_train_steps: 10000
现在假设我有一个 data_engineering
管道,其 nodes.py
中有一个函数,大致如下:
def some_pipeline_step(num_train_steps): """ 接受 `num_train_steps` 参数作为参数。 """ pass
我想知道如何在 data_engineering/pipeline.py
中直接将这些嵌套参数传递给这个函数?我尝试过以下方法但没有成功:
from kedro.pipeline import Pipeline, nodefrom .nodes import split_datadef create_pipeline(**kwargs): return Pipeline( [ node( some_pipeline_step, ["params:model_params.num_train_steps"], dict( train_x="train_x", train_y="train_y", ), ) ] )
我知道可以通过使用 ['parameters']
将所有参数传递给函数,或者通过 ['params:model_params']
传递所有 model_params
参数,但这看起来不够优雅,我觉得应该有更好的方法。欢迎任何建议!
回答:
(免责声明:我是 Kedro 团队的一员)
感谢你的提问。遗憾的是,Kedro 的当前版本不支持嵌套参数。临时的解决方案是如你所指出的那样在节点内使用顶级键,或者使用某种参数过滤器来装饰你的节点函数,但这同样不够优雅。
可能最可行的解决方案是通过重写 _get_feed_dict
方法来自定义你的 ProjectContext
类(在 src/<package_name>/run.py
中),如下所示:
class ProjectContext(KedroContext): # ... def _get_feed_dict(self) -> Dict[str, Any]: """获取参数并返回 feed 字典。""" params = self.params feed_dict = {"parameters": params} def _add_param_to_feed_dict(param_name, param_value): """此函数递归地将参数路径添加到 `feed_dict` 中, 当 `param_value` 本身是一个字典时,以便用户能够 在节点输入中指定特定的嵌套参数。 示例: >>> param_name = "a" >>> param_value = {"b": 1} >>> _add_param_to_feed_dict(param_name, param_value) >>> assert feed_dict["params:a"] == {"b": 1} >>> assert feed_dict["params:a.b"] == 1 """ key = "params:{}".format(param_name) feed_dict[key] = param_value if isinstance(param_value, dict): for key, val in param_value.items(): _add_param_to_feed_dict("{}.{}".format(param_name, key), val) for param_name, param_value in params.items(): _add_param_to_feed_dict(param_name, param_value) return feed_dict
请注意,这个问题已经在 开发分支上得到解决,并将在下一个版本中可用。修复使用了上述代码片段中的方法。