Introduction
In TREEDIVS problem from Codechef Starters 57 contest, you need to output a node statistics, for each nodes in the tree.
A node v is associated to an integer value A_v.
The node statistics to compute is the number of divisors of the product of the values for each nodes within its subtree.
For v a node in the tree,
Statistics(v) = DivisorCount(TreeProduct(v))
TreeProduct(v) = Product_{u in Tree(v)} A_u
Counting divisors of a number requires its prime factorization.
While my logics was correct, I got Time Limit Exceeded on my submission on a few test cases.
In this article, we will go over multiple optimization strategies to reduce runtime within the time limit.
Refer for the successful submission from another community member from which I got the hints for the optimizations. See also Codechef editorial & Small-to-Large merging USACO article.
The performance for 4 different versions run against 2 different types of inputs are:
Runtime got slightly worst on random input. On slim input, it went down drastically.
Setup
Input format
The input consists of
- N, the tree size
- A, the list of node value
- (u_i, v_i), the list of edges to build adjacency list
3
4
100 101 102 103
1 2
1 3
1 4
4
2 2 2 2
1 2
2 3
3 4
5
43 525 524 12 289
1 2
1 3
3 4
4 5
Expected output
We output the list of node statistics for each node:
192 2 8 2
5 4 3 2
1080 12 60 18 3
Input Parsing
package codechef.starters57.TREEDIVS;
import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.PrintWriter;
public class TreeAndDivisorsMain {
public static void main(String[] args) throws IOException {
TreeAndDivisorsFactory factory = TreeAndDivisors3::new;
//InputStream inputStream = System.in;
InputStream inputStream = new FileInputStream("TREEDIVS");
BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(inputStream));
PrintWriter writer = new PrintWriter(new BufferedOutputStream(System.out));
String[] tokens;
tokens = bufferedReader.readLine().split(" ");
int T = Integer.parseInt(tokens[0]);
while (T > 0) {
tokens = bufferedReader.readLine().split(" ");
int N = Integer.parseInt(tokens[0]);
tokens = bufferedReader.readLine().split(" ");
int[] A = new int[N];
for (int i = 0; i < N; i++) {
A[i] = Integer.parseInt(tokens[i]);
}
TreeAndDivisors treeAndDivisors = factory.createTreeAndDivisors(N, A);
for (int i = 0; i < N-1; i++) {
tokens = bufferedReader.readLine().split(" ");
int u = Integer.parseInt(tokens[0]);
int v = Integer.parseInt(tokens[1]);
treeAndDivisors.addEdge(u-1, v-1);
}
int[] divisorCount = treeAndDivisors.divisors();
String divisorLine = AbstractTreeAndDivisors.listToString(divisorCount);
writer.println(divisorLine);
T--;
}
writer.close();
inputStream.close();
}
}
Load Test
We define an interface for the solution contract:
public interface TreeAndDivisors {
void addEdge(int u, int v);
int[] divisors();
}
We define a factory to decouple running a solution on input data from the implementation itself.
public interface TreeAndDivisorsFactory {
TreeAndDivisors createTreeAndDivisors(int n, int[] A);
}
We define a tree generator to compare the effect of the tree structure shape on runtime:
public interface TreeGenerator {
void generate(List<List<Integer>> adjacency);
}
Unit Test
Here's JUnit 5 parameterized unit test :
@ParameterizedTest
@MethodSource
public void testSolution(int[] A, int[][] edges, int[] expectedDivisors) {
TreeAndDivisorsFactory factory = TreeAndDivisors3::new;
TreeAndDivisors treeAndDivisors = factory.createTreeAndDivisors(A.length, A);
Arrays.stream(edges)
.forEach(edge -> treeAndDivisors.addEdge(edge[0], edge[1]));
Assertions.assertArrayEquals(expectedDivisors, treeAndDivisors.divisors());
}
static Object[][] testSolution() {
return new Object[][] {
new Object[] {
new int[] { 100, 101, 102, 103 },
new int[][] { { 0, 1 }, { 0, 2 }, { 0, 3 } },
new int[] { 192, 2, 8, 2 } }
};
}
Input Generation
We define 3 input generators, for a random distribution and 2 extreme cases.
1. With setRandomAdjacency, the node distribution is random. For each node i, 0 < i < N, its parent is selected at random in [0, ..., i-1].
2. With setSlimAdjacency, the tree degenerates in a list.
3. With setThickAdjacency, a single parent contains all the nodes as direct children.
static void setThickAdjacency(List<List<Integer>> adjacency) {
for (int i = 1; i < adjacency.size(); i++) {
adjacency.get(0).add(i);
}
}
static void setSlimAdjacency(List<List<Integer>> adjacency) {
for (int i = 1; i < adjacency.size(); i++) {
adjacency.get(i-1).add(i);
}
}
static void setRandomAdjacency(List<List<Integer>> adjacency) {
for (int i = 1; i < adjacency.size(); i++) {
int parent = random.nextInt(i);
adjacency.get(parent).add(i);
}
}
Analysis
Recursion enables computing the parent value from its children values. A top down approach via Depth First Search traversal will be able to compute all node statistics. The traversal runs iteratively through a FIFO queue data structure, instead of recursively through a call stack, to avoid potential StackOverflow exception.
Like in this other
problem, the DFS traversal is abstracted away to focus only on the custom logics. This constructs separate the traversal logics from the update logics, which is located in
updateNode.
private void dfs() {
Traversal traversal = new AdjacencyListDFSTraversal(adjacency);
traversal.traverse(this::updateNode);
}
protected abstract void updateNode(int current, int parent, List<Integer> children);
Data structure: Hash Map
When computing a current node product value, we need to keep track of the prime factorization of the product value for each child node.
Each prime factorization is stored as a Map<Integer,Integer> where
- keys are prime numbers
- values are associated exponents
Version 1
Computing the parent hash map requires 1 merge operation per child.
private void updateNode(int current, int parent, List<Integer> children) {
Map<Integer, Integer> parentExponents = primeFactors(A[current]);
children.stream()
.filter(child -> !(child == parent))
.map(child -> primeExponents[child])
.forEach(childExponents -> mergeExponents(parentExponents, childExponents));
primeExponents[current] = parentExponents;
divisorCount[current] = divisorCount(parentExponents);
}
In the slim case where the tree is a long list,
- node n at the bottom requires 0 merges
- node n-1 requires 1 merges
- ...
- node 1 at the top requires n-1 merges
This results in O(N^2) total merges.
However we get an exponential trend in the runtime as N grows:
Then we start getting memory error as we increase N. It is thrown when increasing hash table size in HashMap resize.
Caused by: java.lang.OutOfMemoryError: Java heap space
at java.base/java.util.HashMap.resize(HashMap.java:702)
at java.base/java.util.HashMap.merge(HashMap.java:1363)
at codechef.starters57.TREEDIVS.AbstractTreeAndDivisors.mergeExponents(AbstractTreeAndDivisors.java:85)
at codechef.starters57.TREEDIVS.TreeAndDivisors1.lambda$updateNode$2(TreeAndDivisors1.java:19)
at codechef.starters57.TREEDIVS.TreeAndDivisors1$$Lambda$472/0x0000000800d54248.accept(Unknown Source)
at java.base/java.util.stream.ForEachOps$ForEachOp$OfRef.accept(ForEachOps.java:183)
at java.base/java.util.stream.ReferencePipeline$3$1.accept(ReferencePipeline.java:197)
at java.base/java.util.stream.ReferencePipeline$2$1.accept(ReferencePipeline.java:179)
at java.base/java.util.ArrayList$ArrayListSpliterator.forEachRemaining(ArrayList.java:1625)
at java.base/java.util.stream.AbstractPipeline.copyInto(AbstractPipeline.java:509)
at java.base/java.util.stream.AbstractPipeline.wrapAndCopyInto(AbstractPipeline.java:499)
at java.base/java.util.stream.ForEachOps$ForEachOp.evaluateSequential(ForEachOps.java:150)
at java.base/java.util.stream.ForEachOps$ForEachOp$OfRef.evaluateSequential(ForEachOps.java:173)
at java.base/java.util.stream.AbstractPipeline.evaluate(AbstractPipeline.java:234)
at java.base/java.util.stream.ReferencePipeline.forEach(ReferencePipeline.java:596)
at codechef.starters57.TREEDIVS.TreeAndDivisors1.updateNode(TreeAndDivisors1.java:19)
at codechef.starters57.TREEDIVS.AbstractTreeAndDivisors$$Lambda$466/0x0000000800d4efa0.visit(Unknown Source)
at graph.tree.traversal.AbstractDFSTraversal.traverse(AbstractDFSTraversal.java:63)
at graph.tree.traversal.AbstractDFSTraversal.traverse(AbstractDFSTraversal.java:36)
at codechef.starters57.TREEDIVS.AbstractTreeAndDivisors.dfs(AbstractTreeAndDivisors.java:78)
at codechef.starters57.TREEDIVS.AbstractTreeAndDivisors.divisors(AbstractTreeAndDivisors.java:67)
at codechef.starters57.TREEDIVS.TreeAndDivisorsTest.runInput(TreeAndDivisorsTest.java:89)
at codechef.starters57.TREEDIVS.TreeAndDivisorsTest.runBatch(TreeAndDivisorsTest.java:68)
at codechef.starters57.TREEDIVS.TreeAndDivisorsTest.run(TreeAndDivisorsTest.java:24)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:77)
at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.base/java.lang.reflect.Method.invoke(Method.java:568)
at org.junit.platform.commons.util.ReflectionUtils.invokeMethod(ReflectionUtils.java:725)
at org.junit.jupiter.engine.execution.MethodInvocation.proceed(MethodInvocation.java:60)
at org.junit.jupiter.engine.execution.InvocationInterceptorChain$ValidatingInvocation.proceed(InvocationInterceptorChain.java:131)
at org.junit.jupiter.engine.extension.TimeoutExtension.intercept(TimeoutExtension.java:149)
The original approach had a memory leak, keeping track of every intermediate hash maps, while only the current and children ones were required.
To reduce memory footprint, we should null out children Map references once we no longer need them. It will mark the objects as candidates for Garbage Collection.
@Override
protected void updateNode(int current, int parent, List<Integer> children) {
Map<Integer, Integer> parentExponents = primeFactors(A[current]);
children.stream()
.filter(child -> !(child == parent))
.forEach(child -> {
mergeExponents(parentExponents, primeExponents[child]);
primeExponents[child] = null;
});
primeExponents[current] = parentExponents;
divisorCount[current] = divisorCount(parentExponents);
}
We can now see the O(N^2) trend:
Versions 2 & 3
We can skip a large number of merges by reusing a child object. Selecting the child with the larger size is optimal. Then just reuse the same object to assign it to the current node.
private void updateNode(int current, int parent, List<Integer> children) {
Map<Integer, Integer> valueExponents = primeFactors(A[current]);
Optional<Integer> maxOptional = children.stream()
.filter(child -> child != parent)
.max(Comparator.comparing(child -> primeExponents[child].size()));
Map<Integer, Integer> parentExponents;
if (maxOptional.isPresent()) {
int maxChild = maxOptional.get();
parentExponents = primeExponents[maxChild];
mergeExponents(parentExponents, valueExponents);
children.stream()
.filter(child -> !(child == parent || child == maxChild))
.map(child -> primeExponents[child])
.forEach(childExponents -> mergeExponents(parentExponents, childExponents));
} else {
parentExponents = valueExponents;
}
primeExponents[current] = parentExponents;
divisorCount[current] = divisorCount(parentExponents);
}
Version 4
Computing the results require 2 operations
- merge prime factor exponents
- compute divisor count
Instead of doing them separately, combining them in 1 single pass reduces the number of iterations through the hash map keys by 50%.
We switch to mergeMultiplyExponents instead of mergeExponents to now perform both actions at once.
private void updateNode(int current, int parent, List<Integer> children) {
Map<Integer, Integer> valueExponents = primeFactors(A[current]);
Optional<Integer> maxOptional = children.stream()
.filter(child -> child != parent)
.max(Comparator.comparing(child -> primeExponents[child].size()));
Map<Integer, Integer> parentExponents;
int dc;
if (maxOptional.isPresent()) {
int maxChild = maxOptional.get();
parentExponents = primeExponents[maxChild];
dc = divisorCount[maxChild];
dc = mergeMultiplyExponents(dc, parentExponents, valueExponents);
for (int child: children) {
if (child == parent || child == maxChild) {
continue;
}
dc = mergeMultiplyExponents(dc, parentExponents, primeExponents[child]);
}
} else {
parentExponents = valueExponents;
dc = divisorCount(valueExponents);
}
primeExponents[current] = parentExponents;
divisorCount[current] = dc;
}
Gradle Report
See this commit for the full application code.
We compare multiple versions
- Original
- Null out children maps
- Reuse max child Map
- 50% less Map scans
against 2 inputs
For slim input with N=4000 nodes, the runtimes were
- [27.915s] Original version
- [17.442s] Null out children maps
- [4.536s] Reuse max child Map
- [0.312s] Reuse Map scan for both merging and divisor count computation
For random input with N=30000 nodes, the runtimes were
- [3.562s] Original version
- [3.324s] Null out children maps
- [3.368s] Reuse max child Map
- [3.881s] Reuse Map scan for both merging and divisor count computation
Conclusion
As we merge hash tables bottom-up towards the root, hash table grows significantly. Iterating through the entries is the bottleneck in the code execution. We should optimize towards reducing full hash table scans as much as possible.