For each pair $(i,k)$ satisfying $i<k$ let $num[i][k]$ equal the number of $j$ such that $i<j<k$ and $A_i+A_j+A_k=0$. Then if $ans[i][k]$ denotes the answer for $(A_i,A_{i+1},\ldots,A_k)$, we can write
Now I'll describe a way to compute $num[i][i+1],\ldots, num[i][N]$ in $O(N)$ time. For each $k$ from $i+1,\ldots N$ in increasing order, consider a hashmap (unordered_map in C++) that allows you to query the number of occurrences of any integer among $A_{i+1},\ldots,A_{k-1}$. Then $num[i][k]$ is equal to the number of occurrences of $-A_i-A_k$ in this map. When $k$ is incremented by one the only change to the map is a single insertion. As hashmap operations run in $O(1)$ time, this solution runs in $O(N^2)$ time overall.
However, due to the high constant factor of hashmap, this solution does not earn full points. Because all entries of $A$ are in the range $[-10^6,10^6],$ we can replace the hashmap with an array of size $2\cdot 10^6+1,$ greatly improving the runtime.
import java.io.*;
import java.util.*;
public class threesum {
public static void main(String[] args) throws IOException {
BufferedReader in = new BufferedReader(new FileReader("threesum.in"));
PrintWriter out = new PrintWriter(new BufferedWriter(new FileWriter("threesum.out")));
String[] line = in.readLine().split(" ");
int N = Integer.parseInt(line[0]);
int Q = Integer.parseInt(line[1]);
line = in.readLine().split(" ");
int[] A = new int[N];
long[][] ans = new long[N][N];
for (int i = 0; i < N; ++i) A[i] = Integer.parseInt(line[i]);
int[] z = new int[2000001];
for (int i = N-1; i >= 0; --i) {
for (int j = i+1; j < N; ++j) {
int ind = 1000000-A[i]-A[j];
if (ind >= 0 && ind <= 2000000) ans[i][j] = z[ind];
z[1000000+A[j]] ++;
}
for (int j = i+1; j < N; ++j) {
z[1000000+A[j]] --;
}
}
for (int i = N-1; i >= 0; --i)
for (int j = i+1; j < N; ++j)
ans[i][j] += ans[i+1][j]+ans[i][j-1]-ans[i+1][j-1];
for (int i = 0; i < Q; ++i) {
line = in.readLine().split(" ");
int a = Integer.parseInt(line[0]);
int b = Integer.parseInt(line[1]);
out.println(ans[a-1][b-1]);
}
out.close();
}
}
Of course, it was still possible to pass without replacing the hashmap by an array. Although this wasn't intended, I'll include two additional solutions for the sake of completeness.
In C++, gp_hash_table is somewhat faster than unordered_map, especially if you specify an initial capacity. See here for more information.
#include <bits/stdc++.h>
using namespace std;
void setIO(string name) {
ios_base::sync_with_stdio(0); cin.tie(0);
freopen((name+".in").c_str(),"r",stdin);
freopen((name+".out").c_str(),"w",stdout);
}
#include <ext/pb_ds/assoc_container.hpp> // for gp_hash_table
using namespace __gnu_pbds;
int N,Q;
long long ans[5000][5000];
vector<int> A;
int main() {
setIO("threesum");
cin >> N >> Q;
A.resize(N); for (int i = 0; i < N; ++i) cin >> A[i];
for (int i = 0; i < N; ++i) {
gp_hash_table<int,int> g({},{},{},{},{1<<13});
// initialize with capacity that is a power of 2
for (int j = i+1; j < N; ++j) {
int res = -A[i]-A[j];
auto it = g.find(res);
if (it != end(g)) ans[i][j] = it->second;
g[A[j]] ++;
}
}
for (int i = N-1; i >= 0; --i) for (int j = i+1; j < N; ++j)
ans[i][j] += ans[i+1][j]+ans[i][j-1]-ans[i+1][j-1];
for (int i = 0; i < Q; ++i) {
int a,b; cin >> a >> b;
cout << ans[a-1][b-1] << "\n";
}
}
In Java, a hashmap solution passes if StreamTokenizer is used to take care of input, although it uses much more memory than I would expect. (If anyone knows how to reduce the memory usage, could you let me know?)
import java.io.*;
import java.util.*;
public class threesum {
static StreamTokenizer in;
static int nextInt() throws IOException {
in.nextToken();
return (int)in.nval;
}
public static void main(String[] args) throws IOException {
in = new StreamTokenizer(new BufferedReader(new FileReader("threesum.in")));
PrintWriter out = new PrintWriter(new BufferedWriter(new FileWriter("threesum.out")));
int N = nextInt(); int Q = nextInt();
int[] A = new int[N];
long[][] ans = new long[N][N];
for (int i = 0; i < N; ++i) A[i] = nextInt();
Map<Integer,Integer> z = new HashMap<>();
for (int i = N-1; i >= 0; --i) {
z.clear();
for (int j = i+1; j < N; ++j) {
int ind = -A[i]-A[j];
ans[i][j] = z.getOrDefault(ind,0);
z.put(A[j],z.getOrDefault(A[j],0)+1);
}
}
for (int i = N-1; i >= 0; --i)
for (int j = i+1; j < N; ++j)
ans[i][j] += ans[i+1][j]+ans[i][j-1]-ans[i+1][j-1];
for (int i = 0; i < Q; ++i)
out.println(ans[nextInt()-1][nextInt()-1]);
out.close();
}
}