VIDHYAI
HomeBlogTutorialsNewsAboutContact
VIDHYAI

Your Gateway to AI Knowledge

CONTENT

  • Blog
  • Tutorials
  • News

COMPANY

  • About
  • Contact

LEGAL

  • Privacy Policy
  • Terms of Service
  • Disclaimer
Home
Tutorials
Machine Learning
Supervised Learning - Classification
Classification Algorithms
Decision Trees
Back to Classification Algorithms
Progress3/6 lessons (50%)
Lesson 3

Decision Trees

Explore Decision Trees and how they split data into meaningful decision rules. This lesson teaches tree-building, visualization, and practical classification applications.

10 min read6 views

Introduction to Decision Trees

A decision tree is a flowchart-like structure where each internal node represents a test on a feature, each branch represents an outcome, and each leaf node represents a class label. The path from root to leaf represents classification rules.

The Decision Tree Intuition

Think of playing "20 Questions." You ask yes/no questions to narrow down possibilities until you reach an answer. Decision trees work similarly, asking questions about features to arrive at predictions.

Real-world applications:

  • Medical diagnosis systems
  • Credit risk assessment
  • Customer segmentation
  • Fraud detection
  • Game AI and strategy decisions

How Decision Trees Work

The Tree Structure

                    [Root Node]
                   Is Age > 30?
                   /          \
                Yes            No
                /                \
        [Internal Node]      [Leaf: Class B]
       Is Income > 50K?
          /        \
        Yes         No
        /             \
   [Leaf: Class A]  [Leaf: Class B]

Components:

  • Root Node: The topmost node, first decision
  • Internal Nodes: Intermediate decisions
  • Leaf Nodes: Final predictions
  • Branches: Outcomes of decisions

The Splitting Process

At each node, the algorithm:

  1. Evaluates all possible splits for all features
  2. Selects the split that best separates classes
  3. Creates child nodes
  4. Repeats until stopping criteria are met

Splitting Criteria: Measuring Impurity

Decision trees split nodes to reduce "impurity" - the mixing of classes. Two common measures are Gini Impurity and Entropy.

Gini Impurity

Gini = 1 - Σ(pᵢ)²

Where pᵢ is the probability of class i in the node.

  • Gini = 0: Pure node (all samples same class)
  • Gini = 0.5: Maximum impurity for binary classification

Entropy (Information Gain)

Entropy = -Σ pᵢ × log₂(pᵢ)
  • Entropy = 0: Pure node
  • Entropy = 1: Maximum impurity for binary classification

Comparing Gini and Entropy

import numpy as np
import matplotlib.pyplot as plt

p = np.linspace(0.01, 0.99, 100)

gini = 2 * p * (1 - p)
entropy = -p * np.log2(p) - (1 - p) * np.log2(1 - p)

plt.figure(figsize=(8, 5))
plt.plot(p, gini, label='Gini Impurity', linewidth=2)
plt.plot(p, entropy, label='Entropy', linewidth=2)
plt.xlabel('Probability of Class 1')
plt.ylabel('Impurity')
plt.title('Gini vs Entropy for Binary Classification')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

Both measures are highest when classes are equally mixed and lowest when one class dominates.


Implementing Decision Trees in Python

Step 1: Import Libraries and Load Data

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
from sklearn.datasets import load_iris

# Load data
iris = load_iris()
X = iris.data
y = iris.target

# Use only 2 features for visualization
X_2d = X[:, :2]  # sepal length and width
feature_names = iris.feature_names[:2]

X_train, X_test, y_train, y_test = train_test_split(
    X_2d, y, test_size=0.2, random_state=42, stratify=y
)

Step 2: Train the Decision Tree

# Create and train decision tree
tree_clf = DecisionTreeClassifier(
    criterion='gini',      # or 'entropy'
    max_depth=3,           # limit tree depth
    random_state=42
)

tree_clf.fit(X_train, y_train)

Step 3: Visualize the Decision Tree

plt.figure(figsize=(20, 10))
plot_tree(
    tree_clf,
    feature_names=feature_names,
    class_names=iris.target_names,
    filled=True,
    rounded=True,
    fontsize=10
)
plt.title('Decision Tree for Iris Classification')
plt.tight_layout()
plt.show()

This visualization shows exactly how the tree makes decisions, making it highly interpretable.

Step 4: Understand Tree Output

# Print tree rules as text
from sklearn.tree import export_text

tree_rules = export_text(tree_clf, feature_names=list(feature_names))
print("Decision Tree Rules:")
print(tree_rules)

Output:

Decision Tree Rules:
|--- sepal width (cm) <= 2.80
|   |--- sepal length (cm) <= 5.45
|   |   |--- class: versicolor
|   |--- sepal length (cm) > 5.45
|   |   |--- class: virginica
|--- sepal width (cm) > 2.80
|   |--- sepal length (cm) <= 5.45
|   |   |--- class: setosa
|   |--- sepal length (cm) > 5.45
|   |   |--- class: versicolor

Step 5: Evaluate the Model

y_pred = tree_clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)

print(f"Accuracy: {accuracy:.4f}")
print("\nClassification Report:")
print(classification_report(y_test, y_pred, target_names=iris.target_names))

Visualizing Decision Boundaries

def plot_decision_boundary_tree(clf, X, y, feature_names, class_names):
    """Visualize decision tree boundaries."""
    x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5
    y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5
    
    xx, yy = np.meshgrid(np.linspace(x_min, x_max, 200),
                         np.linspace(y_min, y_max, 200))
    
    Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)
    
    plt.figure(figsize=(10, 6))
    plt.contourf(xx, yy, Z, alpha=0.3, cmap='viridis')
    
    scatter = plt.scatter(X[:, 0], X[:, 1], c=y, cmap='viridis', 
                         edgecolors='black', s=50)
    plt.xlabel(feature_names[0])
    plt.ylabel(feature_names[1])
    plt.title('Decision Tree Decision Boundaries')
    plt.colorbar(scatter, label='Class')
    plt.show()

plot_decision_boundary_tree(tree_clf, X_2d, y, feature_names, iris.target_names)

Notice that decision tree boundaries are always parallel to the feature axes (horizontal or vertical), reflecting the axis-aligned splits.


Feature Importance

Decision trees provide feature importance scores based on how much each feature contributes to reducing impurity.

# Use all features
tree_full = DecisionTreeClassifier(max_depth=4, random_state=42)
tree_full.fit(X_train_full, y_train)

# Get feature importance
importance = pd.DataFrame({
    'Feature': iris.feature_names,
    'Importance': tree_full.feature_importances_
}).sort_values('Importance', ascending=False)

print("Feature Importance:")
print(importance)

# Visualize
plt.figure(figsize=(8, 5))
plt.barh(importance['Feature'], importance['Importance'])
plt.xlabel('Importance')
plt.title('Decision Tree Feature Importance')
plt.gca().invert_yaxis()
plt.show()

Features that appear higher in the tree and are used more frequently have higher importance.


Controlling Tree Complexity

The Overfitting Problem

Without constraints, decision trees grow until every leaf is pure, perfectly fitting training data but performing poorly on new data.

# Compare trees of different depths
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
depths = [2, 5, None]  # None = unlimited

for ax, depth in zip(axes, depths):
    tree = DecisionTreeClassifier(max_depth=depth, random_state=42)
    tree.fit(X_train, y_train)
    
    # Calculate scores
    train_score = tree.score(X_train, y_train)
    test_score = tree.score(X_test, y_test)
    
    # Plot boundaries
    x_min, x_max = X_2d[:, 0].min() - 0.5, X_2d[:, 0].max() + 0.5
    y_min, y_max = X_2d[:, 1].min() - 0.5, X_2d[:, 1].max() + 0.5
    xx, yy = np.meshgrid(np.linspace(x_min, x_max, 100),
                         np.linspace(y_min, y_max, 100))
    Z = tree.predict(np.c_[xx.ravel(), yy.ravel()]).reshape(xx.shape)
    
    ax.contourf(xx, yy, Z, alpha=0.3, cmap='viridis')
    ax.scatter(X_2d[:, 0], X_2d[:, 1], c=y, cmap='viridis', edgecolors='black')
    ax.set_title(f'Depth={depth}\nTrain: {train_score:.2f}, Test: {test_score:.2f}')
    ax.set_xlabel(feature_names[0])
    ax.set_ylabel(feature_names[1])

plt.tight_layout()
plt.show()

Hyperparameters for Controlling Complexity

tree_controlled = DecisionTreeClassifier(
    max_depth=5,           # Maximum depth of tree
    min_samples_split=10,  # Minimum samples to split a node
    min_samples_leaf=5,    # Minimum samples in a leaf
    max_features='sqrt',   # Features to consider for split
    max_leaf_nodes=20,     # Maximum number of leaves
    random_state=42
)

tree_controlled.fit(X_train, y_train)
print(f"Controlled tree accuracy: {tree_controlled.score(X_test, y_test):.4f}")

Key hyperparameters:

Parameter Effect
max_depth Limits how deep the tree grows
min_samples_split Minimum samples needed to split
min_samples_leaf Minimum samples required in each leaf
max_leaf_nodes Limits total number of leaves
max_features Features considered at each split

Pruning Decision Trees

Pruning reduces tree complexity after training. Scikit-learn supports cost complexity pruning.

