Decision Trees are one of the most intuitive machine learning algorithms. They are built by splitting datasets into branches, guided by **Information Gain** (based on Entropy).
In this post, we’ll walk step by step through coding the building blocks of a decision tree:
- Computing entropy
- Splitting the dataset
- Calculating information gain
- Choosing the best split
1. Computing Entropy
**Entropy** is a measure of impurity at a node. The formula is:
$$H(p) = -p \log_2(p) - (1-p)\log_2(1-p)$$Where $p$ is the fraction of positive examples (e.g., “edible” mushrooms).
If the node is pure (all edible or all poisonous), entropy = 0.
Here’s the implementation:
import numpy as np
def compute_entropy(y):
"""
Computes the entropy for labels at a node.
Args:
y (ndarray): Array indicating labels (1 = edible, 0 = poisonous)
Returns:
entropy (float): Entropy at that node
"""
entropy = 0.
if len(y) != 0:
# Fraction of edible examples
p1 = np.mean(y == 1)
# Apply entropy formula (avoid 0*log2(0))
if p1 != 0 and p1 != 1:
entropy = -p1 * np.log2(p1) - (1 - p1) * np.log2(1 - p1)
return entropy
✅ Quick test:
print(compute_entropy(np.array([1,1,1,1]))) # 0.0
print(compute_entropy(np.array([0,0,0,0]))) # 0.0
print(compute_entropy(np.array([1,0,1,0]))) # 1.0
2. Splitting the Dataset
To build a decision tree, we split data based on feature values. If the feature = 1 → goes to the **left branch**, otherwise → **right branch**.
def split_dataset(X, node_indices, feature):
"""
Splits the dataset based on a feature value.
Args:
X (ndarray): Data matrix (n_samples, n_features)
node_indices (list): Indices of samples at this node
feature (int): Feature index to split on
Returns:
left_indices (list): Indices with feature value == 1
right_indices (list): Indices with feature value == 0
"""
left_indices = []
right_indices = []
for i in node_indices:
if X[i][feature] == 1:
left_indices.append(i)
else:
right_indices.append(i)
return left_indices, right_indices
✅ Quick test:
X = np.array([[1,0,1],
[0,1,1],
[1,1,0],
[0,0,1]])
node_indices = [0,1,2,3]
print(split_dataset(X, node_indices, 0)) # ([0, 2], [1, 3])
3. Computing Information Gain
Now, we measure how much splitting reduces impurity:
$$IG = H(\text{node}) - \big( w_{\text{left}} \cdot H(\text{left}) + w_{\text{right}} \cdot H(\text{right}) \big)$$Implementation:
def compute_information_gain(X, y, node_indices, feature):
"""
Computes information gain from splitting on a feature.
"""
left_indices, right_indices = split_dataset(X, node_indices, feature)
y_node = y[node_indices]
y_left = y[left_indices]
y_right = y[right_indices]
# Entropy before split
node_entropy = compute_entropy(y_node)
# Entropy after split
left_entropy = compute_entropy(y_left)
right_entropy = compute_entropy(y_right)
# Weights
w_left = len(y_left) / len(y_node) if len(y_node) > 0 else 0
w_right = len(y_right) / len(y_node) if len(y_node) > 0 else 0
# Weighted average entropy
weighted_entropy = w_left * left_entropy + w_right * right_entropy
return node_entropy - weighted_entropy
✅ Quick test:
y = np.array([1, 0, 1, 0])
print(compute_information_gain(X, y, node_indices, 0))
4. Choosing the Best Split
Finally, we loop through all features and pick the one that maximizes **Information Gain**.
def get_best_split(X, y, node_indices):
"""
Finds the best feature to split on.
"""
num_features = X.shape[1]
best_feature = -1
max_info_gain = 0
for feature in range(num_features):
info_gain = compute_information_gain(X, y, node_indices, feature)
if info_gain > max_info_gain:
max_info_gain = info_gain
best_feature = feature
return best_feature
✅ Quick test:
print(get_best_split(X, y, node_indices)) # returns the best feature index
๐ Key Takeaways
- **Entropy** measures impurity.
- **Splitting** separates data into left (1) and right (0).
- **Information Gain** tells us how good a split is.
- **Best Split** is the feature with the highest information gain.