Back to PyTorch Mastery Series

Decision Trees from Scratch

May 29, 2026 Wasil Zafar 25 min read

Build decision trees from first principles using PyTorch tensors — entropy calculation, information gain optimization, recursive partitioning, and a complete tree inference pipeline.

Table of Contents

  1. Decision Tree Intuition
  2. Entropy & Impurity
  3. Information Gain
  4. Recursive Tree Construction
  5. Tree Inference Pipeline
  6. Limitations & Extensions
  7. Related Articles

Decision Tree Intuition

A decision tree learns a hierarchy of if-then rules that recursively partition the feature space. At each internal node, it asks a question about one feature (“Is $x_i \leq t$?”) and routes the data left or right. Leaf nodes contain the final class predictions.

Key Insight: Decision trees are greedy, recursive partitioners. They don’t find the globally optimal tree (NP-hard) — instead they make the locally best split at each step. Despite this, they’re remarkably effective and form the basis for Random Forests and Gradient Boosting.
Decision Tree Structure
flowchart TD
    A["x₁ ≤ 3.5?"] -->|Yes| B["x₂ ≤ 1.2?"]
    A -->|No| C["Class B"]
    B -->|Yes| D["Class A"]
    B -->|No| E["Class B"]
                            

Entropy & Impurity

Entropy measures the disorder (uncertainty) in a set of labels. A pure node (all same class) has entropy 0; maximum entropy occurs when classes are equally distributed:

$$H(S) = -\sum_{k=1}^{K} p_k \log_2 p_k$$

import torch


def entropy(labels):
    """Compute Shannon entropy of a label tensor."""
    if len(labels) == 0:
        return torch.tensor(0.0)
    classes, counts = torch.unique(labels, return_counts=True)
    probs = counts.float() / len(labels)
    # Filter out zero probabilities to avoid log(0)
    probs = probs[probs > 0]
    return -(probs * torch.log2(probs)).sum()


def gini_impurity(labels):
    """Compute Gini impurity of a label tensor."""
    if len(labels) == 0:
        return torch.tensor(0.0)
    classes, counts = torch.unique(labels, return_counts=True)
    probs = counts.float() / len(labels)
    return 1.0 - (probs ** 2).sum()


# Examples
pure = torch.tensor([0, 0, 0, 0, 0])
mixed = torch.tensor([0, 0, 1, 1, 1])
uniform = torch.tensor([0, 1, 2, 3])

print(f"Pure set:    entropy={entropy(pure):.4f}, gini={gini_impurity(pure):.4f}")
print(f"Mixed set:   entropy={entropy(mixed):.4f}, gini={gini_impurity(mixed):.4f}")
print(f"Uniform set: entropy={entropy(uniform):.4f}, gini={gini_impurity(uniform):.4f}")

Information Gain

Information gain measures the reduction in entropy from a split. The best split maximizes this reduction:

$$IG(S, A) = H(S) - \sum_{v} \frac{|S_v|}{|S|} H(S_v)$$

import torch


def entropy(labels):
    """Compute Shannon entropy of a label tensor."""
    if len(labels) == 0:
        return torch.tensor(0.0)
    classes, counts = torch.unique(labels, return_counts=True)
    probs = counts.float() / len(labels)
    probs = probs[probs > 0]
    return -(probs * torch.log2(probs)).sum()


def information_gain(X_column, y, threshold):
    """Compute information gain for a binary split at threshold."""
    parent_entropy = entropy(y)

    # Split
    left_mask = X_column <= threshold
    right_mask = ~left_mask
    y_left = y[left_mask]
    y_right = y[right_mask]

    if len(y_left) == 0 or len(y_right) == 0:
        return torch.tensor(0.0)

    # Weighted child entropy
    n = len(y)
    weighted_entropy = (len(y_left) / n) * entropy(y_left) + \
                       (len(y_right) / n) * entropy(y_right)

    return parent_entropy - weighted_entropy


def find_best_split(X, y):
    """Find the best feature and threshold to split on."""
    n_features = X.shape[1]
    best_gain = torch.tensor(-1.0)
    best_feature = 0
    best_threshold = 0.0

    for feature_idx in range(n_features):
        column = X[:, feature_idx]
        thresholds = torch.unique(column)

        for threshold in thresholds:
            gain = information_gain(column, y, threshold)
            if gain > best_gain:
                best_gain = gain
                best_feature = feature_idx
                best_threshold = threshold.item()

    return best_feature, best_threshold, best_gain


# Demo
torch.manual_seed(42)
X = torch.randn(100, 3)
y = (X[:, 0] + X[:, 1] > 0).long()  # Simple linear boundary

feat, thresh, gain = find_best_split(X, y)
print(f"Best split: feature {feat}, threshold {thresh:.3f}, gain {gain:.4f}")

Recursive Tree Construction

import torch


def entropy(labels):
    if len(labels) == 0:
        return torch.tensor(0.0)
    classes, counts = torch.unique(labels, return_counts=True)
    probs = counts.float() / len(labels)
    probs = probs[probs > 0]
    return -(probs * torch.log2(probs)).sum()


