๐ŸŒณ Building a Decision Tree From Scratch in Python

Decision Trees are one of the most intuitive machine learning algorithms. They are built by splitting datasets into branches, guided by **Information Gain** (based on Entropy).

Decision Tree Example
A simplified node split in a Decision Tree.

In this post, we’ll walk step by step through coding the building blocks of a decision tree:

  1. Computing entropy
  2. Splitting the dataset
  3. Calculating information gain
  4. Choosing the best split

1. Computing Entropy

**Entropy** is a measure of impurity at a node. The formula is:

$$H(p) = -p \log_2(p) - (1-p)\log_2(1-p)$$

Where $p$ is the fraction of positive examples (e.g., “edible” mushrooms).

If the node is pure (all edible or all poisonous), entropy = 0.

Here’s the implementation:

import numpy as np

def compute_entropy(y):
    """
    Computes the entropy for labels at a node.
    
    Args:
        y (ndarray): Array indicating labels (1 = edible, 0 = poisonous)
        
    Returns:
        entropy (float): Entropy at that node
    """
    entropy = 0.
    if len(y) != 0:
        # Fraction of edible examples
        p1 = np.mean(y == 1)
        
        # Apply entropy formula (avoid 0*log2(0))
        if p1 != 0 and p1 != 1:
            entropy = -p1 * np.log2(p1) - (1 - p1) * np.log2(1 - p1)
    return entropy

✅ Quick test:

print(compute_entropy(np.array([1,1,1,1])))  # 0.0
print(compute_entropy(np.array([0,0,0,0])))  # 0.0
print(compute_entropy(np.array([1,0,1,0])))  # 1.0

2. Splitting the Dataset

To build a decision tree, we split data based on feature values. If the feature = 1 → goes to the **left branch**, otherwise → **right branch**.

def split_dataset(X, node_indices, feature):
    """
    Splits the dataset based on a feature value.
    
    Args:
        X (ndarray): Data matrix (n_samples, n_features)
        node_indices (list): Indices of samples at this node
        feature (int): Feature index to split on
    
    Returns:
        left_indices (list): Indices with feature value == 1
        right_indices (list): Indices with feature value == 0
    """
    left_indices = []
    right_indices = []
    
    for i in node_indices:
        if X[i][feature] == 1:
            left_indices.append(i)
        else:
            right_indices.append(i)
    
    return left_indices, right_indices

✅ Quick test:

X = np.array([[1,0,1],
              [0,1,1],
              [1,1,0],
              [0,0,1]])
node_indices = [0,1,2,3]
print(split_dataset(X, node_indices, 0))  # ([0, 2], [1, 3])

3. Computing Information Gain

Now, we measure how much splitting reduces impurity:

$$IG = H(\text{node}) - \big( w_{\text{left}} \cdot H(\text{left}) + w_{\text{right}} \cdot H(\text{right}) \big)$$

Implementation:

def compute_information_gain(X, y, node_indices, feature):
    """
    Computes information gain from splitting on a feature.
    """
    left_indices, right_indices = split_dataset(X, node_indices, feature)
    
    y_node = y[node_indices]
    y_left = y[left_indices]
    y_right = y[right_indices]
    
    # Entropy before split
    node_entropy = compute_entropy(y_node)
    
    # Entropy after split
    left_entropy = compute_entropy(y_left)
    right_entropy = compute_entropy(y_right)
    
    # Weights
    w_left = len(y_left) / len(y_node) if len(y_node) > 0 else 0
    w_right = len(y_right) / len(y_node) if len(y_node) > 0 else 0
    
    # Weighted average entropy
    weighted_entropy = w_left * left_entropy + w_right * right_entropy
    
    return node_entropy - weighted_entropy

✅ Quick test:

y = np.array([1, 0, 1, 0])
print(compute_information_gain(X, y, node_indices, 0))

4. Choosing the Best Split

Finally, we loop through all features and pick the one that maximizes **Information Gain**.

def get_best_split(X, y, node_indices):
    """
    Finds the best feature to split on.
    """
    num_features = X.shape[1]
    best_feature = -1
    max_info_gain = 0
    
    for feature in range(num_features):
        info_gain = compute_information_gain(X, y, node_indices, feature)
        
        if info_gain > max_info_gain:
            max_info_gain = info_gain
            best_feature = feature
    
    return best_feature

✅ Quick test:

print(get_best_split(X, y, node_indices)) # returns the best feature index

๐Ÿ”‘ Key Takeaways

  • **Entropy** measures impurity.
  • **Splitting** separates data into left (1) and right (0).
  • **Information Gain** tells us how good a split is.
  • **Best Split** is the feature with the highest information gain.

Post a Comment

Previous Post Next Post