Introduction
This article describes an invalid approach to Codechef SUB_XOR problem from March Long challenge 1.
If failed task #3 in the test suite:
Let's find out why.
Problem statement
We map a string, containing only '0' or '1' characters, to the number it represents in binary format.
Given such a binary string, find the XOR value of all its binary substrings, modulo 998 244 353.
Constraints
- 1 <= N <= 10^5
- 1 <= T <= 10^2
Input Parsing
package codechef.mar22;
import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.PrintWriter;
class SubarrayXOR {
private static final int MOD = 998_244_353;
public static void main(String[] args) throws Exception {
//InputStream inputStream = System.in;
InputStream inputStream = new FileInputStream("SUB_XOR");
BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(inputStream));
PrintWriter writer = new PrintWriter(new BufferedOutputStream(System.out));
String[] tokens;
tokens = bufferedReader.readLine().split(" ");
int T = Integer.parseInt(tokens[0]);
while (T > 0) {
tokens = bufferedReader.readLine().split(" ");
int N = Integer.parseInt(tokens[0]);
tokens = bufferedReader.readLine().split(" ");
String S = tokens[0];
SubarrayXOR subarrayXOR = new SubarrayXOR();
writer.println(subarrayXOR.beauty(S));
T--;
}
writer.close();
inputStream.close();
}
}
Input
3
2
10
3
101
4
1111
Output
3
6
12
Brute-force solution
public int beauty2(String S) {
int N = S.length();
int xor = 0;
for (int n = 1; n <= N; n++) {
for (int i = 0; i < N-n+1; i++) {
int subStringNumber = toNumber(S, i, n);
xor ^= subStringNumber;
}
}
return xor;
}
private int toNumber(String S, int i, int n) {
int number = 0;
for (int j = i+n-1; j >= i; j--) {
if (S.charAt(j) == '1') {
number += 1 << (i+n-1 - j);
}
}
return number;
}
For all substring sizes n, 1 <= n <= N, convert the substring of size n starting at index i , 0 <= i < N-n+1 to the associated number and apply xor operation.
Make sure to scan the characters from right to left by starting with the last character, since strings are parsed left to right.
Complexity is O(N^3), which will timeout.
My solution
private static final ModularArithmetics MODULAR_ARITHMETICS = new ModularArithmetics(MOD);
public int beauty(String S) {
int N = S.length();
int count = 0;
int sum = 0;
int[] c = new int[N];
int[] s = new int[N];
for (int i = 0; i < N; i++) {
s[i] = sum;
boolean bit = S.charAt(i) == '1';
if (bit) {
count++;
}
sum = MODULAR_ARITHMETICS.add(sum, count);
c[i] = count;
}
int beauty = 0;
for (int i = 0; i < N; i++) {
int xor = MODULAR_ARITHMETICS.substract(MODULAR_ARITHMETICS.multiply(N-i, c[N-1-i]), s[N-1-i]);
beauty = MODULAR_ARITHMETICS.add(beauty, (xor % 2) * MODULAR_ARITHMETICS.exponent(2, i));
}
return beauty;
}
We estimate the number of substrings with a '1' character contributing to bit i, 0 <= i < N, where N is the input string size. This number is denoted cnt_i in the editorial. If it's odd we enable the i-th bit in the output value.
I applied following formula:
cnt_i = (N-i) * c[N-1-i] - s[N-1-i]
with
- c[j] = Sum_{S[i], 0 <= i <= j}
- s[j] = Sum_{c[i], 0 <= i < j}
It leverages prefix count and sum arrays defined above.
Complexity is O(N) for both runtime and space complexities.
I have highlighted in italic 3 problematic lines of code above. Compare with fixed solution section below.
Editorial solution
public int beauty3(String S) {
int N = S.length();
long[] cnt = new long[N];
for(int i = 0; i< N; i++)
if(S.charAt(i) == '1')
cnt[N-i-1] += (1+i);//Adding contribution of on bits
for(int i = N-2; i>= 0; i--)
cnt[i] += cnt[i+1]; // Taking suffix sum to recover cnt
//Converting cnt to decimal number
long ans = 0, f = 1;
for(int i = 0; i < N; i++){
cnt[i] %= 2; // Only the parity of count matters
ans += f*cnt[i]%MOD;
if(ans >= MOD)
ans -= MOD;
f = (f*2)%MOD;
}
return (int) ans;
}
Generating test data
private static final int MAX_N = 100_000;
public void generateTests() {
Random random = new Random();
while (true) {
int N = 1 + random.nextInt(MAX_N);
char[] c = new char[N];
for (int i = 0; i < N; i++) {
c[i] = random.nextBoolean() ? '1' : '0';
}
String S = new String(c);
int b1 = beauty(S);
int b3 = beauty3(S);
if (b1 != b3) {
System.err.println(S);
break;
}
}
}
We come up with an input where our approach and the editorial's diverge.
We generate binary strings randomly, compare the results of both implementations on the same input string until we find a discrepancy.
Bugfix
Fixed Solution
public int beauty(String S) {
int N = S.length();
int count = 0;
int sum = 0;
int[] c = new int[N];
long[] s = new long[N];
for (int i = 0; i < N; i++) {
s[i] = sum;
boolean bit = S.charAt(i) == '1';
if (bit) {
count++;
}
sum += count;
c[i] = count;
}
int beauty = 0;
for (int i = 0; i < N; i++) {
long bitCount = 1L * (N-i) * c[N-1-i] - s[N-1-i];
beauty = MODULAR_ARITHMETICS.add(beauty, ((int) (bitCount % 2)) * MODULAR_ARITHMETICS.exponent(2, i));
}
return beauty;
}
Conclusion
Appendix
ModularArithmetics implementation is
/**
* Modular arithmetics.
*/
public class ModularArithmetics {
private final int m;
public ModularArithmetics(int m) {
this.m = m;
}
/**
* Right-to-left binary method.
*
* @param b Base
* @param e Exponent
* @return b^e [m]
*/
public int exponent(int b, int e) {
int pow = 1;
int base = b % m;
int exponent = e;
while (exponent > 0) {
if ((exponent & 1) != 0) {
pow = multiply(pow, base);
}
exponent >>= 1;
base = multiply(base, base);
}
return pow;
}
/**
* Handles overflow.
*
* @param a
* @param b
* @return a * b % m
*/
public int multiply(int a, int b) {
return (int) ((1L * a * b) % m);
}
/**
*
* @param a
* @return a^2 % m
*/
public int square(int a) {
return multiply(a, a);
}
/**
*
* @param a
* @param b
* @return a + b % m
*/
public int add(int a, int b) {
return (a + b) % m;
}
/**
*
* @param a
* @param b
* @return a - b % m
*/
public int substract(int a, int b) {
return positive(a-b);
}
/**
*
* @param a
* @return same value as a but positive
*/
private int positive(int a) {
return (a % m + m) % m;
}
/**
*
* @param a
* @return u such that a * u = 1
*/
public int inverse(int a) {
// a * u + m * v = 1, m is prime
int u = bezoutCoefficient(a, m);
return positive(u);
}
/**
* u such that
* a * u + b * v = gcd(a, b)
*
* @param a
* @param b
* @return
*/
private int bezoutCoefficient(int a, int b) {
int s = 0, old_s = 1;
int t = 1, old_t = 0;
int r = b, old_r = a;
int prov;
while (r != 0) {
int q = old_r / r;
prov = r;
r = old_r - q * prov;
old_r = prov;
prov = s;
s = old_s - q * prov;
old_s = prov;
prov = t;
t = old_t - q * prov;
old_t = prov;
}
return old_s;
}
}