使用递归的机器学习算法

我目前正在开发ID3机器学习算法的一个非常初级的版本。我在如何递归调用我的build_tree函数来实际构建决策树的其余部分并以一种好的格式输出它上遇到了困难。我已经计算了增益、熵、增益比等,但我不知道如何将递归整合到我的函数中。

我有一个数据集,在完成上述所有计算后,我已经将其分成了两个数据集。现在我需要能够递归调用它,直到左右两个数据集都变得纯净(这可以通过我编写的名为dataset.is_pure()的函数轻松检查),同时跟踪每个节点的阈值。我知道我的所有计算和分割方法都是有效的,因为我已经对它们进行了单独的测试。只是递归部分让我感到困惑。

这是我正在与递归噩梦作斗争的build_tree函数。我目前在Linux环境下使用g++编译器工作。我现在的代码可以编译,但运行时会给我一个分段错误。任何和所有的帮助将不胜感激!

   struct node    {            vector<vector<string>> data;            double atrb;            node* parent;            node* left = NULL;            node* right = NULL;            node(node* parent) : parent(parent) {}    };    node* root = new node(NULL);    void build_tree(node* current, dataset data_set)    {            vector<vector<string>> l_d;            vector<vector<string>> r_d;            double global_entropy = calc_entropy(data_set.get_col(data_set.n_col()-1));            int best_col = this->get_best_col(data_set, global_entropy);            hash_map selected_atrb(data_set.n_row(), data_set.truncate(best_col));            double threshold = get_threshold(selected_atrb, global_entropy);            cout << threshold << "\n";            split_data(threshold, best_col, data_set, l_d, r_d);            dataset right_data(r_d);            dataset left_data(l_d);            right_data.delete_col(best_col);            left_data.delete_col(best_col);            if(left_data.is_pure())                    return;            else            {                    node* new_left = new node(current);                    new_left->atrb = threshold;                    current->left = new_left;                    new_left->data = l_d;                    return build_tree(new_left, left_data);            }            if(right_data.is_pure())                    return;            else            {                    node* new_right = new node(current);                    new_right->atrb = threshold;                    current->right = new_right;                    new_right->data = r_d;                    return build_tree(new_right, right_data);            }    }    id3(dataset data)    {            build_tree(root, data);    }

};

这只是我类的一部分。如果您想查看其他任何代码,请告诉我!


回答:

此致,

我将用伪代码向您解释递归函数的工作原理,我还将为您提供在JavaScript中实现该算法的代码。

在深入细节之前,我将提到您使用的某些概念和类别。

  • 属性: 数据集的特征,通常是数据集列的名称。
  • 类别: 决策特征,通常是二进制值,通常总是数据集的最后一列。
  • 值: 属性在数据集中的可能值,例如(晴天、多云、雨天)
  • 树: 具有彼此关联的多个节点的类别。
  • 节点: 负责存储属性(问题)的实体,也有一系列与之相关的弧线。

  • 弧线: 包含属性的值,并有一个属性将包含下一个子节点。

  • 叶子: 包含一个类别。这个节点是决策的结果,例如(是或否)。

  • 最佳特征: 信息增益最高的属性。

从数据集创建树的函数:

  • 获取类别的值。
  • 评估数据集中是否只有一种类型的类别,例如(是)。
  • 如果为真,则创建一个叶子对象并返回该对象。
  • 获取当前每个属性的信息增益。
  • 选择信息增益最高的属性。
  • 使用最佳特征创建一个节点。
  • 获取最佳特征的值。
  • 迭代这些值的列表。

    • 过滤列表,以便仅包含我们正在迭代的值的记录(将其保存到临时变量中)。

    • 创建一个具有此值的弧线。- 为弧线分配下一个属性:(这里是递归)再次调用相同的函数,发送(过滤后的记录列表、类别、没有最佳特征的属性列表、没有最佳特征的属性的一般列表)。

    • 将弧线添加到节点中。
  • 返回节点。

这将是负责创建树的代码段

let crearArbol = (ejemplosLista, clase, atributos, valores) => {        let valoresClase = obtenerValoresAtributo(ejemplosLista, clase);        if (valoresClase.length == 1) {            autoIncremental++;            return new Hoja(valoresClase[0], autoIncremental);        }        if (atributos.length == 0) {            let claseDominante = claseMayoritaria(ejemplosLista);            return new Atributo();        }        let gananciaAtributos = obtenerGananciaAtributos(ejemplosLista, valores, atributos);        let atributoMaximo = atributos[maximaGanancia(gananciaAtributos)];        autoIncremental++;        let nodo = new Atributo(atributoMaximo, [], autoIncremental);        let valoresLista = obtenerValoresAtributo(ejemplosLista, atributoMaximo);        valoresLista.forEach((valor) => {            let ejemplosFiltrados = arrayDistincAtributos(ejemplosLista, atributoMaximo, valor);            let arco = new Arco(valor);            arco.sigNodo = crearArbol(ejemplosFiltrados, clase, [...eliminarAtributo(atributoMaximo, atributos)], [...eliminarValores(atributoMaximo, valores)]);            nodo.hijos.push(arco);        });        return nodo;    };

不幸的是,代码仅提供西班牙语。这是包含我项目实现的存储库 ID3算法源代码

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注