Explore Decision Trees and how they split data into meaningful decision rules. This lesson teaches tree-building, visualization, and practical classification applications.
A decision tree is a flowchart-like structure where each internal node represents a test on a feature, each branch represents an outcome, and each leaf node represents a class label. The path from root to leaf represents classification rules.
Think of playing "20 Questions." You ask yes/no questions to narrow down possibilities until you reach an answer. Decision trees work similarly, asking questions about features to arrive at predictions.
Real-world applications:
[Root Node]
Is Age > 30?
/ \
Yes No
/ \
[Internal Node] [Leaf: Class B]
Is Income > 50K?
/ \
Yes No
/ \
[Leaf: Class A] [Leaf: Class B]
Components:
At each node, the algorithm:
Decision trees split nodes to reduce "impurity" - the mixing of classes. Two common measures are Gini Impurity and Entropy.
Gini = 1 - Σ(pᵢ)²
Where pᵢ is the probability of class i in the node.
Entropy = -Σ pᵢ × log₂(pᵢ)
import numpy as np
import matplotlib.pyplot as plt
p = np.linspace(0.01, 0.99, 100)
gini = 2 * p * (1 - p)
entropy = -p * np.log2(p) - (1 - p) * np.log2(1 - p)
plt.figure(figsize=(8, 5))
plt.plot(p, gini, label='Gini Impurity', linewidth=2)
plt.plot(p, entropy, label='Entropy', linewidth=2)
plt.xlabel('Probability of Class 1')
plt.ylabel('Impurity')
plt.title('Gini vs Entropy for Binary Classification')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
Both measures are highest when classes are equally mixed and lowest when one class dominates.
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
from sklearn.datasets import load_iris
# Load data
iris = load_iris()
X = iris.data
y = iris.target
# Use only 2 features for visualization
X_2d = X[:, :2] # sepal length and width
feature_names = iris.feature_names[:2]
X_train, X_test, y_train, y_test = train_test_split(
X_2d, y, test_size=0.2, random_state=42, stratify=y
)
# Create and train decision tree
tree_clf = DecisionTreeClassifier(
criterion='gini', # or 'entropy'
max_depth=3, # limit tree depth
random_state=42
)
tree_clf.fit(X_train, y_train)
plt.figure(figsize=(20, 10))
plot_tree(
tree_clf,
feature_names=feature_names,
class_names=iris.target_names,
filled=True,
rounded=True,
fontsize=10
)
plt.title('Decision Tree for Iris Classification')
plt.tight_layout()
plt.show()
This visualization shows exactly how the tree makes decisions, making it highly interpretable.
# Print tree rules as text
from sklearn.tree import export_text
tree_rules = export_text(tree_clf, feature_names=list(feature_names))
print("Decision Tree Rules:")
print(tree_rules)
Output:
Decision Tree Rules:
|--- sepal width (cm) <= 2.80
| |--- sepal length (cm) <= 5.45
| | |--- class: versicolor
| |--- sepal length (cm) > 5.45
| | |--- class: virginica
|--- sepal width (cm) > 2.80
| |--- sepal length (cm) <= 5.45
| | |--- class: setosa
| |--- sepal length (cm) > 5.45
| | |--- class: versicolor
y_pred = tree_clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy:.4f}")
print("\nClassification Report:")
print(classification_report(y_test, y_pred, target_names=iris.target_names))
def plot_decision_boundary_tree(clf, X, y, feature_names, class_names):
"""Visualize decision tree boundaries."""
x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5
y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5
xx, yy = np.meshgrid(np.linspace(x_min, x_max, 200),
np.linspace(y_min, y_max, 200))
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
plt.figure(figsize=(10, 6))
plt.contourf(xx, yy, Z, alpha=0.3, cmap='viridis')
scatter = plt.scatter(X[:, 0], X[:, 1], c=y, cmap='viridis',
edgecolors='black', s=50)
plt.xlabel(feature_names[0])
plt.ylabel(feature_names[1])
plt.title('Decision Tree Decision Boundaries')
plt.colorbar(scatter, label='Class')
plt.show()
plot_decision_boundary_tree(tree_clf, X_2d, y, feature_names, iris.target_names)
Notice that decision tree boundaries are always parallel to the feature axes (horizontal or vertical), reflecting the axis-aligned splits.
Decision trees provide feature importance scores based on how much each feature contributes to reducing impurity.
# Use all features
tree_full = DecisionTreeClassifier(max_depth=4, random_state=42)
tree_full.fit(X_train_full, y_train)
# Get feature importance
importance = pd.DataFrame({
'Feature': iris.feature_names,
'Importance': tree_full.feature_importances_
}).sort_values('Importance', ascending=False)
print("Feature Importance:")
print(importance)
# Visualize
plt.figure(figsize=(8, 5))
plt.barh(importance['Feature'], importance['Importance'])
plt.xlabel('Importance')
plt.title('Decision Tree Feature Importance')
plt.gca().invert_yaxis()
plt.show()
Features that appear higher in the tree and are used more frequently have higher importance.
Without constraints, decision trees grow until every leaf is pure, perfectly fitting training data but performing poorly on new data.
# Compare trees of different depths
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
depths = [2, 5, None] # None = unlimited
for ax, depth in zip(axes, depths):
tree = DecisionTreeClassifier(max_depth=depth, random_state=42)
tree.fit(X_train, y_train)
# Calculate scores
train_score = tree.score(X_train, y_train)
test_score = tree.score(X_test, y_test)
# Plot boundaries
x_min, x_max = X_2d[:, 0].min() - 0.5, X_2d[:, 0].max() + 0.5
y_min, y_max = X_2d[:, 1].min() - 0.5, X_2d[:, 1].max() + 0.5
xx, yy = np.meshgrid(np.linspace(x_min, x_max, 100),
np.linspace(y_min, y_max, 100))
Z = tree.predict(np.c_[xx.ravel(), yy.ravel()]).reshape(xx.shape)
ax.contourf(xx, yy, Z, alpha=0.3, cmap='viridis')
ax.scatter(X_2d[:, 0], X_2d[:, 1], c=y, cmap='viridis', edgecolors='black')
ax.set_title(f'Depth={depth}\nTrain: {train_score:.2f}, Test: {test_score:.2f}')
ax.set_xlabel(feature_names[0])
ax.set_ylabel(feature_names[1])
plt.tight_layout()
plt.show()
tree_controlled = DecisionTreeClassifier(
max_depth=5, # Maximum depth of tree
min_samples_split=10, # Minimum samples to split a node
min_samples_leaf=5, # Minimum samples in a leaf
max_features='sqrt', # Features to consider for split
max_leaf_nodes=20, # Maximum number of leaves
random_state=42
)
tree_controlled.fit(X_train, y_train)
print(f"Controlled tree accuracy: {tree_controlled.score(X_test, y_test):.4f}")
Key hyperparameters:
| Parameter | Effect |
|---|---|
max_depth |
Limits how deep the tree grows |
min_samples_split |
Minimum samples needed to split |
min_samples_leaf |
Minimum samples required in each leaf |
max_leaf_nodes |
Limits total number of leaves |
max_features |
Features considered at each split |
Pruning reduces tree complexity after training. Scikit-learn supports cost complexity pruning.
# Find optimal alpha (pruning parameter)
path = tree_clf.cost_complexity_pruning_path(X_train, y_train)
alphas = path.ccp_alphas
# Train trees with different alphas
train_scores = []
test_scores = []
for alpha in alphas:
tree = DecisionTreeClassifier(ccp_alpha=alpha, random_state=42)
tree.fit(X_train, y_train)
train_scores.append(tree.score(X_train, y_train))
test_scores.append(tree.score(X_test, y_test))
# Plot
plt.figure(figsize=(10, 5))
plt.plot(alphas, train_scores, label='Train', marker='o')
plt.plot(alphas, test_scores, label='Test', marker='s')
plt.xlabel('Alpha (Complexity Parameter)')
plt.ylabel('Accuracy')
plt.title('Pruning: Accuracy vs Complexity Parameter')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
# Find best alpha
best_alpha = alphas[np.argmax(test_scores)]
print(f"Best alpha: {best_alpha:.4f}")
pruned_tree = DecisionTreeClassifier(ccp_alpha=best_alpha, random_state=42)
pruned_tree.fit(X_train, y_train)
print(f"Pruned tree depth: {pruned_tree.get_depth()}")
print(f"Pruned tree leaves: {pruned_tree.get_n_leaves()}")
print(f"Test accuracy: {pruned_tree.score(X_test, y_test):.4f}")
Decision trees can also predict continuous values by averaging target values in each leaf.
from sklearn.tree import DecisionTreeRegressor
from sklearn.datasets import make_regression
from sklearn.metrics import mean_squared_error, r2_score
# Generate regression data
X_reg, y_reg = make_regression(n_samples=200, n_features=1, noise=20, random_state=42)
# Train regression tree
tree_reg = DecisionTreeRegressor(max_depth=4, random_state=42)
tree_reg.fit(X_reg, y_reg)
# Predict
y_pred_reg = tree_reg.predict(X_reg)
# Visualize
plt.figure(figsize=(10, 5))
sort_idx = X_reg.flatten().argsort()
plt.scatter(X_reg, y_reg, alpha=0.5, label='Data')
plt.plot(X_reg[sort_idx], y_pred_reg[sort_idx], 'r-', linewidth=2, label='Prediction')
plt.xlabel('Feature')
plt.ylabel('Target')
plt.title('Decision Tree Regression')
plt.legend()
plt.show()
print(f"R² Score: {r2_score(y_reg, y_pred_reg):.4f}")
Notice the step-like predictions characteristic of tree-based regression.
Decision trees provide an intuitive, interpretable approach to classification and regression through hierarchical decision rules.
Key takeaways:
Learn how Logistic Regression predicts binary outcomes using probability-based decision boundaries. This lesson covers theory, implementation, and practical use cases like spam detection and churn prediction.
Understand how the KNN algorithm classifies data based on similarity. This lesson explains distance metrics, choosing the right K value, and building accurate classification models.
Discover the Naive Bayes classifier, a fast and powerful algorithm based on probability and Bayes’ theorem. This lesson shows how it excels in text classification and other high‑dimensional tasks.