def information_gain(X_col, y, threshold):
    left_mask = X_col <= threshold
    y_left, y_right = y[left_mask], y[~left_mask]
    if len(y_left) == 0 or len(y_right) == 0:
        return torch.tensor(0.0)
    n = len(y)
    child_ent = (len(y_left)/n) * entropy(y_left) + (len(y_right)/n) * entropy(y_right)
    return entropy(y) - child_ent


class DecisionNode:
    """A node in the decision tree."""
    def __init__(self, feature=None, threshold=None, left=None, right=None, value=None):
        self.feature = feature
        self.threshold = threshold
        self.left = left
        self.right = right
        self.value = value  # Leaf class prediction


class DecisionTreeClassifier:
    """Decision tree classifier built with PyTorch tensors."""

    def __init__(self, max_depth=10, min_samples_split=2):
        self.max_depth = max_depth
        self.min_samples_split = min_samples_split
        self.root = None

    def fit(self, X, y):
        self.root = self._build_tree(X.float(), y.long(), depth=0)
        return self

    def _build_tree(self, X, y, depth):
        n_samples = len(y)
        n_classes = len(torch.unique(y))

        # Stopping conditions
        if depth >= self.max_depth or n_classes == 1 or n_samples < self.min_samples_split:
            leaf_value = torch.mode(y).values.item()
            return DecisionNode(value=leaf_value)

        # Find best split
        best_gain = -1.0
        best_feat, best_thresh = 0, 0.0

        for feat_idx in range(X.shape[1]):
            col = X[:, feat_idx]
            thresholds = torch.unique(col)
            # Sample thresholds for efficiency
            if len(thresholds) > 20:
                indices = torch.linspace(0, len(thresholds)-1, 20).long()
                thresholds = thresholds[indices]

            for t in thresholds:
                gain = information_gain(col, y, t).item()
                if gain > best_gain:
                    best_gain = gain
                    best_feat = feat_idx
                    best_thresh = t.item()

        if best_gain <= 0:
            return DecisionNode(value=torch.mode(y).values.item())

        # Split and recurse
        left_mask = X[:, best_feat] <= best_thresh
        left_node = self._build_tree(X[left_mask], y[left_mask], depth + 1)
        right_node = self._build_tree(X[~left_mask], y[~left_mask], depth + 1)

        return DecisionNode(feature=best_feat, threshold=best_thresh,
                          left=left_node, right=right_node)

    def _predict_one(self, x, node):
        if node.value is not None:
            return node.value
        if x[node.feature] <= node.threshold:
            return self._predict_one(x, node.left)
        return self._predict_one(x, node.right)

    def predict(self, X):
        X = X.float()
        return torch.tensor([self._predict_one(x, self.root) for x in X])


# Demo
torch.manual_seed(42)
X_train = torch.randn(200, 4)
y_train = ((X_train[:, 0] > 0) & (X_train[:, 1] > 0)).long()  # AND boundary

tree = DecisionTreeClassifier(max_depth=5)
tree.fit(X_train, y_train)
preds = tree.predict(X_train)
accuracy = (preds == y_train).float().mean()
print(f"Training accuracy: {accuracy:.4f}")

# Test generalization
X_test = torch.randn(50, 4)
y_test = ((X_test[:, 0] > 0) & (X_test[:, 1] > 0)).long()
test_preds = tree.predict(X_test)
test_acc = (test_preds == y_test).float().mean()
print(f"Test accuracy: {test_acc:.4f}")

Tree Inference Pipeline

Depth vs. Performance: Deeper trees fit training data better but generalize worse. Use max_depth and min_samples_split to control complexity. Cross-validation helps find the sweet spot.
import torch


# Quick depth comparison (using DecisionTreeClassifier from above)
torch.manual_seed(42)
X = torch.randn(300, 5)
y = (X[:, 0]**2 + X[:, 1]**2 < 2).long()  # Circular boundary

# Split
X_train, X_test = X[:200], X[200:]
y_train, y_test = y[:200], y[200:]

for depth in [1, 3, 5, 10, 20]:
    tree = DecisionTreeClassifier(max_depth=depth)
    tree.fit(X_train, y_train)
    train_acc = (tree.predict(X_train) == y_train).float().mean()
    test_acc = (tree.predict(X_test) == y_test).float().mean()
    print(f"Depth={depth:2d}: train={train_acc:.3f}, test={test_acc:.3f}")

Limitations & Extensions

Overfitting

Single decision trees are prone to overfitting — they can perfectly memorize training data by creating one leaf per sample. Techniques to combat this:

  • Pruning: Remove branches that don’t improve validation performance
  • Max depth: Limit tree growth
  • Min samples: Require minimum samples per leaf

Random Forests Preview

The variance problem of individual trees is solved by ensemble methods — Random Forests average many decorrelated trees (via bagging + feature subsampling), while Gradient Boosting sequentially fits trees to residuals.