Cost Complexity Pruning

# Find optimal alpha (pruning parameter)
path = tree_clf.cost_complexity_pruning_path(X_train, y_train)
alphas = path.ccp_alphas

# Train trees with different alphas
train_scores = []
test_scores = []

for alpha in alphas:
    tree = DecisionTreeClassifier(ccp_alpha=alpha, random_state=42)
    tree.fit(X_train, y_train)
    train_scores.append(tree.score(X_train, y_train))
    test_scores.append(tree.score(X_test, y_test))

# Plot
plt.figure(figsize=(10, 5))
plt.plot(alphas, train_scores, label='Train', marker='o')
plt.plot(alphas, test_scores, label='Test', marker='s')
plt.xlabel('Alpha (Complexity Parameter)')
plt.ylabel('Accuracy')
plt.title('Pruning: Accuracy vs Complexity Parameter')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

# Find best alpha
best_alpha = alphas[np.argmax(test_scores)]
print(f"Best alpha: {best_alpha:.4f}")

Apply Optimal Pruning

pruned_tree = DecisionTreeClassifier(ccp_alpha=best_alpha, random_state=42)
pruned_tree.fit(X_train, y_train)

print(f"Pruned tree depth: {pruned_tree.get_depth()}")
print(f"Pruned tree leaves: {pruned_tree.get_n_leaves()}")
print(f"Test accuracy: {pruned_tree.score(X_test, y_test):.4f}")

Decision Trees for Regression

Decision trees can also predict continuous values by averaging target values in each leaf.

from sklearn.tree import DecisionTreeRegressor
from sklearn.datasets import make_regression
from sklearn.metrics import mean_squared_error, r2_score

# Generate regression data
X_reg, y_reg = make_regression(n_samples=200, n_features=1, noise=20, random_state=42)

# Train regression tree
tree_reg = DecisionTreeRegressor(max_depth=4, random_state=42)
tree_reg.fit(X_reg, y_reg)

# Predict
y_pred_reg = tree_reg.predict(X_reg)

# Visualize
plt.figure(figsize=(10, 5))
sort_idx = X_reg.flatten().argsort()
plt.scatter(X_reg, y_reg, alpha=0.5, label='Data')
plt.plot(X_reg[sort_idx], y_pred_reg[sort_idx], 'r-', linewidth=2, label='Prediction')
plt.xlabel('Feature')
plt.ylabel('Target')
plt.title('Decision Tree Regression')
plt.legend()
plt.show()

print(f"R² Score: {r2_score(y_reg, y_pred_reg):.4f}")

Notice the step-like predictions characteristic of tree-based regression.


Advantages and Disadvantages

Advantages

  • Highly interpretable: Easy to visualize and explain
  • No feature scaling required: Handles raw values
  • Handles mixed data types: Works with numerical and categorical features
  • Captures non-linear relationships: Through hierarchical splits
  • Feature importance: Built-in measure of feature relevance

Disadvantages

  • Prone to overfitting: Without proper constraints
  • Unstable: Small data changes can produce different trees
  • Biased toward features with many levels: For categorical features
  • Axis-aligned boundaries: Cannot capture diagonal relationships well
  • Greedy algorithm: May not find globally optimal tree

Summary

Decision trees provide an intuitive, interpretable approach to classification and regression through hierarchical decision rules.

Key takeaways:

  • Decision trees split data based on feature thresholds
  • Gini impurity and entropy measure node purity
  • Feature importance indicates which features drive predictions
  • Control complexity with max_depth, min_samples_leaf, etc.
  • Pruning reduces overfitting by simplifying the tree
  • Trees are interpretable but prone to overfitting
  • Decision boundaries are always axis-aligned
  • Trees form the basis for powerful ensemble methods
Back to Classification Algorithms

Previous Lesson

K-Nearest Neighbors (KNN) Algorithm

Next Lesson

Support Vector Machines

Related Lessons

1

Logistic Regression for Binary Classification

Learn how Logistic Regression predicts binary outcomes using probability-based decision boundaries. This lesson covers theory, implementation, and practical use cases like spam detection and churn prediction.

2

K-Nearest Neighbors (KNN) Algorithm

Understand how the KNN algorithm classifies data based on similarity. This lesson explains distance metrics, choosing the right K value, and building accurate classification models.

3

Naive Bayes Classifier

Discover the Naive Bayes classifier, a fast and powerful algorithm based on probability and Bayes’ theorem. This lesson shows how it excels in text classification and other high‑dimensional tasks.

In this track (6)

1Logistic Regression for Binary Classification2K-Nearest Neighbors (KNN) Algorithm3Decision Trees4Support Vector Machines5Naive Bayes Classifier6Classification Project - Customer Churn Prediction