Looking at the diagram provided in the sample case, the locations of the cows is essentially an X where each of the five squares that form the X are recursively replaced by Xes.
Subtask 2: Define $f(k,dif)$ to be the number of cows $(x,y)$ in the square $[0,3^k)\times [0,3^k)$ such that $x-y=dif$. We can do this in $\mathcal{O}(k)$ time by reducing to $k-1$, as $\texttt{gen_full}$ does in the code below. Assume $dif\ge 0$.
Case 1: $dif<3^{k-1}$
The diagram below displays the relevant positions for $k=2, dif=2$. In this case, $f(k,dif)=3\cdot f(k-1,dif)$.
x
012345678
0 10*000101
1 010.00010
2 1010.0101
3 00010*000
y 4 000010.00
5 0001010.0
6 10100010*
7 010000010
8 101000101
Case 2: $dif\ge 3^{k-1}$
The diagram below displays the relevant positions for $k=2, dif=6$. In this case, $f(k,dif)=f(k-1,dif-2\cdot 3^{k-1})$.
x
012345678
0 101000*01
1 0100000*0
2 10100010*
3 000101000
y 4 000010000
5 000101000
6 101000101
7 010000010
8 101000101
Full solution: We use the same idea of reducing from $3^k$ to $3^{k-1}$. For the details, see $\texttt{rec}$ in the code below.
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
vector<ll> po3 = [](){
vector<ll> res{1};
for (int i = 1; i < 40; ++i)
res.push_back(3*res.back());
return res;
}();
ll full[40];
void gen_full(int k, ll dif) {
// count # of cows (x,y) in [0,3^k) x [0,3^k)
// such that x-y=dif
dif = abs(dif);
if (k == 0) {
full[k] = (dif == 0);
return;
}
if (dif >= po3[k-1]) {
gen_full(k-1,dif-2*po3[k-1]);
full[k] = full[k-1];
} else {
gen_full(k-1,dif);
full[k] = 3*full[k-1];
}
}
ll rec(ll x, ll y, int k) {
x %= po3[k], y %= po3[k];
// count # of cows in [0,3^k) x [0,3^k)
// on the segment from (x-min(x,y),y-min(x,y)) to (x,y)
if (k == 0) return 1;
if (x < y) swap(x,y);
if (x-y >= po3[k-1]) {
if (x < 2*po3[k-1]) return 0;
if (y < po3[k-1]) return rec(x,y,k-1);
if (y >= po3[k-1]) return full[k-1];
}
if (x < po3[k-1]) return rec(x,y,k-1);
if (y < po3[k-1]) return full[k-1];
if (x < 2*po3[k-1]) return full[k-1]+rec(x,y,k-1);
if (y < 2*po3[k-1]) return 2*full[k-1];
return 2*full[k-1]+rec(x,y,k-1);
}
ll diag(ll x, ll y) {
if (x < 0 || y < 0) return 0;
gen_full(39,x-y);
return rec(x,y,39);
}
int main() {
int Q; cin >> Q;
while (Q--) {
ll d,x,y; cin >> d >> x >> y;
cout << diag(x+d,y+d)-diag(x-1,y-1) << "\n";
}
}
Alternatively, we can ignore the diagram and do dynamic programming on the base-3 digits directly to count the number of $k\in [0,d]$ such that $(x+k,y+k)$ contains a cow. We determine the digits of $k$ from least significant to most significant. If we've determined the first $i$ digits so far, we should keep track of the following information:
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
#define F0R(i,a) for (int i = 0; i < a; ++i)
int main() {
vector<ll> po3{1};
for (int i = 1; i < 40; ++i)
po3.push_back(3*po3.back());
array<array<array<ll,2>,2>,3> dp, DP;
int Q; cin >> Q;
while (Q--) {
ll d,x,y; cin >> d >> x >> y;
dp = {}; dp[1][0][0] = 1;
F0R(i,39) {
DP = {};
int dd = d/po3[i]%3, xd = x/po3[i]%3, yd = y/po3[i]%3;
F0R(cmp,3) F0R(xc,2) F0R(yc,2) F0R(j,3) {
int XD = (xd+xc+j)%3, XC = (xd+xc+j)/3;
int YD = (yd+yc+j)%3, YC = (yd+yc+j)/3;
int CMP = cmp;
if (j > dd) CMP = 2;
if (j < dd) CMP = 0;
if (XD%2 == YD%2)
DP[CMP][XC][YC] += dp[cmp][xc][yc];
}
swap(dp,DP);
}
cout << dp[0][0][0]+dp[1][0][0] << "\n";
}
}
Danny Mittal's code:
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.StringTokenizer;
public class LargestPasture {
public static void main(String[] args) throws IOException {
long[] pow3 = new long[39];
pow3[0] = 1;
for (int e = 1; e <= 38; e++) {
pow3[e] = 3L * pow3[e - 1];
}
BufferedReader in = new BufferedReader(new InputStreamReader(System.in));
StringBuilder out = new StringBuilder();
int n = Integer.parseInt(in.readLine());
for (int j = 1; j <= n; j++) {
StringTokenizer tokenizer = new StringTokenizer(in.readLine());
long d = Long.parseLong(tokenizer.nextToken());
long x = Long.parseLong(tokenizer.nextToken());
long y = Long.parseLong(tokenizer.nextToken());
long[][][][][] dp = new long[3][2][3][2][40];
for (int a = 0; a < 2; a++) {
for (int c = 0; c < 2; c++) {
dp[a][0][c][0][0] = 1;
}
}
for (int e = 0; e <= 38; e++) {
int lim = (int) ((d / pow3[e]) % 3L);
int xDigit = (int) ((x / pow3[e]) % 3L);
int yDigit = (int) ((y / pow3[e]) % 3L);
for (int h = 0; h < 2; h++) {
for (int digit = 0; digit < 3; digit++) {
for (int k = 0; k < 2; k++) {
int hNext = (xDigit + digit + h) / 3;
int xNewDigit = (xDigit + digit + h) % 3;
int kNext = (yDigit + digit + k) / 3;
int yNewDigit = (yDigit + digit + k) % 3;
int compare;
if (digit < lim) {
compare = 0;
} else if (digit == lim) {
compare = 1;
} else {
compare = 2;
}
if (xNewDigit % 2 == yNewDigit % 2) {
for (int a = 0; a < 2; a++) {
for (int c = 0; c < 2; c++) {
dp[a][hNext][c][kNext][e + 1] += dp[a == 1 ? compare : 0][h][c == 1 ? compare : 0][k][e];
}
}
}
}
}
}
}
out.append(dp[1][0][1][0][39]).append('\n');
}
System.out.print(out);
}
}