决策树是一种用于分类和回归的监督学习方法。决策树目标是创建一个模型,通过学习从数据特征推断出的简单决策规则来预测目标变量的值。
决策树优缺点
决策树的一些优点是:
- 易于理解和解释。树可以被可视化。
- 需要很少的训练数据
- 能处理数值和类别数据
- 能够处理多输出问题
决策树的一些缺点是:
- 深度太深,很容易过拟合
- 决策树可能不稳定
- 决策树的预测结果不是连续的
- 决策树节点分裂过程是贪心的
sklearn 决策树API
DecisionTreeClassifier
https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.htm
二分类或多分类、多标签分类
fromsklearn.datasetsimportload_irisfromsklearn.model_selectionimportcross_val_scorefromsklearn.treeimportDecisionTreeClassifier clf = DecisionTreeClassifier(random_state=0) iris = load_iris() cross_val_score(clf, iris.data, iris.target, cv=10)
DecisionTreeRegressor
https://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeRegressor.html
回归、多标签回归
fromsklearn.datasetsimportload_diabetesfromsklearn.model_selectionimportcross_val_scorefromsklearn.treeimportDecisionTreeRegressor X, y = load_diabetes(return_X_y=True) regressor = DecisionTreeRegressor(random_state=0) cross_val_score(regressor, X, y, cv=10)
sklearn 底层树结构
树结构
决策分类器有一个名为的属性tree_,它允许访问低级属性,例如node_count(节点总数),和max_depth(树的最大深度),它还存储整个二叉树结构,表示为多个并行数组。
- children_left[i]: 节点的左子节点的 idi,如果是叶节点则为 -1
- children_right[i]: 节点右子节点的idi,如果是叶节点则为-1
- feature[i]: 用于分裂节点的特征i
- threshold[i]:节点的阈值i
- n_node_samples[i]:到达节点的训练样本数i
- impurity[i]:节点处的杂质i
iris = load_iris() X = iris.data y = iris.target X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
clf = DecisionTreeClassifier(max_leaf_nodes=3, random_state=0)
clf.fit(X_train, y_train)
n_nodes = clf.tree_.node_count
children_left = clf.tree_.children_left
children_right = clf.tree_.children_right
feature = clf.tree_.feature
threshold = clf.tree_.threshold
决策路径
decision_path方法输出一个指示矩阵,允许检索感兴趣的样本遍历的节点。位置处的指示矩阵中的非零元素表示样本经过节点。
apply方法返回样本所达到的叶id,得到样本所达到的叶子节点 ID 的数组,这可以用对样本进行编码,也可以用于特征工程。
iris = load_iris() X = iris.data y = iris.target X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
clf = DecisionTreeClassifier(max_leaf_nodes=3, random_state=0)
clf.fit(X_train, y_train)
node_indicator = clf.decision_path(X_test)
leaf_id = clf.apply(X_test)
参考资料
- https://scikit-learn.org/stable/modules/tree.html
- https://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html
- https://scikit-learn.org/stable/auto_examples/tree/