Monday, October 3, 2022

Codechef Starters 57, Tree and Divisors

Introduction


In TREEDIVS problem from Codechef Starters 57 contest, you need to compute the product of every node value within a node subtree, for each node of the tree.

While the output was correct, I got Time Limit Exceeded on my submission.





In this article, we will go over multiple optimization strategies to keep runtime execution short.

Refer for the successful submission from another community member from which I got the hints for the optimizations. See also Codechef editorial & Small-to-Large merging USACO article.


Setup


Sample input

The input consists of

  1. N, the tree size
  2. A, the list of node value
  3. (u_i, v_i), the list of edges to build adjacency list


3
4
100 101 102 103
1 2
1 3
1 4
4
2 2 2 2
1 2
2 3
3 4
5
43 525 524 12 289
1 2
1 3
3 4
4 5


Expected output


We output the list of node value products for each node:


192 2 8 2 
5 4 3 2
1080 12 60 18 3


Input Parsing


package codechef.starters57.TREEDIVS;

import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.PrintWriter;

public class TreeAndDivisorsMain {
public static void main(String[] args) throws IOException {
TreeAndDivisorsFactory factory = TreeAndDivisors3::new;

//InputStream inputStream = System.in;
InputStream inputStream = new FileInputStream("TREEDIVS");
BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(inputStream));
PrintWriter writer = new PrintWriter(new BufferedOutputStream(System.out));

String[] tokens;

tokens = bufferedReader.readLine().split(" ");
int T = Integer.parseInt(tokens[0]);
while (T > 0) {
tokens = bufferedReader.readLine().split(" ");
int N = Integer.parseInt(tokens[0]);
tokens = bufferedReader.readLine().split(" ");

int[] A = new int[N];
for (int i = 0; i < N; i++) {
A[i] = Integer.parseInt(tokens[i]);
}

TreeAndDivisors treeAndDivisors = factory.createTreeAndDivisors(N, A);
for (int i = 0; i < N-1; i++) {
tokens = bufferedReader.readLine().split(" ");
int u = Integer.parseInt(tokens[0]);
int v = Integer.parseInt(tokens[1]);
treeAndDivisors.addEdge(u-1, v-1);
}

int[] divisorCount = treeAndDivisors.divisors();
String divisorLine = AbstractTreeAndDivisors.listToString(divisorCount);
writer.println(divisorLine);

T--;
}

writer.close();
inputStream.close();
}
}


Benchmark


We define an interface for the solution contract:


public interface TreeAndDivisors {
void addEdge(int u, int v);
int[] divisors();
}


We define a factory to decouple running a solution on input data from the implementation itself.


public interface TreeAndDivisorsFactory {
TreeAndDivisors createTreeAndDivisors(int n, int[] A);
}


We define a tree generator to compare the effect of the tree structure shape on runtime:


public interface TreeGenerator {
void generate(List<List<Integer>> adjacency);
}


Unit Test


Here's JUnit 5 parameterized unit test :


@ParameterizedTest
@MethodSource
public void correctness(int[] A, int[][] edges, int[] expectedDivisors) {
TreeAndDivisorsFactory factory = TreeAndDivisors3::new;
TreeAndDivisors treeAndDivisors = factory.createTreeAndDivisors(A.length, A);
Arrays.stream(edges)
.forEach(edge -> treeAndDivisors.addEdge(edge[0], edge[1]));
Assertions.assertArrayEquals(expectedDivisors, treeAndDivisors.divisors());
}

static Object[][] correctness() {
return new Object[][] {
new Object[] {
new int[] { 100, 101, 102, 103 },
new int[][] { { 0, 1 }, { 0, 2 }, { 0, 3 } },
new int[] { 192, 2, 8, 2 } }
};
}


Input Generation


We define 3 input generators:


static void setThickAdjacency(List<List<Integer>> adjacency) {
for (int i = 1; i < adjacency.size(); i++) {
adjacency.get(0).add(i);
}
}

static void setSlimAdjacency(List<List<Integer>> adjacency) {
for (int i = 1; i < adjacency.size(); i++) {
adjacency.get(i-1).add(i);
}
}

static void setRandomAdjacency(List<List<Integer>> adjacency) {
for (int i = 1; i < adjacency.size(); i++) {
int parent = random.nextInt(i);
adjacency.get(parent).add(i);
}
}


With setRandomAdjacency, we generate a random tree structure. For each node i, 0 < i < N, its parent is selected at random in [0, ..., i-1].


With setSlimAdjacency, we generate a list-like tree where each node only has one child.




With setThickAdjacency, we generate a single parent where each nodes other than the root are the children.





Analysis


Depth First Search traversal here is more practical than Breadth First Search. We only need the recursion to compute parent value from the children values.


Algorithm: DFS


Like in this other problem, we can setup a node visitor in DFS traversal. It abstracts away DFS implementation logics. We can then focus only on the node value update logics, located in updateNode method.

private void dfs() {
Traversal traversal = new AdjacencyListDFSTraversal(adjacency);
traversal.traverse(this::updateNode);
}

protected abstract void updateNode(int current, int parent, List<Integer> children);

Data structure: Hash Map


When computing a current node product value, we need to keep track of the prime factorization of the product value for each child node.

Each prime factorization is stored as a Map<Integer,Integer> where

  • keys are prime numbers
  • values are associated exponents


Version 1


private void updateNode(int current, int parent, List<Integer> children) {
Map<Integer, Integer> parentExponents = primeFactors(A[current]);

children.stream()
.filter(child -> !(child == parent))
.map(child -> primeExponents[child])
.forEach(childExponents -> mergeExponents(parentExponents, childExponents));

primeExponents[current] = parentExponents;
divisorCount[current] = divisorCount(parentExponents);
}

On the slim adjacency use case, assuming each node values are distinct primes,

  • node n-1 requires 0 merges
  • node n-2 requires 1 merges
  • ...
  • node 0 requires n-1 merges

This results in O(N^2) total merges.

We see an exponential trend in the runtime as N grows:




We start getting memory error as we increase N. It is thrown when increasing hash table size in HashMap resize.


Caused by: java.lang.OutOfMemoryError: Java heap space
at java.base/java.util.HashMap.resize(HashMap.java:702)
at java.base/java.util.HashMap.merge(HashMap.java:1363)
at codechef.starters57.TREEDIVS.AbstractTreeAndDivisors.mergeExponents(AbstractTreeAndDivisors.java:85)
at codechef.starters57.TREEDIVS.TreeAndDivisors1.lambda$updateNode$2(TreeAndDivisors1.java:19)
at codechef.starters57.TREEDIVS.TreeAndDivisors1$$Lambda$472/0x0000000800d54248.accept(Unknown Source)
at java.base/java.util.stream.ForEachOps$ForEachOp$OfRef.accept(ForEachOps.java:183)
at java.base/java.util.stream.ReferencePipeline$3$1.accept(ReferencePipeline.java:197)
at java.base/java.util.stream.ReferencePipeline$2$1.accept(ReferencePipeline.java:179)
at java.base/java.util.ArrayList$ArrayListSpliterator.forEachRemaining(ArrayList.java:1625)
at java.base/java.util.stream.AbstractPipeline.copyInto(AbstractPipeline.java:509)
at java.base/java.util.stream.AbstractPipeline.wrapAndCopyInto(AbstractPipeline.java:499)
at java.base/java.util.stream.ForEachOps$ForEachOp.evaluateSequential(ForEachOps.java:150)
at java.base/java.util.stream.ForEachOps$ForEachOp$OfRef.evaluateSequential(ForEachOps.java:173)
at java.base/java.util.stream.AbstractPipeline.evaluate(AbstractPipeline.java:234)
at java.base/java.util.stream.ReferencePipeline.forEach(ReferencePipeline.java:596)
at codechef.starters57.TREEDIVS.TreeAndDivisors1.updateNode(TreeAndDivisors1.java:19)
at codechef.starters57.TREEDIVS.AbstractTreeAndDivisors$$Lambda$466/0x0000000800d4efa0.visit(Unknown Source)
at graph.tree.traversal.AbstractDFSTraversal.traverse(AbstractDFSTraversal.java:63)
at graph.tree.traversal.AbstractDFSTraversal.traverse(AbstractDFSTraversal.java:36)
at codechef.starters57.TREEDIVS.AbstractTreeAndDivisors.dfs(AbstractTreeAndDivisors.java:78)
at codechef.starters57.TREEDIVS.AbstractTreeAndDivisors.divisors(AbstractTreeAndDivisors.java:67)
at codechef.starters57.TREEDIVS.TreeAndDivisorsTest.runInput(TreeAndDivisorsTest.java:89)
at codechef.starters57.TREEDIVS.TreeAndDivisorsTest.runBatch(TreeAndDivisorsTest.java:68)
at codechef.starters57.TREEDIVS.TreeAndDivisorsTest.run(TreeAndDivisorsTest.java:24)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:77)
at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.base/java.lang.reflect.Method.invoke(Method.java:568)
at org.junit.platform.commons.util.ReflectionUtils.invokeMethod(ReflectionUtils.java:725)
at org.junit.jupiter.engine.execution.MethodInvocation.proceed(MethodInvocation.java:60)
at org.junit.jupiter.engine.execution.InvocationInterceptorChain$ValidatingInvocation.proceed(InvocationInterceptorChain.java:131)
at org.junit.jupiter.engine.extension.TimeoutExtension.intercept(TimeoutExtension.java:149)


To reduce memory footprint, we can null out children Map references. It will mark the objects as candidates for Garbage Collection. We no longer need them after computing parent Map.


@Override
protected void updateNode(int current, int parent, List<Integer> children) {
Map<Integer, Integer> parentExponents = primeFactors(A[current]);

children.stream()
.filter(child -> !(child == parent))
.forEach(child -> {
mergeExponents(parentExponents, primeExponents[child]);
primeExponents[child] = null;
});

primeExponents[current] = parentExponents;
divisorCount[current] = divisorCount(parentExponents);
}


We can now see the O(N^2) trend:



Version 2


private void updateNode(int current, int parent, List<Integer> children) {
Map<Integer, Integer> valueExponents = primeFactors(A[current]);

Optional<Integer> maxOptional = children.stream()
.filter(child -> child != parent)
.max(Comparator.comparing(child -> primeExponents[child].size()));

Map<Integer, Integer> parentExponents;

if (maxOptional.isPresent()) {
int maxChild = maxOptional.get();

parentExponents = primeExponents[maxChild];
mergeExponents(parentExponents, valueExponents);
children.stream()
.filter(child -> !(child == parent || child == maxChild))
.map(child -> primeExponents[child])
.forEach(childExponents -> mergeExponents(parentExponents, childExponents));
} else {
parentExponents = valueExponents;
}

primeExponents[current] = parentExponents;
divisorCount[current] = divisorCount(parentExponents);
}


We can skip a high ratio of the total merges for free by reusing a child object. We get most bang for the buck by extracting the child with the max Map size. Then just reuse the same object to assign it to the current node.


It turns out that there's still room for improvement in HashMap iterations. We are iterating through the hash table twice

  1. merge prime factor exponents
  2. compute divisor count


Version 3


We can reduce the hash table scans by 50% by performing both actions at once. We switch to mergeMultiplyExponents instead of mergeExponents to now perform both actions at once.


private void updateNode(int current, int parent, List<Integer> children) {
Map<Integer, Integer> valueExponents = primeFactors(A[current]);

Optional<Integer> maxOptional = children.stream()
.filter(child -> child != parent)
.max(Comparator.comparing(child -> primeExponents[child].size()));

Map<Integer, Integer> parentExponents;
int dc;

if (maxOptional.isPresent()) {
int maxChild = maxOptional.get();

parentExponents = primeExponents[maxChild];
dc = divisorCount[maxChild];

dc = mergeMultiplyExponents(dc, parentExponents, valueExponents);

for (int child: children) {
if (child == parent || child == maxChild) {
continue;
}
dc = mergeMultiplyExponents(dc, parentExponents, primeExponents[child]);
}
} else {
parentExponents = valueExponents;
dc = divisorCount(valueExponents);
}

primeExponents[current] = parentExponents;
divisorCount[current] = dc;
}


