First, we show how to find the length of the shortest path between a and b. It can be shown that greedily advancing from a to the farthest right reachable interval is always optimal. So, from each interval, we can assign a unique "right parent" interval, and the distance from a to b can be found by repeatedly going to the parent of the current interval, ending when we reach an interval at b or greater. To find these parents, we can use a two pointer technique where we scan from left to right and keep track of the parent of the current node, which always moves to the right. Naively, we can directly travel from a to b using these parent pointers.
To speed this process up to O(logN) time per query, we can use the technique of binary jumping. For every k from 1…logN, we can calculate the 2k-th parent of each interval: the 2kth parent of the ith interval is the 2k−1th parent of the 2k−1th parent of the ith interval. Then, to find the distance between a and b, we can do the following process: for every k from logN…0, we check whether the 2kth parent of a is to the right of b or not. If so, then the shortest path between a and b is less than 2k. Else, we recursively find the shortest path between this parent and b.
This observation nets us partial credit, as we can find the shortest path between a and b, and then iterate over all special intervals s between a and b, and checking whether dist(a,s)+dist(s,b)=dist(a,b). If so, then s clearly lies on a shortest path from a to b; otherwise, dist(a,s)+dist(s,b)>dist(a,b) and any path going through s must have distance greater than dist(a,b).
An interesting observation to make is that when using the binary jumping solution above, we could arbitrarily use binary jumps going left or right. This will come in handy later on!
Now, we will take a closer look at the (not necessarily special) intervals s which satisfy dist(a,s)+dist(s,b)=dist(a,b). For convenience, let L=dist(a,b). Consider the intervals s which satisfy dist(a,s)=i,dist(s,b)=L−i for some i∈[1,L−1]. From our greedy strategy for finding shortest paths mentioned above, it can be shown that the set of intervals which satisfy dist(a,s)≤i are simply those with indices ≤greedya(i), where greedya(i) denotes the interval we arrive at if we start at a and greedily go right i times. Similarly, the set of intervals which satisfy dist(s,b)≤L−i are those with indices ≥greedyb(L−i), where greedyb(L−i) denotes the interval we arrive at if we start at b and go greedily left L−i times.
So, the intervals s which satisfy dist(a,s)=i,dist(s,b)=L−i are the intersection of these two sets: intervals which satisfy greedyb(L−i)≤s≤greedya(i), which forms some contiguous range of intervals. Notice that if we consider all of these ranges for all i, no interval is counted twice. Additionally, each range is nonempty (in particular, greedya(i) always lies in range i). So, our answer can be expressed in the form ∑L−1i=1(number of special ranges between greedyb(L−i) and greedya(i), inclusive).
(This potentially leaves out counting the intervals a and b, which we can take care of as special cases.)
We can then use cumulative sums to simplify this expression. Define csum(x) to be the number of special intervals between 1 and x, inclusive. The sum can be rewritten as ∑L−1i=1csum(greedya(i))−csum(greedyb(L−i)−1). We can rewrite the sum as ∑L−1i=1csum(greedya(i))−∑L−1i=1csum(greedyb(L−i)−1)=∑L−1i=1csum(greedya(i))−∑L−1i=1csum(greedyb(i)−1).
Now, each of these sums can be computed independently and then summed up for the final answer. We show how to compute the second sum.
Consider again our binary jumping solution for computing the shortest path between a and b, where we defined the parent of an interval i as greedily going left one step. We can imagine these parent relations forming a graph (specifically, a forest) with directed edges going left. Then, the shortest path length is the length of the path starting at node b and going up the tree until we arrive at a node with index ≤a. For each node, define par(node) to be the parent of the node. For each edge between some node and par(node), we can place a weight on the edge equal to csum(par(node)−1). Then, the expression ∑L−1i=1csum(greedyb(i)−1) is simply the total weight of all the edges going from b to its L−1th parent.
The sum of the weights of these edges can be computed again using binary jumps. For every k from 1…logn, we can precompute the sum of the weights on the path from each node i and its 2kth parent. We do this recursively: the sum of the weights on the path from node i and its 2kth parent is equal to the sum of the weights on the path from node i to its 2k−1th parent plus the sum of the weights on the path from the 2k−1th parent of i to the 2k−1th parent of the 2k−1th parent of i. We compute logN pieces of information for each of the nodes, which takes O(NlogN) total time.
Then, to answer queries, we can traverse from b to its L−1th parent using logN jumps (as described earlier), while summing up the weights we just computed every time we jump 2k distance for some k.
#include <bits/stdc++.h> using namespace std; const int mx = 200005; const int BIN = 30; int N, Q; //precompute binary jumps vector<vector<pair<int, long long>>> genJmp(string S, vector<int> csum){ vector<vector<pair<int, long long>>> res(N, vector<pair<int, long long>>(BIN)); res[0][0] = make_pair(-1, 0); vector<pair<int, int>> rangs(N); int Ls = 0; int Rs = 0; for(int i = 0; i < int(S.size()); i++){ if(S[i] == 'L'){ rangs[Ls].first = i; Ls++; } else{ rangs[Rs].second = i; Rs++; } } //calculate the information for the 2^0th parent (can also be done in linear time) set<pair<int, int>> rights; for(int i = 0; i < N; i++){ if(i > 0){ auto it = rights.lower_bound(make_pair(rangs[i].first, -1)); int lower_ind = (it->second); assert(lower_ind < i); res[i][0] = make_pair(lower_ind, csum[lower_ind]); } rights.insert(make_pair(rangs[i].second, i)); } //generate binary jumps for(int i = 0; i < N; i++){ for(int j = 1; j < BIN; j++){ int lower_ind = res[i][j-1].first; assert(lower_ind < i); if(lower_ind == -1){ res[i][j] = make_pair(-1, 0); } else{ res[i][j] = make_pair(res[lower_ind][j-1].first, res[i][j-1].second+res[lower_ind][j-1].second); } } } return res; } int main() { cin.tie(0)->sync_with_stdio(0); cin >> N >> Q; string S; cin >> S; string special; cin >> special; //generate cumulative sums vector<int> csum(int(special.size()), 0); for(int i = 0; i < int(csum.size()); i++){ if(i-1 >= 0) csum[i] = csum[i-1]; csum[i]+=(special[i] == '1'); } vector<int> csum_left(N); //change indices of left cumulative sums for(int i = 1; i < N; i++){ csum_left[i] = csum[i-1]; } //do a bunch of reversals so that we can use the same function for generating binary jumps, then reverse afterwards again vector<int> csum_right = csum; reverse(begin(csum_right), end(csum_right)); vector<vector<pair<int, long long>>> jmp_left = genJmp(S, csum_left); reverse(begin(S), end(S)); for(auto& u: S){ u = ((u == 'L') ? 'R' : 'L'); } reverse(begin(special), end(special)); vector<vector<pair<int, long long>>> jmp_right = genJmp(S, csum_right); reverse(begin(jmp_right), end(jmp_right)); for(int i = 0; i < int(jmp_right.size()); i++){ for(int j = 0; j < int(jmp_right[i].size()); j++){ jmp_right[i][j].first = N-1-jmp_right[i][j].first; } } reverse(begin(S), end(S)); for(auto& u: S){ u = ((u == 'L') ? 'R' : 'L'); } reverse(begin(special), end(special)); //answer queries for(int i = 1; i <= Q; i++){ int a, b; cin >> a >> b; a--; b--; //0-indexed intervals int cur = b; long long sum = 0; //sum of edge weights going left from b int cur_right = a; //move this right as you go along long long sum_right = 0; //sum of edge weights going right from a int dist_traveled = 0; for(int j = BIN-1; j >= 0; j--){ //binary jump of size 2^j if(jmp_left[cur][j].first == -1) continue; if(jmp_left[cur][j].first > a){ //don't reach a or further to the left sum_right+=jmp_right[cur_right][j].second; //add up edge weights cur_right = jmp_right[cur_right][j].first; //update current node sum+=jmp_left[cur][j].second; //add up edge weights cur = jmp_left[cur][j].first; //update current node dist_traveled+=(1<<j); assert(cur_right != -1 && cur != -1); } } assert(cur_right < b); assert(cur > a); long long ans = sum_right-sum; if(special[a] == '1') ans++; //take care of a and b specially if(special[b] == '1') ans++; cout << dist_traveled+1 << " " << ans << "\n"; } // you should actually read the stuff at the bottom } /* stuff you should look for * int overflow, array bounds * special cases (n=1?) * do smth instead of nothing and stay organized * WRITE STUFF DOWN * DON'T GET STUCK ON ONE APPROACH */
Bonus: Solve the case where the right endpoints of the intervals are not in increasing order (solution in this paper).