Problem Walkthrough

Count Complete Tree Nodes LeetCode 222

Exploit the structure of a complete binary tree by comparing left and right subtree heights to detect perfect subtrees — enabling O(log^2 n) node counting without visiting every node.

8 min read|

Count Complete Tree Nodes

LeetCode 222

Problem Overview

LeetCode 222 asks you to count the total number of nodes in a complete binary tree given its root. The brute-force approach visits every node in O(n), but the problem constraints hint at something better: a complete binary tree has a highly predictable structure that we can exploit.

A complete binary tree is one where all levels are fully filled except possibly the last level, which is filled from left to right. This definition is the key to the entire solution — the structure guarantees that for any node, either the left or right subtree is a perfect binary tree.

The challenge is to count nodes faster than O(n). The expected complexity is O(log^2 n), which means we should be skipping large subtrees entirely rather than visiting each node individually.

  • Given root of a complete binary tree — count total nodes
  • Must be faster than O(n) naive traversal
  • Complete tree: all levels full except last, which fills left to right
  • Last level nodes always fill from left to right — no gaps in the middle
  • Key constraint: at every node, one subtree is always a perfect binary tree

O(n) vs O(log^2 n)

The naive solution traverses every node and increments a counter — this is O(n) and ignores everything special about the complete binary tree structure. For a tree with one million nodes, this means one million recursive calls.

The O(log^2 n) approach exploits the fact that in a complete binary tree, large subtrees are often perfect binary trees. A perfect binary tree with height h has exactly 2^h - 1 nodes. If we can detect that a subtree is perfect in O(log n) time, we can count its nodes in O(1) and skip the traversal entirely.

The detection mechanism is elegant: compute the left height (follow left children until null) and the right height (follow right children until null). If they are equal, the subtree is perfect. If they differ by one, it is not perfect and we recurse into both children.

💡

The key insight: equal heights mean a perfect subtree

In a complete binary tree, if the leftmost height equals the rightmost height, the subtree rooted at that node is a PERFECT binary tree with exactly 2^h - 1 nodes. You do not need to traverse it — return 2^h - 1 directly. This single check is what transforms the O(n) traversal into O(log^2 n).

Height Comparison

The algorithm computes two heights at each node: the left height by following left children to the bottom, and the right height by following right children to the bottom. Both computations take O(log n) time because the tree has O(log n) levels.

If left height equals right height, the tree rooted at this node is a perfect binary tree. Return (1 << h) - 1, which is 2^h - 1 using a bit shift. This is an O(1) count for potentially millions of nodes.

If the heights differ (left height is always >= right height in a complete tree), the subtree is not perfect. Return 1 (for the current node) plus countNodes(left) plus countNodes(right). One of these recursive calls will immediately return via the perfect-subtree shortcut.

  1. 1Compute left height: start at root, follow left children until null, count depth
  2. 2Compute right height: start at root, follow right children until null, count depth
  3. 3If left height == right height: return (1 << leftHeight) - 1 — perfect tree, O(1) count
  4. 4Else: return 1 + countNodes(node.left) + countNodes(node.right)
  5. 5Recurse — at each level, one subtree resolves immediately as a perfect tree

Why O(log^2 n)

At each recursive call, computing the left and right heights costs O(log n) — we traverse at most one full path to a leaf. The question is how many recursive calls are made before we hit a perfect-subtree base case.

At every level of recursion, exactly one of the two subtrees is a perfect binary tree (by the definition of a complete binary tree). That subtree is counted in O(1) via the bit shift formula. Only the other subtree recurses further. This means we recurse at most O(log n) levels deep.

Total cost: O(log n) recursive levels times O(log n) height computation per level equals O(log^2 n). For n = 1,000,000 nodes, log^2 n is approximately 400 — a massive improvement over 1,000,000 for O(n).

ℹ️

Recursion depth is O(log n) because one subtree is always perfect

At each recursive level, one subtree has equal left and right heights (perfect — counted in O(1)) and the other subtree recurses. Since the recursing subtree is at most half the original size, and the tree has O(log n) levels, the recursion depth is O(log n). Each level does O(log n) work for height computation, giving O(log n) × O(log n) = O(log^2 n) total.

Walk-Through Example

Consider a complete binary tree with 6 nodes arranged as: root (1), left child (2), right child (3), left-left (4), left-right (5), right-left (6). Level 1 and 2 are full; level 3 has only node 6 on the left.

At the root: left height = 3 (follow 1→2→4), right height = 2 (follow 1→3→6). Heights differ, so we recurse into both subtrees.

Left subtree rooted at 2: left height = 2 (follow 2→4), right height = 2 (follow 2→5). Equal — perfect tree with h=2, return (1 << 2) - 1 = 3. Right subtree rooted at 3: left height = 1 (follow 3→6), right height = 0 (3→null). Differ — return 1 + countNodes(6) + countNodes(null) = 1 + 1 + 0 = 2. Total = 1 + 3 + 2 = 6. Correct.

  1. 1Root (1): leftH=3, rightH=2 — not equal, recurse both
  2. 2Node 2 (left): leftH=2, rightH=2 — equal, perfect, return (1<<2)-1 = 3
  3. 3Node 3 (right): leftH=1, rightH=0 — not equal, recurse both
  4. 4Node 6 (left of 3): leftH=0, rightH=0 — equal, return (1<<0)-1+1 = 1
  5. 5Total: 1 (root) + 3 (left subtree) + 2 (right subtree) = 6

Code Walkthrough Python and Java

Python implementation: define a helper getHeight(node, goLeft) that follows either left or right children until null, returning the depth. In countNodes, compute leftH = getHeight(root, True) and rightH = getHeight(root, False). If equal, return (1 << leftH) - 1. Else return 1 + countNodes(root.left) + countNodes(root.right). Time O(log^2 n), space O(log n) for the recursion stack.

Java implementation mirrors the Python approach. Use a private int height(TreeNode node, boolean left) helper. The bit shift (1 << leftH) - 1 is identical in Java and avoids integer overflow for trees up to height 30 (LeetCode constraint: at most 5×10^4 nodes, so height <= 16). Always use integer bit shift — avoid Math.pow which returns double.

Both solutions handle the base case naturally: a null root has leftH = 0 = rightH, so (1 << 0) - 1 = 0 is returned without any explicit null check needed in the recursive branch. The recursion terminates cleanly because each call reduces the tree size by at least half.

⚠️

Use (1 << h) - 1 not pow(2, h) - 1 for node count

Always use the bit shift (1 << h) - 1 to compute 2^h - 1 node count for a perfect subtree. In Python, pow(2, h) - 1 returns an integer and is safe, but it is slower than bit shift which is O(1). In Java and C++, pow(2, h) returns a double and can lose precision for large h. Bit shift is O(1), exact, and works in all languages — make it your default for powers of 2.

Ready to master algorithm patterns?

YeetCode flashcards help you build pattern recognition through active recall and spaced repetition.

Start practicing now