Thursday, March 24, 2022

Codechef SUB_XOR

Introduction

This article describes an invalid approach to Codechef SUB_XOR problem from March Long challenge 1.

If failed task #3 in the test suite:

Let's find out why.


Problem statement

We map a string, containing only '0' or '1' characters, to the number it represents in binary format. 
Given such a binary string, find the XOR value of all its binary substrings, modulo 998 244 353.


Constraints

  • 1 <= N <= 10^5
  • 1 <= T <= 10^2


Input Parsing

package codechef.mar22;

import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.PrintWriter;

class SubarrayXOR {
private static final int MOD = 998_244_353;

public static void main(String[] args) throws Exception {
//InputStream inputStream = System.in;
InputStream inputStream = new FileInputStream("SUB_XOR");
BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(inputStream));
PrintWriter writer = new PrintWriter(new BufferedOutputStream(System.out));

String[] tokens;
tokens = bufferedReader.readLine().split(" ");
int T = Integer.parseInt(tokens[0]);
while (T > 0) {
tokens = bufferedReader.readLine().split(" ");
int N = Integer.parseInt(tokens[0]);
tokens = bufferedReader.readLine().split(" ");
String S = tokens[0];

SubarrayXOR subarrayXOR = new SubarrayXOR();
writer.println(subarrayXOR.beauty(S));
T--;
}
writer.close();
inputStream.close();
}
}


Input

Here's an input sample:
3
2
10
3
101
4
1111


Output

Here's the output:
3
6
12


Brute-force solution

public int beauty2(String S) {
int N = S.length();
int xor = 0;
for (int n = 1; n <= N; n++) {
for (int i = 0; i < N-n+1; i++) {
int subStringNumber = toNumber(S, i, n);
xor ^= subStringNumber;
}
}
return xor;
}

private int toNumber(String S, int i, int n) {
int number = 0;
for (int j = i+n-1; j >= i; j--) {
if (S.charAt(j) == '1') {
number += 1 << (i+n-1 - j);
}
}
return number;
}

For all substring sizes n1 <= n <= N, convert the substring of size n starting at index i , 0 <= i < N-n+1 to the associated number and apply xor operation.

Make sure to scan the characters from right to left by starting with the last character, since strings are parsed left to right.

Complexity is O(N^3), which will timeout.


My solution

private static final ModularArithmetics MODULAR_ARITHMETICS = new ModularArithmetics(MOD);

public int beauty(String S) {
int N = S.length();

int count = 0;
int sum = 0;
int[] c = new int[N];
int[] s = new int[N];
for (int i = 0; i < N; i++) {
s[i] = sum;
boolean bit = S.charAt(i) == '1';
if (bit) {
count++;
}
sum = MODULAR_ARITHMETICS.add(sum, count);
c[i] = count;
}

int beauty = 0;
for (int i = 0; i < N; i++) {
int xor = MODULAR_ARITHMETICS.substract(MODULAR_ARITHMETICS.multiply(N-i, c[N-1-i]), s[N-1-i]);
beauty = MODULAR_ARITHMETICS.add(beauty, (xor % 2) * MODULAR_ARITHMETICS.exponent(2, i));
}
return beauty;
}
ModularArithmetics is a utility to add, substract, multiply and exponentiate values modulo a prime number. See appendix.

We estimate the number of substrings with a '1' character contributing to bit i, 0 <= i < N, where N is the input string size. This number is denoted cnt_i in the editorial. If it's odd we enable the i-th bit in the output value.

I applied following formula:

cnt_i = (N-i) * c[N-1-i] - s[N-1-i]

with

  • c[j] = Sum_{S[i], 0 <= i <= j}
  • s[j] = Sum_{c[i], 0 <= i < j}

It leverages prefix count and sum arrays defined above.

Complexity is O(N) for both runtime and space complexities.

I have highlighted in italic 3 problematic lines of code above. Compare with fixed solution section below.


Editorial solution

See editorial for a description of the solution.
public int beauty3(String S) {
int N = S.length();
long[] cnt = new long[N];
for(int i = 0; i< N; i++)
if(S.charAt(i) == '1')
cnt[N-i-1] += (1+i);//Adding contribution of on bits

for(int i = N-2; i>= 0; i--)
cnt[i] += cnt[i+1]; // Taking suffix sum to recover cnt

//Converting cnt to decimal number
long ans = 0, f = 1;
for(int i = 0; i < N; i++){
cnt[i] %= 2; // Only the parity of count matters
ans += f*cnt[i]%MOD;
if(ans >= MOD)
ans -= MOD;
f = (f*2)%MOD;
}
return (int) ans;
}
One quick win when comparing my approach with the editorial:

