Looking at the diagram provided in the sample case, the locations of the cows is essentially an X where each of the five squares that form the X are recursively replaced by Xes.
Subtask 2: Define f(k,dif) to be the number of cows (x,y) in the square [0,3k)×[0,3k) such that x−y=dif. We can do this in O(k) time by reducing to k−1, as gen_full does in the code below. Assume dif≥0.
Case 1: dif<3k−1
The diagram below displays the relevant positions for k=2,dif=2. In this case, f(k,dif)=3⋅f(k−1,dif).
x 012345678 0 10*000101 1 010.00010 2 1010.0101 3 00010*000 y 4 000010.00 5 0001010.0 6 10100010* 7 010000010 8 101000101
Case 2: dif≥3k−1
The diagram below displays the relevant positions for k=2,dif=6. In this case, f(k,dif)=f(k−1,dif−2⋅3k−1).
x 012345678 0 101000*01 1 0100000*0 2 10100010* 3 000101000 y 4 000010000 5 000101000 6 101000101 7 010000010 8 101000101
Full solution: We use the same idea of reducing from 3k to 3k−1. For the details, see rec in the code below.
#include <bits/stdc++.h> using namespace std; using ll = long long; vector<ll> po3 = [](){ vector<ll> res{1}; for (int i = 1; i < 40; ++i) res.push_back(3*res.back()); return res; }(); ll full[40]; void gen_full(int k, ll dif) { // count # of cows (x,y) in [0,3^k) x [0,3^k) // such that x-y=dif dif = abs(dif); if (k == 0) { full[k] = (dif == 0); return; } if (dif >= po3[k-1]) { gen_full(k-1,dif-2*po3[k-1]); full[k] = full[k-1]; } else { gen_full(k-1,dif); full[k] = 3*full[k-1]; } } ll rec(ll x, ll y, int k) { x %= po3[k], y %= po3[k]; // count # of cows in [0,3^k) x [0,3^k) // on the segment from (x-min(x,y),y-min(x,y)) to (x,y) if (k == 0) return 1; if (x < y) swap(x,y); if (x-y >= po3[k-1]) { if (x < 2*po3[k-1]) return 0; if (y < po3[k-1]) return rec(x,y,k-1); if (y >= po3[k-1]) return full[k-1]; } if (x < po3[k-1]) return rec(x,y,k-1); if (y < po3[k-1]) return full[k-1]; if (x < 2*po3[k-1]) return full[k-1]+rec(x,y,k-1); if (y < 2*po3[k-1]) return 2*full[k-1]; return 2*full[k-1]+rec(x,y,k-1); } ll diag(ll x, ll y) { if (x < 0 || y < 0) return 0; gen_full(39,x-y); return rec(x,y,39); } int main() { int Q; cin >> Q; while (Q--) { ll d,x,y; cin >> d >> x >> y; cout << diag(x+d,y+d)-diag(x-1,y-1) << "\n"; } }
Alternatively, we can ignore the diagram and do dynamic programming on the base-3 digits directly to count the number of k∈[0,d] such that (x+k,y+k) contains a cow. We determine the digits of k from least significant to most significant. If we've determined the first i digits so far, we should keep track of the following information:
#include <bits/stdc++.h> using namespace std; using ll = long long; #define F0R(i,a) for (int i = 0; i < a; ++i) int main() { vector<ll> po3{1}; for (int i = 1; i < 40; ++i) po3.push_back(3*po3.back()); array<array<array<ll,2>,2>,3> dp, DP; int Q; cin >> Q; while (Q--) { ll d,x,y; cin >> d >> x >> y; dp = {}; dp[1][0][0] = 1; F0R(i,39) { DP = {}; int dd = d/po3[i]%3, xd = x/po3[i]%3, yd = y/po3[i]%3; F0R(cmp,3) F0R(xc,2) F0R(yc,2) F0R(j,3) { int XD = (xd+xc+j)%3, XC = (xd+xc+j)/3; int YD = (yd+yc+j)%3, YC = (yd+yc+j)/3; int CMP = cmp; if (j > dd) CMP = 2; if (j < dd) CMP = 0; if (XD%2 == YD%2) DP[CMP][XC][YC] += dp[cmp][xc][yc]; } swap(dp,DP); } cout << dp[0][0][0]+dp[1][0][0] << "\n"; } }
Danny Mittal's code:
import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; import java.util.StringTokenizer; public class LargestPasture { public static void main(String[] args) throws IOException { long[] pow3 = new long[39]; pow3[0] = 1; for (int e = 1; e <= 38; e++) { pow3[e] = 3L * pow3[e - 1]; } BufferedReader in = new BufferedReader(new InputStreamReader(System.in)); StringBuilder out = new StringBuilder(); int n = Integer.parseInt(in.readLine()); for (int j = 1; j <= n; j++) { StringTokenizer tokenizer = new StringTokenizer(in.readLine()); long d = Long.parseLong(tokenizer.nextToken()); long x = Long.parseLong(tokenizer.nextToken()); long y = Long.parseLong(tokenizer.nextToken()); long[][][][][] dp = new long[3][2][3][2][40]; for (int a = 0; a < 2; a++) { for (int c = 0; c < 2; c++) { dp[a][0][c][0][0] = 1; } } for (int e = 0; e <= 38; e++) { int lim = (int) ((d / pow3[e]) % 3L); int xDigit = (int) ((x / pow3[e]) % 3L); int yDigit = (int) ((y / pow3[e]) % 3L); for (int h = 0; h < 2; h++) { for (int digit = 0; digit < 3; digit++) { for (int k = 0; k < 2; k++) { int hNext = (xDigit + digit + h) / 3; int xNewDigit = (xDigit + digit + h) % 3; int kNext = (yDigit + digit + k) / 3; int yNewDigit = (yDigit + digit + k) % 3; int compare; if (digit < lim) { compare = 0; } else if (digit == lim) { compare = 1; } else { compare = 2; } if (xNewDigit % 2 == yNewDigit % 2) { for (int a = 0; a < 2; a++) { for (int c = 0; c < 2; c++) { dp[a][hNext][c][kNext][e + 1] += dp[a == 1 ? compare : 0][h][c == 1 ? compare : 0][k][e]; } } } } } } } out.append(dp[1][0][1][0][39]).append('\n'); } System.out.print(out); } }