Decision Tree Intuition
A decision tree learns a hierarchy of if-then rules that recursively partition the feature space. At each internal node, it asks a question about one feature (“Is $x_i \leq t$?”) and routes the data left or right. Leaf nodes contain the final class predictions.
flowchart TD
A["x₁ ≤ 3.5?"] -->|Yes| B["x₂ ≤ 1.2?"]
A -->|No| C["Class B"]
B -->|Yes| D["Class A"]
B -->|No| E["Class B"]
Entropy & Impurity
Entropy measures the disorder (uncertainty) in a set of labels. A pure node (all same class) has entropy 0; maximum entropy occurs when classes are equally distributed:
$$H(S) = -\sum_{k=1}^{K} p_k \log_2 p_k$$
import torch
def entropy(labels):
"""Compute Shannon entropy of a label tensor."""
if len(labels) == 0:
return torch.tensor(0.0)
classes, counts = torch.unique(labels, return_counts=True)
probs = counts.float() / len(labels)
# Filter out zero probabilities to avoid log(0)
probs = probs[probs > 0]
return -(probs * torch.log2(probs)).sum()
def gini_impurity(labels):
"""Compute Gini impurity of a label tensor."""
if len(labels) == 0:
return torch.tensor(0.0)
classes, counts = torch.unique(labels, return_counts=True)
probs = counts.float() / len(labels)
return 1.0 - (probs ** 2).sum()
# Examples
pure = torch.tensor([0, 0, 0, 0, 0])
mixed = torch.tensor([0, 0, 1, 1, 1])
uniform = torch.tensor([0, 1, 2, 3])
print(f"Pure set: entropy={entropy(pure):.4f}, gini={gini_impurity(pure):.4f}")
print(f"Mixed set: entropy={entropy(mixed):.4f}, gini={gini_impurity(mixed):.4f}")
print(f"Uniform set: entropy={entropy(uniform):.4f}, gini={gini_impurity(uniform):.4f}")
Information Gain
Information gain measures the reduction in entropy from a split. The best split maximizes this reduction:
$$IG(S, A) = H(S) - \sum_{v} \frac{|S_v|}{|S|} H(S_v)$$
import torch
def entropy(labels):
"""Compute Shannon entropy of a label tensor."""
if len(labels) == 0:
return torch.tensor(0.0)
classes, counts = torch.unique(labels, return_counts=True)
probs = counts.float() / len(labels)
probs = probs[probs > 0]
return -(probs * torch.log2(probs)).sum()
def information_gain(X_column, y, threshold):
"""Compute information gain for a binary split at threshold."""
parent_entropy = entropy(y)
# Split
left_mask = X_column <= threshold
right_mask = ~left_mask
y_left = y[left_mask]
y_right = y[right_mask]
if len(y_left) == 0 or len(y_right) == 0:
return torch.tensor(0.0)
# Weighted child entropy
n = len(y)
weighted_entropy = (len(y_left) / n) * entropy(y_left) + \
(len(y_right) / n) * entropy(y_right)
return parent_entropy - weighted_entropy
def find_best_split(X, y):
"""Find the best feature and threshold to split on."""
n_features = X.shape[1]
best_gain = torch.tensor(-1.0)
best_feature = 0
best_threshold = 0.0
for feature_idx in range(n_features):
column = X[:, feature_idx]
thresholds = torch.unique(column)
for threshold in thresholds:
gain = information_gain(column, y, threshold)
if gain > best_gain:
best_gain = gain
best_feature = feature_idx
best_threshold = threshold.item()
return best_feature, best_threshold, best_gain
# Demo
torch.manual_seed(42)
X = torch.randn(100, 3)
y = (X[:, 0] + X[:, 1] > 0).long() # Simple linear boundary
feat, thresh, gain = find_best_split(X, y)
print(f"Best split: feature {feat}, threshold {thresh:.3f}, gain {gain:.4f}")
Recursive Tree Construction
import torch
def entropy(labels):
if len(labels) == 0:
return torch.tensor(0.0)
classes, counts = torch.unique(labels, return_counts=True)
probs = counts.float() / len(labels)
probs = probs[probs > 0]
return -(probs * torch.log2(probs)).sum()
def information_gain(X_col, y, threshold):
left_mask = X_col <= threshold
y_left, y_right = y[left_mask], y[~left_mask]
if len(y_left) == 0 or len(y_right) == 0:
return torch.tensor(0.0)
n = len(y)
child_ent = (len(y_left)/n) * entropy(y_left) + (len(y_right)/n) * entropy(y_right)
return entropy(y) - child_ent
class DecisionNode:
"""A node in the decision tree."""
def __init__(self, feature=None, threshold=None, left=None, right=None, value=None):
self.feature = feature
self.threshold = threshold
self.left = left
self.right = right
self.value = value # Leaf class prediction
class DecisionTreeClassifier:
"""Decision tree classifier built with PyTorch tensors."""
def __init__(self, max_depth=10, min_samples_split=2):
self.max_depth = max_depth
self.min_samples_split = min_samples_split
self.root = None
def fit(self, X, y):
self.root = self._build_tree(X.float(), y.long(), depth=0)
return self
def _build_tree(self, X, y, depth):
n_samples = len(y)
n_classes = len(torch.unique(y))
# Stopping conditions
if depth >= self.max_depth or n_classes == 1 or n_samples < self.min_samples_split:
leaf_value = torch.mode(y).values.item()
return DecisionNode(value=leaf_value)
# Find best split
best_gain = -1.0
best_feat, best_thresh = 0, 0.0
for feat_idx in range(X.shape[1]):
col = X[:, feat_idx]
thresholds = torch.unique(col)
# Sample thresholds for efficiency
if len(thresholds) > 20:
indices = torch.linspace(0, len(thresholds)-1, 20).long()
thresholds = thresholds[indices]
for t in thresholds:
gain = information_gain(col, y, t).item()
if gain > best_gain:
best_gain = gain
best_feat = feat_idx
best_thresh = t.item()
if best_gain <= 0:
return DecisionNode(value=torch.mode(y).values.item())
# Split and recurse
left_mask = X[:, best_feat] <= best_thresh
left_node = self._build_tree(X[left_mask], y[left_mask], depth + 1)
right_node = self._build_tree(X[~left_mask], y[~left_mask], depth + 1)
return DecisionNode(feature=best_feat, threshold=best_thresh,
left=left_node, right=right_node)
def _predict_one(self, x, node):
if node.value is not None:
return node.value
if x[node.feature] <= node.threshold:
return self._predict_one(x, node.left)
return self._predict_one(x, node.right)
def predict(self, X):
X = X.float()
return torch.tensor([self._predict_one(x, self.root) for x in X])
# Demo
torch.manual_seed(42)
X_train = torch.randn(200, 4)
y_train = ((X_train[:, 0] > 0) & (X_train[:, 1] > 0)).long() # AND boundary
tree = DecisionTreeClassifier(max_depth=5)
tree.fit(X_train, y_train)
preds = tree.predict(X_train)
accuracy = (preds == y_train).float().mean()
print(f"Training accuracy: {accuracy:.4f}")
# Test generalization
X_test = torch.randn(50, 4)
y_test = ((X_test[:, 0] > 0) & (X_test[:, 1] > 0)).long()
test_preds = tree.predict(X_test)
test_acc = (test_preds == y_test).float().mean()
print(f"Test accuracy: {test_acc:.4f}")
Tree Inference Pipeline
max_depth and min_samples_split to control complexity. Cross-validation helps find the sweet spot.
import torch
# Quick depth comparison (using DecisionTreeClassifier from above)
torch.manual_seed(42)
X = torch.randn(300, 5)
y = (X[:, 0]**2 + X[:, 1]**2 < 2).long() # Circular boundary
# Split
X_train, X_test = X[:200], X[200:]
y_train, y_test = y[:200], y[200:]
for depth in [1, 3, 5, 10, 20]:
tree = DecisionTreeClassifier(max_depth=depth)
tree.fit(X_train, y_train)
train_acc = (tree.predict(X_train) == y_train).float().mean()
test_acc = (tree.predict(X_test) == y_test).float().mean()
print(f"Depth={depth:2d}: train={train_acc:.3f}, test={test_acc:.3f}")
Limitations & Extensions
Overfitting
Single decision trees are prone to overfitting — they can perfectly memorize training data by creating one leaf per sample. Techniques to combat this:
- Pruning: Remove branches that don’t improve validation performance
- Max depth: Limit tree growth
- Min samples: Require minimum samples per leaf
Random Forests Preview
The variance problem of individual trees is solved by ensemble methods — Random Forests average many decorrelated trees (via bagging + feature subsampling), while Gradient Boosting sequentially fits trees to residuals.