xxxxxxxxxx
// Boost Software License - Version 1.0 - August 17th, 2003
//
// Permission is hereby granted, free of charge, to any person or organization
// obtaining a copy of the software and accompanying documentation covered by
// this license (the "Software") to use, reproduce, display, distribute,
// execute, and transmit the Software, and to prepare derivative works of the
// Software, and to permit third-parties to whom the Software is furnished to
// do so, all subject to the following:
//
// The copyright notices in the Software and this entire statement, including
// the above license grant, this restriction and the following disclaimer,
// must be included in all copies of the Software, in whole or in part, and
// all derivative works of the Software, unless such copies or derivative
// works are solely in the form of machine-executable object code generated by
// a source language processor.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
// SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
// FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
// ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
import std.stdio;
import std.string;
import std.range;
import std.typecons;
import std.datetime.stopwatch;
import std.container.binaryheap;
import std.algorithm.comparison;
// A fast algorithm for Pythagorean triples using Euclid's formula.
//
// We start by generating all coprime pairs (m, n) with m > n > 0
// starting with the initial values (2, 1) and (3, 1) and the
// generators:
// (m, n) => (2*m-n, n)
// (m, n) => (2*m+n, n)
// (m, n) => (m+2*n, m)
//
// See
// https://en.wikipedia.org/wiki/Coprime_integers#Generating_all_coprime_pairs
//
// Related:
// https://en.wikipedia.org/wiki/Tree_of_primitive_Pythagorean_triples
//
// We then use Euclid's formula by iterating over the triples (m, n, k)
// where m and n are coprime and k >= 1 using the lexicographic order
// on (k*(m*m+n*n), k*min(2*m*n, m*m-n*n)). This iteration is done using
// a priority queue.
//
// Requires O(n) space and O(n log(n)) time for n triples.
//
// Note that enumerating Pythagorean triples can be done even faster if
// you aren't constrained by the specific ordering induced by the original
// algorithm.
alias Triple = Tuple!(int, "x", int, "y", int, "z");
class PythagoreanTriples {
private:
struct Item {
int m, n, k; // generators
int x, y, z; // triple
int opCmp(ref const Item other) {
int cmp(int a, int b) {
return (a > b) - (a < b);
}
int c = cmp(other.z, z);
if (c != 0)
return c;
return cmp(other.x, x);
}
this(int m, int n, int k) {
this.m = m;
this.n = n;
this.k = k;
int t1 = k * (m * m - n * n);
int t2 = 2 * k * m * n;
x = min(t1, t2);
y = max(t1, t2);
z = k * (m * m + n * n);
}
}
BinaryHeap!(Item[]) pqueue;
bool[Tuple!(int, int)] seen;
public:
enum empty = false;
Triple front;
this() {
pqueue = heapify([Item(2, 1, 1), Item(3, 1, 1)]);
popFront();
}
void popFront() {
while (true) {
auto f = pqueue.front;
pqueue.removeFront();
with (f) {
if (k == 1) {
// generators for primitive triples
pqueue.insert(Item(2 * m - n, m, 1));
pqueue.insert(Item(2 * m + n, m, 1));
pqueue.insert(Item(m + 2 * n, n, 1));
}
// generator for non-primitive triples
pqueue.insert(Item(m, n, k + 1));
front.x = x;
front.y = y;
front.z = z;
// Avoid duplicates for non-primitive triples, e.g. 6, 8, 10 is
// generated by both m = 3, n = 1, k = 1 and m = 2, n = 1, k = 2.
// Only primitive triples are guaranteed to be unique by
// construction.
auto pair = tuple!(int, int)(x, y);
if (pair !in seen) {
seen[pair] = true;
return;
}
}
}
}
}
void pyth_euclid(bool delegate(int, int, int) until) {
auto triples = new PythagoreanTriples();
foreach (triple; triples) {
with (triple) {
if (until(x, y, z)) return;
}
}
}
// Naive implementation for comparison purposes
void pyth_simple(bool delegate(int, int, int) until) {
int i = 0;
for (int z = 1;; z++) {
for (int x = 1; x <= z; x++) {
for (int y = x; y <= z; y++) {
if (x * x + y * y == z * z) {
if (until(x, y, z)) return;
}
}
}
}
}
void main() {
const n = 3000;
int count;
bool until(int x, int y, int z) {
if (count-- <= 0) return true;
// writeln("%d^2 + %d^2 = %d^2".format(x, y, z));
return false;
}
StopWatch sw;
sw.start();
count = n;
pyth_simple(&until);
sw.stop();
writeln("Simple implementation: %d ms".format(sw.peek().total!"msecs"));
sw.reset();
sw.start();
count = n;
pyth_euclid(&until);
sw.stop();
writeln("Fast implementation: %d ms".format(sw.peek().total!"msecs"));
}