我在Matlab中编写了一个决策树分类器。据我所知,一切应该都能正常工作,逻辑上也说得通。但当我尝试调用fit方法时,它在我的一个函数上崩溃了,提示我输入参数不对,但我确定我已经正确输入了!这个问题以及类似关于函数和输入参数的问题已经困扰了我一两天了。我怀疑这可能与在构造函数中调用这些函数有关,但即使在主脚本中调用它们也仍然不起作用。请帮帮我!
classdef my_ClassificationTree < handle properties X % 训练样本 Y % 训练标签 MinParentSize % 最小父节点大小 MaxNumSplits % 最大分割数 Verbose % 是否在过程中输出调试信息 % MinLeafSize CutPoint CutPredictorIndex Children numSplits root end methods % 构造函数:实现拟合阶段 function obj = my_ClassificationTree(X, Y, MinParentSize, MaxNumSplits, Verbose) obj.X = X; obj.Y = Y; obj.MinParentSize = MinParentSize; obj.MaxNumSplits = MaxNumSplits; obj.Verbose = Verbose;% obj.Children = zeros(1, 2);% obj.CutPoint = 0;% obj.CutPredictorIndex = 0; % obj.MinLeafSize = MinLeafSize; obj.numSplits = 0; obj.root = Node(1, size(obj.X,1)); root = Node(1, size(obj.X,1)); fit(obj,root); end function node = Node(sIndex,eIndex) node.startIndex = sIndex; node.endIndex = eIndex; node.leaf = false; node.Children = 0; node.size = eIndex - sIndex + 1; node.CutPoint = 0; node.CutPredictorIndex = 0; node.NodeClass = 0; end function fit(obj,node) if node.size < obj.MinParentSize || obj.numSplits >= obj.MaxNumSplits % 将节点标记为叶节点 node.Leaf = true; % 计算该节点处样本的主要类别标签 labels = obj.Y(node.startIndex:node.endIndex); % 收集节点范围内数据的所有标签 node.NodeClass = mode(labels); % 找到最频繁的标签并将节点分类为此类 return; end bestCutPoint = findBestCutPoint(node, obj.X, obj.Y); leftChild = Node(node.startIndex, bestCutPoint.CutIndex - 1); rightChild = Node(bestSplit.splitIndex, node.endIndex); obj.numSplits = obj.numSplits + 1; node.CutPoint = bestSplit.CutPoint; node.CutPredictorIndex = bestSplit.CutPredictorIndex; % 将子节点附加到父节点 node.Children = [leftChild, rightChild]; % 递归构建左子节点和右子节点的树 fit(obj, leftChild); fit(obj, rightChild); end function bestCutPoint = findBestCutPoint(node, X, labels) bestCutPoint.CutPoint = 0; bestCutPoint.CutPredictorIndex = 0; bestCutPoint.CutIndex = 0; bestGDI = Inf; % 将最佳GDI初始化为一个大值 % 循环遍历所有特征 for i = 1:size(X, 2) % 循环遍历特征的所有唯一值 values = unique(X(node.startIndex:node.endIndex, i)); for j = 1:length(values) % 计算两个结果的加权不纯度 % 切割 leftLabels = labels(node.startIndex:node.endIndex, 1); rightLabels = labels(node.startIndex:node.endIndex, 1); leftLabels = leftLabels(X(node.startIndex:node.endIndex, i) < values(j)); rightLabels = rightLabels(X(node.startIndex:node.endIndex, i) >= values(j)); leftGDI = weightedGDI(leftLabels, labels); rightGDI = weightedGDI(rightLabels, labels); % 计算分割的加权不纯度 cutGDI = leftGDI + rightGDI; % 如果当前分割的GDI更低,则更新最佳分割 if cutGDI < bestGDI bestGDI = cutGDI; bestCutPoint.CutPoint = values(j); bestCutPoint.CutPredictorIndex = i; bestCutPoint.CutIndex = find(X(:, i) == values(j), 1, 'first'); end end end end% 预测阶段: function predictions = predict(obj, test_examples) % 准备存储我们的预测类别标签: predictions = categorical; % 遍历X中的每个样本 for i = 1:size(test_examples, 1) % 将当前节点设置为根节点 currentNode = obj.root; % 当当前节点不是叶节点时 while ~currentNode.leaf % 检查当前节点的CutPredictorIndex属性指定的预测特征值 value = test_examples(i, currentNode.CutPredictorIndex); % 如果值小于当前节点的CutPoint,则将当前节点设置为当前节点的左子节点 if value < currentNode.CutPoint currentNode = currentNode.Children(1); % 如果值大于或等于当前节点的CutPoint,则将当前节点设置为当前节点的右子节点 else currentNode = currentNode.Children(2); end end % 一旦当前节点成为叶节点,将当前节点的NodeClass添加到predictions向量中 predictions(i) = currentNode.NodeClass; end end % 在下面的行中添加你想要的其他方法... end end
这是调用myClassificationTree的函数
function m = my_fitctree(train_examples, train_labels, varargin) % 接受一个额外的名称-值对参数,允许我们开启调试: p = inputParser; addParameter(p, 'Verbose', false); %addParameter(p, 'MinLeafSize', false); % 接受一个额外的名称-值对参数,允许我们设置最小 % 父节点大小(默认为10): addParameter(p, 'MinParentSize', 10); % 接受一个额外的名称-值对参数,允许我们设置最大 % 分割数(默认为训练样本数-1): addParameter(p, 'MaxNumSplits', size(train_examples,1) - 1); p.parse(varargin{:}); % 使用提供的参数创建一个新的my_ClassificationTree % 对象: m = my_ClassificationTree(train_examples, train_labels, ... p.Results.MinParentSize, p.Results.MaxNumSplits, p.Results.Verbose); end
这是我在主代码块中的代码
mym2_dt = my_fitctree(train_examples, train_labels, 'MinParentSize', 10)
这些是错误 这些是错误
我期望它能构建一个决策树并填充它。然而,它在findBestCutPoint函数上崩溃了,我无法修复它
回答:
类方法的第一个参数(构造函数除外)应该是类的实例(即obj
)。您对Node
和findBestCutPoint
的定义应该将obj
作为第一个参数。
此外,从其他方法内部调用类方法时,应使用obj.theMethod
的语法,这在您的代码中似乎并非如此。
因此,例如,对Node
的调用应该是:
obj.root = obj.Node(1, size(obj.X,1));
而Node
应定义如下:
function node = Node(obj,sIndex,eIndex)
对findBestCutPoint
同样适用。请注意,在调用时,类实例的引用是隐式传递的,因此您不需要在调用中实际包含它。