Decision Trees
Learn how machines make decisions by asking a series of yes/no questions — and why the right questions make all the difference.
Every doctor, detective, and loan officer has a secret weapon: a mental checklist of questions they work through to reach a decision. "Is the patient over 60? Does she have chest pain? Is her blood pressure above 140?" Each answer narrows the possibilities until a conclusion clicks into place.
A decision tree teaches a computer to do the same thing. It learns a flowchart of yes/no questions from your data — and uses that flowchart to classify new examples or predict values. Decision trees power spam filters, medical diagnosis tools, and credit scoring systems, and they are the foundation of some of the most powerful ML algorithms in existence.
#Nodes, Branches, and Leaves
A decision tree has three kinds of parts:
- Root node — the very first question asked (the most important one)
- Internal nodes — follow-up questions that split the data further
- Leaf nodes — the endpoints where a final prediction is made, no more questions needed
Each node asks one question about one feature: "Is age > 30?" or "Is income < $50k?". Each answer sends you down a branch — left for yes, right for no — until you land on a leaf that says "Approved" or "Denied". Under the hood, a trained tree is literally just nested if/else statements:
def predict(age, likes_music):
if age >= 18:
if likes_music:
return "Will buy"
else:
return "Won't buy"
else:
return "Won't buy" # under 18, no ticket
print(predict(25, True)) # adult who likes music
print(predict(25, False)) # adult who doesn't like music
print(predict(15, True)) # teen who likes musicTwenty Questions — but Smarter
Think of the childhood game Twenty Questions. A great player picks questions that eliminate as many possibilities as possible with each answer. "Is it bigger than a bread box?" beats "Is it a golden retriever?" because the first question halves the search space.
A decision tree algorithm plays Twenty Questions on your dataset — automatically finding the question at each step that eliminates the most confusion.
#Choosing the Best Split: Gini Impurity
At every node, the tree tries every possible question on every feature and picks the one that creates the purest groups. Purity means: after the split, each group should contain mostly one class.
The most common measure is Gini Impurity. In plain English: "If I randomly pick two examples from this bucket, how often would they be different classes?"
- All examples the same class → Gini = 0 (perfectly pure)
- Half one class, half another → Gini = 0.5 (maximally impure)
The algorithm picks whichever split produces the lowest weighted-average Gini across both child groups.
def gini(labels):
total = len(labels)
if total == 0:
return 0
counts = {}
for label in labels:
counts[label] = counts.get(label, 0) + 1
return 1 - sum((c / total) ** 2 for c in counts.values())
print(gini(["spam", "spam", "spam", "spam"])) # pure
print(gini(["spam", "ham", "spam", "ham"])) # impure
print(gini(["spam", "ham", "ham", "ham"])) # mostly pure#How the Algorithm Builds the Tree
The building process is elegantly recursive — it applies the same logic over and over at each node:
- Start with all training examples at the root.
- Try every feature and every possible threshold ("age > 18?", "age > 25?", "income > 40k?", ...).
- Compute the weighted Gini of the two groups each split would create.
- Pick the split with the lowest weighted Gini — this is the best question.
- Recurse: repeat the whole process on each child group independently.
- Stop when a group is pure, too small, or you've hit a maximum depth.
The result is a tree that separates the training data — which brings us to an important danger.
Deep Trees Memorize, Not Generalize
If you let a tree grow without limits, it will memorize the training data perfectly — including noise and flukes. This is overfitting.
Imagine memorizing every specific question on a practice exam rather than understanding the subject. You'd ace the practice test but bomb the real one.
The fix: limit the tree's depth (e.g., max_depth=5) or prune branches that don't improve accuracy on a held-out validation set. Shallow trees also have a bonus: you can actually print them out and explain every decision to a human.
#Teaser: Random Forests
One tree can be fragile — a small change in the data might produce a very different tree. The solution? Grow many trees and let them vote.
A Random Forest trains hundreds of decision trees, each on a random subset of the data and a random subset of the features. When predicting, every tree casts a vote and the majority wins. The result is dramatically more accurate and robust than any single tree.
Real-world tools like scikit-learn make this one line: RandomForestClassifier(n_estimators=100). But under the hood, it's still just decision trees — exactly what you've learned here. Decision trees also shine when you need a model you can explain to non-technical stakeholders, when your data has mixed feature types (numbers, categories, booleans), or when you just want a quick baseline before trying something fancier.
A decision tree node splits 10 examples into Group A (8 spam, 2 ham) and Group B (1 spam, 9 ham). Compared to the original mixed group of 10 spam and 10 ham, this split is:
Key takeaways
- A decision tree is a flowchart of yes/no questions — nodes ask questions, branches carry answers, leaves make predictions.
- The best split at each node is the question that creates the purest groups, measured by Gini impurity or information gain.
- Deep trees overfit by memorizing training data; limit depth or prune to make the tree generalize to new examples.
- A Random Forest trains many trees on random subsets and combines their votes for much better accuracy and robustness.
- Under the hood, a decision tree is just nested if/else logic — straightforward, explainable, and surprisingly powerful.
All the data — mixed classes.
This hand-written decision tree classifies concert ticket buyers. What does it print?
def predict(age, likes_music):
if age >= 18:
if likes_music:
return "Will buy"
else:
return "Won't buy"
else:
return "Won't buy"
print(predict(30, False))
print(predict(16, True))This code should compute the weighted-average Gini of a split (used to compare candidate questions). This code has a bug — what's wrong?
def gini_split_score(left, right):
total = len(left) + len(right)
weight_left = len(left) / total
weight_right = len(right) / total
# combine the two child groups' impurity
return weight_left * gini(left) - weight_right * gini(right)Complete the weighted-Gini split score. The size of each group divided by the total gives that group's weight. Fill in the denominator.
def gini_split_score(left, right): total = len(left) + len(right) weight_left = len(left) / weight_right = len(right) / total return weight_left * gini(left) + weight_right * gini(right)
Put these steps in the order the algorithm follows to choose the best question at a single node.
Pick the candidate split with the lowest weighted Gini
Split the examples into two groups for each candidate
Try every feature and every possible threshold as a candidate split
Recurse on each child group to build the next level
Compute the weighted Gini of the two groups for each candidate
Implement gini_split_score(left, right) that takes two lists of class labels (the two groups after a split) and returns the weighted Gini impurity of the split — the average Gini of the two groups, weighted by their sizes.
For example, if left has 3 examples and right has 3 examples: weighted_gini = (3/6) * gini(left) + (3/6) * gini(right)
Then compare two candidate splits and print which one is better.
Try it live — edit the code and hit Run to execute real Python: