Introduction
In Codechef GEOMMEAN problem, we need to find the number of subarrays [l, r] of an array A with following property:
A_l * ... * A_r = X^(r-l+1)
This article is about a 1st approach bound to fail, but on the right track.
It is an application of prefix sum + caching. This method usually applies to an array of integers, in 1 dimension. But it actually extends easily to a multi-dimension use case where multiple arrays of integers are required.
Input sample
3
3 3
3 3 3
4 4
1 2 3 4
4 54
36 81 54 54
Output sample
Input parsing
public static void main(String[] args) throws Exception {
//InputStream inputStream = System.in;
InputStream inputStream = new FileInputStream("GEOMMEAN");
BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(inputStream));
PrintWriter writer = new PrintWriter(new BufferedOutputStream(System.out));
String[] tokens;
tokens = bufferedReader.readLine().split(" ");
int T = Integer.parseInt(tokens[0]);
KostomukshaAndAESCMSU kostomukshaAndAESCMSU = new KostomukshaAndAESCMSU();
while (T > 0) {
tokens = bufferedReader.readLine().split(" ");
int N = Integer.parseInt(tokens[0]);
int X = Integer.parseInt(tokens[1]);
tokens = bufferedReader.readLine().split(" ");
int[] A = new int[N];
for (int i = 0; i < N; i++) {
A[i] = Integer.parseInt(tokens[i]);
}
writer.println(kostomukshaAndAESCMSU.subArrays(N, X, A));
T--;
}
writer.close();
inputStream.close();
}
Brute force solutions
As benchmarks, following solutions work in O(N^2)
With BigInteger
Use BigInteger to work with arbitrary size numbers,
public long subArrays2(int N, int X, int[] A) {
long subArrays = 0;
BigInteger x = BigInteger.valueOf(X);
for (int L = 0; L < N; L++) {
BigInteger product = BigInteger.ONE;
BigInteger power = BigInteger.ONE;
for (int R = L; R < N; R++) {
product = product.multiply(BigInteger.valueOf(A[R]));
power = power.multiply(x);
if (product.equals(power)) {
subArrays++;
}
}
}
return subArrays;
}
With modular arithmetics
Multiplying int numbers instead of BigInteger has less overhead. It should run faster. Not sure if this approach is correct ...
private static final int MOD = 1_000_000_007;
ModularArithmetics modularArithmetics = new ModularArithmetics(MOD);
public long subArrays3(int N, int X, int[] A) {
long subArrays = 0;
for (int L = 0; L < N; L++) {
int product = 1;
int power = 1;
for (int R = L; R < N; R++) {
product = modularArithmetics.multiply(product, A[R]);
power = modularArithmetics.multiply(power, X);
if (product == power) {
subArrays++;
}
}
}
return subArrays;
}
Log transform, a failed approach
We apply a logarithmic transform to turn the product into a sum
Sum(log(A_i), i = l ... r) = (r-l+1) log(X)
Call Math.log and cache the values in a Map<Double, Integer>
public long subArrays1(int N, int X, int[] A) {
long subArrays = 0;
Map<Double, Integer> y_1 = new HashMap<>();
y_1.put(0D, 1);
double s = 0;
for (int i = 0; i < N; i++) {
s += Math.log(A[i]);
double y = s - (i+1) * Math.log(X);
int subArrays_i = y_1.getOrDefault(y, 0);
subArrays += subArrays_i;
y_1.put(y, subArrays_i+1);
}
return subArrays;
}
In above implementation, I projected the values on x=-1 axis as normalization. I get following invalid output on the sample input.
Looking at the map
we see a duplicated key because of precision error. In subArrays1 highlighted line of code, multiplication by i+1 caused the discrepancy.
Instead of computing prefix sum only, we calculate the prefix difference sum:
public long subArrays2(int N, int X, int[] A) {
long subArrays = 0;
Map<Double, Integer> y = new HashMap<>();
y.put(0D, 1);
double s = 0;
for (int i = 0; i < N; i++) {
s += Math.log(A[i]) - Math.log(X);
int subArrays_i = y.getOrDefault(s, 0);
subArrays += subArrays_i;
y.put(s, subArrays_i+1);
}
return subArrays;
}
This returns the same map, without the precision error
But this variant eventually fails as well. Here's the counter example search
public void generateTests() {
Random random = new Random();
int N = 3;
int MAX = 10;
while (true) {
int[] A = new int[N];
for (int i = 0; i < N; i++) {
A[i] = 1 + random.nextInt(MAX);
}
int X = 1 + random.nextInt(MAX);
long subArrays2 = subArrays2(N, X, A);
long subArrays3 = subArrays3(N, X, A);
if (subArrays2 != subArrays3) {
System.err.println(String.format("%d != %d %d %s", subArrays2, subArrays3, X, Arrays.toString(A)));
break;
}
}
}
On this input, subArrays2 returned 0 subarrays while there are 2 matches.
1
3 6
9 4 9
log approach is affected by 2 risks:
- false negatives: 2 values don't end-up in the same bucket while they should. For example they evaluate to 0D and -0D, which have different hashCode.
- false positives: 2 values end-up in the same bucket while they should not. They are different values, but because of finite precision they end up being rounded to the same floating number.
Prime Factorization, an exact approach
Reduced problem set
X = exp(log(X))
The subarray is solution of
Sum(e_i, i = l ... r) = (r-l+1) e
Generalization to vector
Sum(e_ij, i = l ... r) = (r-l+1) e_j
The equation is the same as above, but with vectors instead of scalars
Sum(e_i, i = l ... r) = (r-l+1) e
Implementation details
class ExponentVector {
private final int[] exponents;
public ExponentVector(int n) {
exponents = new int[n];
}
public ExponentVector(int[] exponents) {
this(exponents.length);
System.arraycopy(exponents, 0, this.exponents, 0, exponents.length);
}
@Override
public boolean equals(Object o) {
ExponentVector other = (ExponentVector) o;
return Arrays.equals(exponents, other.exponents);
}
@Override
public int hashCode() {
return Arrays.hashCode(exponents);
}
public int size() {
return exponents.length;
}
}
PrimeFactorization primeFactorization = new DivisorPrimeFactorization();
public long subArrays5(int N, int X, int[] A) {
List<PrimeFactor> xFactors = primeFactorization.factors(X);
ExponentVector x = exponents(xFactors);
int n = xFactors.size();
ExponentVector s;
ExponentVector zero = new ExponentVector(n);
Map<ExponentVector, Integer> y_1 = new HashMap<>();
y_1.put(zero, 1);
s = zero;
long subArrays = 0;
for (int i = 0; i < N; i++) {
int[] exponents = new int[n];
for (int j = 0; j < n; j++) {
PrimeFactor primeFactor = xFactors.get(j);
int prime = primeFactor.getPrime();
int exponent = 0;
while (A[i] % prime == 0) {
A[i] /= prime;
exponent++;
}
exponents[j] = exponent;
}
if (A[i] != 1) {
y_1.clear();
y_1.put(zero, 1);
s = zero;
continue;
}
ExponentVector a = new ExponentVector(exponents);
s = subtract(add(s, a), x);
int subArrays_i = y_1.getOrDefault(s, 0);
subArrays += subArrays_i;
y_1.put(s, subArrays_i+1);
}
return subArrays;
}
Conclusion
- simplify the problem
- reduce the scope of the input
- reduce runtime complexity, from O(N^2) down to O(N).
No comments:
Post a Comment
Note: Only a member of this blog may post a comment.