Kedro – 如何直接将嵌套参数传递给节点

kedro 建议将参数存储在 conf/base/parameters.yml 中。假设它的内容如下:

step_size: 1model_params:    learning_rate: 0.01    test_data_ratio: 0.2    num_train_steps: 10000

现在假设我有一个 data_engineering 管道,其 nodes.py 中有一个函数,大致如下:

def some_pipeline_step(num_train_steps):    """    接受 `num_train_steps` 参数作为参数。    """    pass

我想知道如何在 data_engineering/pipeline.py 中直接将这些嵌套参数传递给这个函数?我尝试过以下方法但没有成功:

from kedro.pipeline import Pipeline, nodefrom .nodes import split_datadef create_pipeline(**kwargs):    return Pipeline(        [            node(                some_pipeline_step,                ["params:model_params.num_train_steps"],                dict(                    train_x="train_x",                    train_y="train_y",                ),            )        ]    )

我知道可以通过使用 ['parameters'] 将所有参数传递给函数,或者通过 ['params:model_params'] 传递所有 model_params 参数,但这看起来不够优雅,我觉得应该有更好的方法。欢迎任何建议!


回答:

(免责声明:我是 Kedro 团队的一员)

感谢你的提问。遗憾的是,Kedro 的当前版本不支持嵌套参数。临时的解决方案是如你所指出的那样在节点内使用顶级键,或者使用某种参数过滤器来装饰你的节点函数,但这同样不够优雅。

可能最可行的解决方案是通过重写 _get_feed_dict 方法来自定义你的 ProjectContext 类(在 src/<package_name>/run.py 中),如下所示:

class ProjectContext(KedroContext):    # ...    def _get_feed_dict(self) -> Dict[str, Any]:        """获取参数并返回 feed 字典。"""        params = self.params        feed_dict = {"parameters": params}        def _add_param_to_feed_dict(param_name, param_value):            """此函数递归地将参数路径添加到 `feed_dict` 中,            当 `param_value` 本身是一个字典时,以便用户能够            在节点输入中指定特定的嵌套参数。            示例:                >>> param_name = "a"                >>> param_value = {"b": 1}                >>> _add_param_to_feed_dict(param_name, param_value)                >>> assert feed_dict["params:a"] == {"b": 1}                >>> assert feed_dict["params:a.b"] == 1            """            key = "params:{}".format(param_name)            feed_dict[key] = param_value            if isinstance(param_value, dict):                for key, val in param_value.items():                    _add_param_to_feed_dict("{}.{}".format(param_name, key), val)        for param_name, param_value in params.items():            _add_param_to_feed_dict(param_name, param_value)        return feed_dict

请注意,这个问题已经在 开发分支上得到解决,并将在下一个版本中可用。修复使用了上述代码片段中的方法。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注