Problem 216

216 Investigating the primality of numbers of the form 2*n^2-1
http://projecteuler.net/index.php?section=problems&id=216
ミラー・ラビンを使うのも、ひとつの方法でしょう。
しかし、それは問題の主旨に反するような気もする。
では、どうするか…
2*n^2-1が、ある素数pで割り切れるならば、
n^{2}\equiv\frac{p+1}{2}\, (\mathrm{mod}\, p)
が成り立つことに注目。
この等式の解は mod p で考えれば、0個か2個。
mod p での √ が必要になるが、それは前にやった。
4n+1素数の平方和分解 - 落書き、時々落学
で、つまり、
n\equiv\sqrt{\frac{p+1}{2}}\,(\mathrm{mod}\, p)
なる、n に対しては, 2*n^2-1 は p で割り切れる。
あとは、これを利用して、ふるいにかければ良い。

{-# OPTIONS_GHC -fbang-patterns #-}
import Control.Monad (when)
import Control.Monad.ST (ST, runST)
import Data.Array.ST (STUArray, newArray, readArray, writeArray)
import Mod (jacobi, sqrtMod)

primesUpTo :: Integer -> [Integer]
primesUpTo n = runST $ do
                 p <- newArray (1,div n 2) True :: ST s (STUArray s Integer Bool)
                 let u = floor.sqrt.fromIntegral $ div n 2
                     loop !i = when (i <= u) $ sieve i (2*i*(i+1)) >> loop (i+1)
                     sieve !i !j = when (j <= div n 2) $ writeArray p j False >> sieve i (j+2*i+1)
                     primes !i !ps | i > div n 2 = return.reverse $ ps
                                   | otherwise   = do q <- readArray p i
                                                      if q then primes (i+1) (2*i+1:ps)
                                                         else primes (i+1) ps
                 loop 1
                 primes 1 [2]

p216 :: Integer -> Integer
p216 n = runST $ do
           q <- newArray (2,n) True :: ST s (STUArray s Integer Bool)
           let u = floor $ sqrt 2 * fromIntegral n
               ps = tail $ primesUpTo u
               loop [] = count 2 0
               loop !(p:ps) | jacobi (div (p+1) 2) p == -1 = loop ps
                            | otherwise = sieve p s >> sieve p (p-s) >> check p >> loop ps
                   where s = sqrtMod (div (p+1) 2) p
               sieve !i !j = when (j<=n) $ writeArray q j False >> sieve i (j+i)
               check !i = when (div (i+1) 2 == j*j) $ writeArray q j True
                   where j = floor.sqrt.fromIntegral $ div (i+1) 2
               count !i !c | i > n     = return c
                           | otherwise =  do p <- readArray q i
                                             if p then count (i+1) (c+1)
                                                else count (i+1) c
           loop ps

main :: IO ()
main = print $ p216 (5*10^7)

ミラー・ラビンよりは速いけど、やっぱり、まだ遅い。

追記

C++で実装したら、10倍くらい速くなった。

#include <iostream>
#include <cmath>
#define SIGN(p) ((p) ? 1 : -1)
#define N 50000000
#define U ( (int) ( sqrt( 2 ) * N ))
#define S sqrt( 0.5 *  U )
using namespace std;

// input : a, b
// output : x, y  s.t. ax + by = gcd(a, b)
int extGcd( int a, int b, int& x, int& y ) {
  if ( b == 0 ) {
    x = 1; y = 0; return a;
  }
  int g = extGcd( b, a % b, y, x );
  y -= a / b * x;
  return g;
}

// xn = 1 (mod p)
int invMod(int n, int p) {
  int x, y;
  return extGcd ( n, p, x, y ) == 1 ? ( x + p ) % p : 0;
}

// jacobi symbol ( a / n )
int jacobi(int a, int n) {
  if (a < 0) return SIGN(n % 4 == 1) * jacobi(-a, n);
  if (a < 2) return a;
  if (a % 2 == 0) return SIGN(n % 8 == 1 || n % 8 == 7) * jacobi(a / 2, n);
  return SIGN(a % 4 == 1 || n % 4 == 1) * jacobi(n % a, a);
}

// a^n (mod p)
int powMod(int a, int n, int p) {
  if (n == 0) return 1;
  if (n % 2 == 0) return powMod( (long long) a * a % p, n / 2, p);
  return (long long) a * powMod( a, n - 1, p ) % p;
}

// sqrt n (mod p)
int sqrtMod(int n, int p) {
  int w = 2, s = 0, q = p - 1, m = invMod( n, p );
  while ( jacobi ( w, p ) != -1 ) w++;
  while ( q % 2 == 0 ) q /= 2, s++;
  int v = powMod( w, q, p );
  int r = powMod( n, (q + 1) / 2, p );
  do {
    int i = 0, u = (long long) r * r * m % p;
    while ( u % p != 1 ) u = (long long) u * u % p, i++;
    if ( i == 0 ) return r;
    r = (long long) r * powMod (v, 1 << (s - i - 1), p) % p;
  } while ( true );
}

int main() {
  bool *p = new bool [ U / 2 + 1], *q = new bool [ N + 1 ];
  for ( int i = 1; i <= U / 2; i++ ) p[i] = true;
  for ( int i = 1; i <= N; i++ ) q[i] = true;

  for ( int i = 1; i <= S; i++ ) if ( p[i] ) // sieve prime
    for ( int j = 2 * i * ( i + 1 ); j <= U / 2; j += 2 * i + 1 ) p[j] = false;

  for ( int i = 3; i <= U ; i += 2 ) // sieve 2*n^2-1 is prime
    if ( p[i / 2] ) {
      if ( jacobi ( i / 2 + 1, i ) != 1 ) continue;
      int s = sqrtMod( i / 2 + 1, i );
      for ( int j = s; j <= N ; j += i ) q[j] = false;
      for ( int j = i - s; j <= N ; j += i ) q[j] = false;
      int t = sqrt(  i / 2 + 1 );
      if (  i / 2 + 1 == t * t ) q[t] = true;
    }

  int c = 0;
  for (int i = 2; i <= N; i++ ) if ( q[i] ) c++;
  cout << c << endl;
}