To compute 2^i [m], 0 <= i < N, just multiply by 2 and apply mod operation in O(1), instead of doing modular exponentiation in O(log(N)).

But it's not obvious what went wrong in my approach.


Generating test data

We can use the editorial implementation above as a benchmark to troubleshoot the bug in our solution.
private static final int MAX_N = 100_000;

public void generateTests() {
Random random = new Random();
while (true) {
int N = 1 + random.nextInt(MAX_N);
char[] c = new char[N];
for (int i = 0; i < N; i++) {
c[i] = random.nextBoolean() ? '1' : '0';
}

String S = new String(c);
int b1 = beauty(S);
int b3 = beauty3(S);
if (b1 != b3) {
System.err.println(S);
break;
}
}
}

We come up with an input where our approach and the editorial's diverge.

We generate binary strings randomly, compare the results of both implementations on the same input string until we find a discrepancy.


Bugfix

Compare cnt array is long type in editorial solution, while s array in my solution is int type with mod operations *, +, -.

The bug was to have computed the number modulo the prime number and use int rather than long. The overflow was avoided by applying mod operation. But the real problem is to add/subtract an odd number of times MOD odd value on the final value: this would change the parity should the count go higher than the modulo value or go negative.

cnt_i value may be higher than MOD.

On a reproduction case, I had

cnt[0] = 1 086 784 617
           =     88 540 264 [998 244 353] 

so the parity got changed after modulo operation.


Fixed Solution

public int beauty(String S) {
int N = S.length();

int count = 0;
int sum = 0;
int[] c = new int[N];
long[] s = new long[N];
for (int i = 0; i < N; i++) {
s[i] = sum;
boolean bit = S.charAt(i) == '1';
if (bit) {
count++;
}
sum += count;
c[i] = count;
}

int beauty = 0;
for (int i = 0; i < N; i++) {
long bitCount = 1L * (N-i) * c[N-1-i] - s[N-1-i];
beauty = MODULAR_ARITHMETICS.add(beauty, ((int) (bitCount % 2)) * MODULAR_ARITHMETICS.exponent(2, i));
}
return beauty;
}
Notice we now compute exact values for the number for the bit count.
The mod operation was only required to fit in an int value on large number of bits, while keeping the correct counts within long type.

Another callout is that ModularArithmetics utility seems overkill to just add 2 numbers or compute power of 2.

See new successful submission.




Conclusion

MOD = 998 244 353 is a large prime number hence odd number, you should not add / subtract in modular arithmetics for exact parity computation. You should compute exact values with long type should there be overflow with int.


Appendix

ModularArithmetics implementation is

/**
* Modular arithmetics.
*/
public class ModularArithmetics {

private final int m;
public ModularArithmetics(int m) {
this.m = m;
}

/**
* Right-to-left binary method.
*
* @param b Base
* @param e Exponent
* @return b^e [m]
*/
public int exponent(int b, int e) {
int pow = 1;

int base = b % m;
int exponent = e;

while (exponent > 0) {
if ((exponent & 1) != 0) {
pow = multiply(pow, base);
}
exponent >>= 1;
base = multiply(base, base);
}

return pow;
}

/**
* Handles overflow.
*
* @param a
* @param b
* @return a * b % m
*/
public int multiply(int a, int b) {
return (int) ((1L * a * b) % m);
}

/**
*
* @param a
* @return a^2 % m
*/
public int square(int a) {
return multiply(a, a);
}

/**
*
* @param a
* @param b
* @return a + b % m
*/
public int add(int a, int b) {
return (a + b) % m;
}

/**
*
* @param a
* @param b
* @return a - b % m
*/
public int substract(int a, int b) {
return positive(a-b);
}

/**
*
* @param a
* @return same value as a but positive
*/
private int positive(int a) {
return (a % m + m) % m;
}

/**
*
* @param a
* @return u such that a * u = 1
*/
public int inverse(int a) {
// a * u + m * v = 1, m is prime
int u = bezoutCoefficient(a, m);
return positive(u);
}

/**
* u such that
* a * u + b * v = gcd(a, b)
*
* @param a
* @param b
* @return
*/
private int bezoutCoefficient(int a, int b) {
int s = 0, old_s = 1;
int t = 1, old_t = 0;
int r = b, old_r = a;
int prov;

while (r != 0) {
int q = old_r / r;

prov = r;
r = old_r - q * prov;
old_r = prov;

prov = s;
s = old_s - q * prov;
old_s = prov;

prov = t;
t = old_t - q * prov;
old_t = prov;
}

return old_s;
}
}