(Analysis by Nick Wu)

In this problem, we have several red points and blue points, both colors with their own internal ordering. We start at the first red point and want to end at the last red point, meanwhile visiting all points preserving ordering within a given color. The amount of energy it takes to travel between two given points is the square of the distance, and we wish to minimize the total energy used.

Imagine the last step that we must take to walk to the last red point - we must walk either from the second-to-last red point to the last one, or we walk from the last blue point directly to the last red point. Irrespective of which one we choose, we would want to pick the optimal path that ends with us either at the last blue point with one more red point to visit, or at the second-to-last red point with no blue points left to visit.

This motivates a dynamic programming solution. Let $f(n, m)$ be the minimum energy expended to visit $n$ red points and $m$ blue points, ending at the $n$th red point, and get $g(n, m)$ be the minimum energy expended to visit $n$ red points and $m$ blue points, ending at the $m$th blue point. $f(n, m)$ depends on $f(n-1, m)$ and $g(n-1, m)$, while $g(n, m)$ depends on $f(n, m-1)$ and $g(n, m-1)$.

import java.io.*;
import java.util.*;
public class checklist {
public static void main(String[] args) throws IOException {
PrintWriter pw = new PrintWriter(new BufferedWriter(new FileWriter("checklist.out")));
int n = Integer.parseInt(st.nextToken());
int m = Integer.parseInt(st.nextToken());
State[] holstein = new State[n];
State[] guernsey = new State[m];
for(int i = 0; i < n; i++) {
holstein[i] = new State(Integer.parseInt(st.nextToken()), Integer.parseInt(st.nextToken()));
}
for(int i = 0; i < m; i++) {
guernsey[i] = new State(Integer.parseInt(st.nextToken()), Integer.parseInt(st.nextToken()));
}
long[][][] dp = new long[n+1][m+1][2];
for(int i = 0; i < dp.length; i++) {
for(int j = 0; j < dp[i].length; j++) {
Arrays.fill(dp[i][j], 1L << 60);
}
}
dp[1][0][0] = 0;
for(int i = 0; i < dp.length; i++) {
for(int j = 0; j < dp[i].length; j++) {
for(int k = 0; k < 2; k++) {
if(k == 0 && i == 0) continue;
if(k == 1 && j == 0) continue;
State source;
if(k == 0) source = holstein[i-1];
else source = guernsey[j-1];
if(i < n) {
dp[i+1][j][0] = Math.min(dp[i+1][j][0], dp[i][j][k] + source.dist(holstein[i]));
}
if(j < m) {
dp[i][j+1][1] = Math.min(dp[i][j+1][1], dp[i][j][k] + source.dist(guernsey[j]));
}
}
}
}
pw.println(dp[n][m][0]);
pw.close();
}

static class State {
public int x, y;
public State(int a, int b) {
x=a;
y=b;
}
public int dist(State s) {
return (x-s.x)*(x-s.x) + (y-s.y)*(y-s.y);
}
}

}