Gradle Report


See this commit for the full application code.

We compare runs multiple versions against random and slim input types.

Compare execution times for slim input across the versions on N=4000 nodes.

  1. [2][27.915s] Original version
  2. [4][17.442s] Null out children maps
  3. [6][4.536s] Reuse max child Map
  4. [8][0.312s] Reuse Map scan for both merging and divisor count computation


TestMethod nameDurationResult
random, Original, N=30000run(String, String, int, TreeGenerator, TreeAndDivisorsFactory)[1]3.562spassed
slim, Original, N=4000run(String, String, int, TreeGenerator, TreeAndDivisorsFactory)[2]27.915spassed
random, Null out children maps, N=30000run(String, String, int, TreeGenerator, TreeAndDivisorsFactory)[3]3.324spassed
slim, Null out children maps, N=4000run(String, String, int, TreeGenerator, TreeAndDivisorsFactory)[4]17.442spassed
random, Reuse max child Map, N=30000run(String, String, int, TreeGenerator, TreeAndDivisorsFactory)[5]3.368spassed
slim, Reuse max child Map, N=4000run(String, String, int, TreeGenerator, TreeAndDivisorsFactory)[6]4.536spassed
random, 50% less Map scans, N=30000run(String, String, int, TreeGenerator, TreeAndDivisorsFactory)[7]3.881spassed
slim, 50% less Map scans, N=4000run(String, String, int, TreeGenerator, TreeAndDivisorsFactory)[8]0.312spassed


See codechef submissions for 3 versions

  1. Version 1
  2. Version 2
  3. Version 3





Conclusion


As we merge hash tables bottom-up towards the root, hash table grows significantly. Iterating through the entries is the bottleneck in the code execution. We should optimize towards reducing full hash table scans as much as possible.



Appendix


Plotting code


#!/usr/bin/env python3

import matplotlib.pyplot as plt
import numpy as np

x = [ 100, 500, 1000, 2000, 3000, 4000 ]

y_random = [ 0.054, 0.082, 0.117, 0.204, 0.309, 0.417 ]
y_slim = [ 0.079, 0.394, 1.143, 4.167, 9.741, 26.487 ]
y_slim_null = [ 0.051, 0.371, 1.125, 4.076, 8.883, 16.031 ]

p_random = np.polyfit(x, y_random, 1)
p_slim = np.polyfit(x, np.log(y_slim), 1, w=np.sqrt(y_slim))
p_slim_null = np.polyfit(x, y_slim_null, 2)

x_interpolated = np.arange(0, 4100, 100)
yi_random = np.polyval(p_random, x_interpolated)
yi_slim = np.exp(np.polyval(p_slim, x_interpolated))
yi_slim_null = np.polyval(p_slim_null, x_interpolated)

fig, ax = plt.subplots()
ax.plot(x, y_random, 'r+', label='Random adjacency list')
ax.plot(x, y_slim, 'go', label='Slim adjacency list')
ax.plot(x, y_slim_null, 'bx', label='Slim adjacency list, less memory')
ax.plot(x_interpolated, yi_random, 'r', label='Linear interpolation')
ax.plot(x_interpolated, yi_slim, 'g', label='Exponential interpolation')
ax.plot(x_interpolated, yi_slim_null, 'b', label='Square interpolation')

plt.title('TREEDIVS runtime')
plt.xlabel('N')
plt.ylabel('Time (seconds)')
plt.legend()
plt.show()

Dot graphs

Slim

digraph G {

  0 -> 1

  1 -> 2

  2 -> 3

  3 -> 4

  4 -> 5

  5 -> 6

  6 -> 7

  7 -> 8

  8 -> 9

} 

Thick

digraph G {

  0 -> 1

  0 -> 2

  0 -> 3

  0 -> 4

  0 -> 5

  0 -> 6

  0 -> 7

  0 -> 8

  0 -> 9

}


No comments:

Post a Comment

Note: Only a member of this blog may post a comment.