编辑:我在底部添加了自己的熵/信息增益/最佳分割方法,以防有人想帮助我调试,这样他们就不必自己编写这些方法了!
作为一项挑战,在观看了这个视频后,我想在Python中创建一个解决方案(基于类以练习面向对象编程)。由于这里的数据是分类数据,我想创建一棵树,使其视觉上与视频中的表示相似(即可能有>2个子节点,分割和值上的标签等)。
我已经尽力调试代码以使其运行。在当前状态下,它有两个大问题:
(1)实例化类的子节点列表显示了一些None值和一些Tree对象。由于None值破坏了.predict函数,我临时添加了一个if语句来忽略它们
(2)值属性显示”Yes”和”No”,这些是目标值而不是特征值。例如:
In: dt.tree_.children[0].valueOut: "Yes"
特征分割完全不在标签属性中。无论是节点还是子节点,标签字段都是None,预测也返回None。我已经花了几个小时检查代码,但找不出原因。我已经包含了数据(从视频中复制)和DataFrame设置,以便于访问,并展示我如何尝试运行程序。
提前为这篇很长的帖子道歉!我在试图解释我的逻辑(即使这可能会阻止大多数人帮助我),我不想让人们认为我只是在请别人帮我写代码!
注意:决策树类使用嵌套函数,以便.fit总是以一个新的Tree()实例开始,因此.predict将使用在.fit之后应该被填充的dt.tree_属性
Tree()类:
class Tree(): def __init__(self, children = [], label = None, value = None): self.children = children #用于替代二进制解决方案中的左右节点 self.label = label #标记这个节点的子节点是基于哪个特征进行分割的 self.value = value #上一个节点分割特征的值。对于头节点,这应该始终为None
dt.fit的伪代码:
def fit(self, data, target, features) def run_id3(data, target, features, tree): (基本情况) 检查目标列是否只有一个唯一值。 如果是,将当前树的标签设置为目标列,在当前树下添加一个子节点,带有目标值 返回(结束递归) 找到最佳特征来分割数据 将当前节点的标签设置为该特征 对于分割特征中的每个唯一值: 创建一个节点并将值设置为唯一值 将新节点添加到当前树的子节点列表中 使用当前唯一特征值(分割)过滤数据,并以子树为头进行递归 run_id3(data, target, features, self.tree_)
dt.fit的代码:
class DecisionTree(): tree_: Tree def __init__(self): self.tree_ = Tree() pass def fit(self, data, target, features): def run_id3(data, target, features, tree): unique_targets = pd.unique(data[target]) if len(unique_targets) == 1: tree.label = target tree.children.append(Tree(value=unique_targets[0])) return best_split = find_best(data, target, features) tree.label = best_split for unique_val in np.unique(data[best_split]): new_tree = Tree() new_tree.value = unique_val tree.children.append(run_id3(data[data[best_split] == unique_val], target, features, new_tree)) run_id3(data, target, features, self.tree_)
dt.predict的伪代码:
def predict(self, row): def get_prediction(tree, row): 检查当前节点是否没有子节点 返回节点标签(应该是目标预测) 将当前列(特征分割)设置为当前节点标签 对于当前节点的每个子节点 如果子节点不为null(这不好,存在是为了阻止程序停止) 如果子节点的值等于我们在测试行中该列的值 递归(向下走树),将当前子树设置为参数中的头 tree = self.tree_(所以树从实例化树的头部开始,应该在dt.fit之后被填充) 返回 get_prediction(tree, row)
dt.predict的代码:
def predict(self, row): def get_prediction(tree, row): if len(tree.children) == 0: return tree.label column = tree.label for child in tree.children:# 下面的条件完全是为了阻止程序停止,因为我还没弄清楚为什么子节点属性不断添加NoneType对象 if child is not None: if child.value == row[column]: return get_prediction(child, row) tree = self.tree_ return get_prediction(tree, row)
数据设置:
outlook = ['Sunny', 'Sunny', 'Overcast', 'Rain', 'Rain', 'Rain', 'Overcast', 'Sunny', 'Sunny', 'Rain', 'Sunny', 'Overcast', 'Overcast', 'Rain', 'Rain']humidity = ['High', 'High', 'High', 'High', 'Normal', 'Normal', 'Normal', 'High', 'Normal', 'Normal', 'Normal', 'High', 'Normal', 'High', 'High']wind = ['Weak', 'Strong', 'Weak', 'Weak', 'Weak', 'Strong', 'Strong', 'Weak', 'Weak', 'Weak', 'Strong', 'Strong', 'Weak', 'Strong', 'Weak']play = ['No', 'No', 'Yes', 'Yes', 'Yes', 'No', 'Yes', 'No', 'Yes', 'Yes', 'Yes', 'Yes', 'Yes', 'No', '?']columns = ["Outlook", "Humidity", "Wind", "Play"]data = pd.DataFrame([outlook, humidity, wind, play]).Tdata.columns = columnstrain = data.iloc[:-1, :]test = data.iloc[-1, :3]features = columns.copy()features.remove("Play")target = "Play"dt = DecisionTree()dt.fit(train, target, features)pred = dt.predict(test)
信息增益方法:
import numpy as npimport pandas as pddef entropy(column): elements, counts = np.unique(column, return_counts=True) # 如果语句在理解中停止nan结果,因为0*log2(x)未定义,返回0.在这种情况下, # 1*log2(1) + 0*log2(0) = 0.零熵结果,零不确定性与理论一致 entropy = np.sum( [-(counts[i] / np.sum(counts)) * np.log2(counts[i] / np.sum(counts)) if counts[i] > 0 else 0 for i in range(len(counts))]) return entropydef information_gain(data, split_name, target_name): target_entropy = entropy(data[target_name]) vals, counts = np.unique(data[split_name], return_counts=True) weighted_entropy = np.sum( [counts[i] / np.sum(counts) * entropy(data.loc[data[split_name] == vals[i], target_name]) for i in range(len(counts))]) return target_entropy - weighted_entropydef find_best(data, target_name, features): max_gain = 0 best_col = "" for col in features: gain = information_gain(data, col, target_name) if gain > max_gain: max_gain = gain best_col = col return best_col
回答: