Introduction
In TREEDIVS problem from Codechef Starters 57 contest, you need to compute the product of every node value within a node subtree, for each node of the tree.
While the output was correct, I got Time Limit Exceeded on my submission.
In this article, we will go over multiple optimization strategies to keep runtime execution short.
Refer for the successful submission from another community member from which I got the hints for the optimizations. See also Codechef editorial & Small-to-Large merging USACO article.
Setup
Sample input
The input consists of
- N, the tree size
- A, the list of node value
- (u_i, v_i), the list of edges to build adjacency list
3
4
100 101 102 103
1 2
1 3
1 4
4
2 2 2 2
1 2
2 3
3 4
5
43 525 524 12 289
1 2
1 3
3 4
4 5
Expected output
We output the list of node value products for each node:
192 2 8 2
5 4 3 2
1080 12 60 18 3
Input Parsing
package codechef.starters57.TREEDIVS;
import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.PrintWriter;
public class TreeAndDivisorsMain {
public static void main(String[] args) throws IOException {
TreeAndDivisorsFactory factory = TreeAndDivisors3::new;
//InputStream inputStream = System.in;
InputStream inputStream = new FileInputStream("TREEDIVS");
BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(inputStream));
PrintWriter writer = new PrintWriter(new BufferedOutputStream(System.out));
String[] tokens;
tokens = bufferedReader.readLine().split(" ");
int T = Integer.parseInt(tokens[0]);
while (T > 0) {
tokens = bufferedReader.readLine().split(" ");
int N = Integer.parseInt(tokens[0]);
tokens = bufferedReader.readLine().split(" ");
int[] A = new int[N];
for (int i = 0; i < N; i++) {
A[i] = Integer.parseInt(tokens[i]);
}
TreeAndDivisors treeAndDivisors = factory.createTreeAndDivisors(N, A);
for (int i = 0; i < N-1; i++) {
tokens = bufferedReader.readLine().split(" ");
int u = Integer.parseInt(tokens[0]);
int v = Integer.parseInt(tokens[1]);
treeAndDivisors.addEdge(u-1, v-1);
}
int[] divisorCount = treeAndDivisors.divisors();
String divisorLine = AbstractTreeAndDivisors.listToString(divisorCount);
writer.println(divisorLine);
T--;
}
writer.close();
inputStream.close();
}
}
Benchmark
We define an interface for the solution contract:
public interface TreeAndDivisors {
void addEdge(int u, int v);
int[] divisors();
}
We define a factory to decouple running a solution on input data from the implementation itself.
public interface TreeAndDivisorsFactory {
TreeAndDivisors createTreeAndDivisors(int n, int[] A);
}
We define a tree generator to compare the effect of the tree structure shape on runtime:
public interface TreeGenerator {
void generate(List<List<Integer>> adjacency);
}
Unit Test
Here's JUnit 5 parameterized unit test :
@ParameterizedTest
@MethodSource
public void correctness(int[] A, int[][] edges, int[] expectedDivisors) {
TreeAndDivisorsFactory factory = TreeAndDivisors3::new;
TreeAndDivisors treeAndDivisors = factory.createTreeAndDivisors(A.length, A);
Arrays.stream(edges)
.forEach(edge -> treeAndDivisors.addEdge(edge[0], edge[1]));
Assertions.assertArrayEquals(expectedDivisors, treeAndDivisors.divisors());
}
static Object[][] correctness() {
return new Object[][] {
new Object[] {
new int[] { 100, 101, 102, 103 },
new int[][] { { 0, 1 }, { 0, 2 }, { 0, 3 } },
new int[] { 192, 2, 8, 2 } }
};
}
Input Generation
We define 3 input generators:
static void setThickAdjacency(List<List<Integer>> adjacency) {
for (int i = 1; i < adjacency.size(); i++) {
adjacency.get(0).add(i);
}
}
static void setSlimAdjacency(List<List<Integer>> adjacency) {
for (int i = 1; i < adjacency.size(); i++) {
adjacency.get(i-1).add(i);
}
}
static void setRandomAdjacency(List<List<Integer>> adjacency) {
for (int i = 1; i < adjacency.size(); i++) {
int parent = random.nextInt(i);
adjacency.get(parent).add(i);
}
}
With setRandomAdjacency, we generate a random tree structure. For each node i, 0 < i < N, its parent is selected at random in [0, ..., i-1].
With setSlimAdjacency, we generate a list-like tree where each node only has one child.
With setThickAdjacency, we generate a single parent where each nodes other than the root are the children.
Analysis
Depth First Search traversal here is more practical than Breadth First Search. We only need the recursion to compute parent value from the children values.
Algorithm: DFS
private void dfs() {
Traversal traversal = new AdjacencyListDFSTraversal(adjacency);
traversal.traverse(this::updateNode);
}
protected abstract void updateNode(int current, int parent, List<Integer> children);
Data structure: Hash Map
When computing a current node product value, we need to keep track of the prime factorization of the product value for each child node.
Each prime factorization is stored as a Map<Integer,Integer> where
- keys are prime numbers
- values are associated exponents
Version 1
private void updateNode(int current, int parent, List<Integer> children) {
Map<Integer, Integer> parentExponents = primeFactors(A[current]);
children.stream()
.filter(child -> !(child == parent))
.map(child -> primeExponents[child])
.forEach(childExponents -> mergeExponents(parentExponents, childExponents));
primeExponents[current] = parentExponents;
divisorCount[current] = divisorCount(parentExponents);
}
On the slim adjacency use case, assuming each node values are distinct primes,
- node n-1 requires 0 merges
- node n-2 requires 1 merges
- ...
- node 0 requires n-1 merges
This results in O(N^2) total merges.
We see an exponential trend in the runtime as N grows:
We start getting memory error as we increase N. It is thrown when increasing hash table size in HashMap resize.
To reduce memory footprint, we can null out children Map references. It will mark the objects as candidates for Garbage Collection. We no longer need them after computing parent Map.
@Override
protected void updateNode(int current, int parent, List<Integer> children) {
Map<Integer, Integer> parentExponents = primeFactors(A[current]);
children.stream()
.filter(child -> !(child == parent))
.forEach(child -> {
mergeExponents(parentExponents, primeExponents[child]);
primeExponents[child] = null;
});
primeExponents[current] = parentExponents;
divisorCount[current] = divisorCount(parentExponents);
}
We can now see the O(N^2) trend:
Version 2
private void updateNode(int current, int parent, List<Integer> children) {
Map<Integer, Integer> valueExponents = primeFactors(A[current]);
Optional<Integer> maxOptional = children.stream()
.filter(child -> child != parent)
.max(Comparator.comparing(child -> primeExponents[child].size()));
Map<Integer, Integer> parentExponents;
if (maxOptional.isPresent()) {
int maxChild = maxOptional.get();
parentExponents = primeExponents[maxChild];
mergeExponents(parentExponents, valueExponents);
children.stream()
.filter(child -> !(child == parent || child == maxChild))
.map(child -> primeExponents[child])
.forEach(childExponents -> mergeExponents(parentExponents, childExponents));
} else {
parentExponents = valueExponents;
}
primeExponents[current] = parentExponents;
divisorCount[current] = divisorCount(parentExponents);
}
We can skip a high ratio of the total merges for free by reusing a child object. We get most bang for the buck by extracting the child with the max Map size. Then just reuse the same object to assign it to the current node.
It turns out that there's still room for improvement in HashMap iterations. We are iterating through the hash table twice
- merge prime factor exponents
- compute divisor count
Version 3
We can reduce the hash table scans by 50% by performing both actions at once. We switch to mergeMultiplyExponents instead of mergeExponents to now perform both actions at once.
private void updateNode(int current, int parent, List<Integer> children) {
Map<Integer, Integer> valueExponents = primeFactors(A[current]);
Optional<Integer> maxOptional = children.stream()
.filter(child -> child != parent)
.max(Comparator.comparing(child -> primeExponents[child].size()));
Map<Integer, Integer> parentExponents;
int dc;
if (maxOptional.isPresent()) {
int maxChild = maxOptional.get();
parentExponents = primeExponents[maxChild];
dc = divisorCount[maxChild];
dc = mergeMultiplyExponents(dc, parentExponents, valueExponents);
for (int child: children) {
if (child == parent || child == maxChild) {
continue;
}
dc = mergeMultiplyExponents(dc, parentExponents, primeExponents[child]);
}
} else {
parentExponents = valueExponents;
dc = divisorCount(valueExponents);
}
primeExponents[current] = parentExponents;
divisorCount[current] = dc;
}
Gradle Report
See this commit for the full application code.
We compare runs multiple versions against random and slim input types.
Compare execution times for slim input across the versions on N=4000 nodes.
- [2][27.915s] Original version
- [4][17.442s] Null out children maps
- [6][4.536s] Reuse max child Map
- [8][0.312s] Reuse Map scan for both merging and divisor count computation
Test | Method name | Duration | Result |
---|---|---|---|
random, Original, N=30000 | run(String, String, int, TreeGenerator, TreeAndDivisorsFactory)[1] | 3.562s | passed |
slim, Original, N=4000 | run(String, String, int, TreeGenerator, TreeAndDivisorsFactory)[2] | 27.915s | passed |
random, Null out children maps, N=30000 | run(String, String, int, TreeGenerator, TreeAndDivisorsFactory)[3] | 3.324s | passed |
slim, Null out children maps, N=4000 | run(String, String, int, TreeGenerator, TreeAndDivisorsFactory)[4] | 17.442s | passed |
random, Reuse max child Map, N=30000 | run(String, String, int, TreeGenerator, TreeAndDivisorsFactory)[5] | 3.368s | passed |
slim, Reuse max child Map, N=4000 | run(String, String, int, TreeGenerator, TreeAndDivisorsFactory)[6] | 4.536s | passed |
random, 50% less Map scans, N=30000 | run(String, String, int, TreeGenerator, TreeAndDivisorsFactory)[7] | 3.881s | passed |
slim, 50% less Map scans, N=4000 | run(String, String, int, TreeGenerator, TreeAndDivisorsFactory)[8] | 0.312s | passed |
Conclusion
As we merge hash tables bottom-up towards the root, hash table grows significantly. Iterating through the entries is the bottleneck in the code execution. We should optimize towards reducing full hash table scans as much as possible.
Appendix
Plotting code
#!/usr/bin/env python3
import matplotlib.pyplot as plt
import numpy as np
x = [ 100, 500, 1000, 2000, 3000, 4000 ]
y_random = [ 0.054, 0.082, 0.117, 0.204, 0.309, 0.417 ]
y_slim = [ 0.079, 0.394, 1.143, 4.167, 9.741, 26.487 ]
y_slim_null = [ 0.051, 0.371, 1.125, 4.076, 8.883, 16.031 ]
p_random = np.polyfit(x, y_random, 1)
p_slim = np.polyfit(x, np.log(y_slim), 1, w=np.sqrt(y_slim))
p_slim_null = np.polyfit(x, y_slim_null, 2)
x_interpolated = np.arange(0, 4100, 100)
yi_random = np.polyval(p_random, x_interpolated)
yi_slim = np.exp(np.polyval(p_slim, x_interpolated))
yi_slim_null = np.polyval(p_slim_null, x_interpolated)
fig, ax = plt.subplots()
ax.plot(x, y_random, 'r+', label='Random adjacency list')
ax.plot(x, y_slim, 'go', label='Slim adjacency list')
ax.plot(x, y_slim_null, 'bx', label='Slim adjacency list, less memory')
ax.plot(x_interpolated, yi_random, 'r', label='Linear interpolation')
ax.plot(x_interpolated, yi_slim, 'g', label='Exponential interpolation')
ax.plot(x_interpolated, yi_slim_null, 'b', label='Square interpolation')
plt.title('TREEDIVS runtime')
plt.xlabel('N')
plt.ylabel('Time (seconds)')
plt.legend()
plt.show()
Dot graphs
Slim
digraph G {
0 -> 1
1 -> 2
2 -> 3
3 -> 4
4 -> 5
5 -> 6
6 -> 7
7 -> 8
8 -> 9
}
Thick
digraph G {
0 -> 1
0 -> 2
0 -> 3
0 -> 4
0 -> 5
0 -> 6
0 -> 7
0 -> 8
0 -> 9
}
No comments:
Post a Comment
Note: Only a member of this blog may post a comment.