有限体上での連立方程式(ガウスの消去法)

有限体上で連立方程式を解くプログラム(ガウスの消去法の実装).

例えば,2元体上で

1 1 1 1
1 1 0 1
0 0 1 0

,つまり
x1 + x2 + x3 = 1
x1 + x2 = 1
x3 = 0
は掃き出すと

1 1 1 1
0 0 1 0
0 0 0 0

となり,上三角ではなくなるが,もちろん解は存在する(x1 = 1, x2 = x3 = 0) .

  • これを使えば,FlipItの解の存在判定が解ける.


以下はガウスの消去法本体. Ax = b (mod q) を解くのが目的.
a は係数行列,つまり a = [A | b]である.
x は解を記録する配列.
q は素数で有限体の位数.
連立方程式に解が存在するとき,Trueを返し,解無しの場合Falseを返す.

// 有限体上の線型方程式系 Ax = b (mod q)を解く
// a = [A | b]: m × n の係数行列
// x: 解を記録するベクトル
// 計算量: O(min(m, n) * m * n)
bool gauss(int **a, int *x, int m, int n, int q) {
  int rank = 0, *pivot = new int [n];
  // 前進消去
  for (int i = 0, j = 0; i < m && j < n-1; ++j) {
    int p = -1, tmp = 0;
    // ピボットを探す
    for (int k = i; p < 0 && k < m; ++k) {
      if (a[k][j] != 0) p = k;  // 有限体上なので非零で十分
    }
    // ランク落ち対策
    if (p == -1) continue;
    // 第i行と第p行を入れ替える
    for (int k = j; k < n; ++k)
      tmp = a[i][k], a[i][k] = a[p][k], a[p][k] = tmp;
    // 第i行を使って掃き出す
    for (int k = i+1; k < m; ++k) {
      tmp = - a[k][j] * invMod(a[i][j], q) % q;
      for (int l = j; l < n; ++l)
        a[k][l] += tmp * a[i][l];
    }
    // 第i行を正規化: a[i][j] = 1 にする
    tmp = invMod(a[i][j], q);
    for (int k = j; k < n; ++k)
      a[i][k] = a[i][k] * tmp % q;
    pivot[i++] = j, rank++;
  }
  // 解の存在のチェック
  for (int i = rank; i < m; ++i)
    if (a[i][n-1] != 0) return false;
  // 解をxに代入(後退代入)
  for (int i = 0; i < rank; ++i)
    x[i] = a[i][n-1];
  for (int i = rank-1; i >= 0; --i) {
    for (int j = pivot[i] + 1; j < n-1; ++j)
      x[i] -= a[i][j] * x[j];
    x[i] -= x[i] / q * q, x[i] = (x[i] + q) % q;  // 0 <= x[i] < q に調整
  }
  return true;
}

# なんか有限体ではなく素体というべきなのかもしれないと,いまさらながら思いはじめる.

ソースコード全体.適当なテストコードを含む.

#include <iostream>
#include <cmath>

using namespace std;

void show(int **a, int m, int n);

// input : a, b
// output : x, y  s.t. ax + by = (符号付き)gcd(a, b)
int extGcd( int a, int b, int& x, int& y ) {
  if ( b == 0 ) {
    x = 1; y = 0; return a;
  }
  int g = extGcd( b, a % b, y, x );
  y -= (a / b) * x;
  return g;
}

// xn = 1 (mod p)
int invMod(int n, int p) {
  int x, y, g = extGcd ( n, p, x, y );
  if (g == 1) return x;
  else if (g == -1) return -x;
  else return 0; // gcd(n, p) != 1,解なし
}

// 有限体上の線型方程式系 Ax = b (mod q)を解く
// a = [A | b]: m × n の係数行列
// x: 解を記録するベクトル
// 計算量: O(min(m, n) * m * n)
bool gauss(int **a, int *x, int m, int n, int q) {
  int rank = 0, *pivot = new int [n];
  // 前進消去
  for (int i = 0, j = 0; i < m && j < n-1; ++j) {
    int p = -1, tmp = 0;
    // ピボットを探す
    for (int k = i; p < 0 && k < m; ++k) {
      if (a[k][j] != 0) p = k;  // 有限体上なので非零で十分
    }
    // ランク落ち対策
    if (p == -1) continue;
    // 第i行と第p行を入れ替える
    for (int k = j; k < n; ++k)
      tmp = a[i][k], a[i][k] = a[p][k], a[p][k] = tmp;
    // 第i行を使って掃き出す
    for (int k = i+1; k < m; ++k) {
      tmp = - a[k][j] * invMod(a[i][j], q) % q;
      for (int l = j; l < n; ++l)
        a[k][l] += tmp * a[i][l];
    }
    // 第i行を正規化: a[i][j] = 1 にする
    tmp = invMod(a[i][j], q);
    for (int k = j; k < n; ++k)
      a[i][k] = a[i][k] * tmp % q;
    pivot[i++] = j, rank++;
  }
  // 解の存在のチェック
  for (int i = rank; i < m; ++i)
    if (a[i][n-1] != 0) return false;
  // 解をxに代入(後退代入)
  for (int i = 0; i < rank; ++i)
    x[i] = a[i][n-1];
  for (int i = rank-1; i >= 0; --i) {
    for (int j = pivot[i] + 1; j < n-1; ++j)
      x[i] -= a[i][j] * x[j];
    x[i] -= x[i] / q * q, x[i] = (x[i] + q) % q;  // 0 <= x[i] < q に調整
  }
  return true;
}

int main() {
  const int m = 4, n = 5;
  int **a, *x;
  a = new int* [m];
  for (int i = 0; i < m; ++i)
    a[i] = new int[n];
  a[0][0] = 1, a[0][1] = 1, a[0][2] = 1, a[0][3] = 0, a[0][4] = 1;
  a[1][0] = 1, a[1][1] = 1, a[1][2] = 0, a[1][3] = 1, a[1][4] = 0;
  a[2][0] = 1, a[2][1] = 0, a[2][2] = 1, a[2][3] = 1, a[2][4] = 0;
  a[3][0] = 0, a[3][1] = 1, a[3][2] = 1, a[3][3] = 1, a[3][4] = 1;
  x = new int[n-1];

  show(a, m, n);
  gauss(a, x, m, n, 2);
  show(a, m, n);
  for (int i = 0; i < n-1; ++i)
    cout << x[i] << " ";
  cout << endl;
}

void show(int **a, int m, int n) {
  cout << "==============================" << endl;
  for (int i = 0; i < m; ++i) {
    for (int j = 0; j < n; ++j)
      cout << a[i][j] << " ";
    cout << endl;
  }
  cout << "==============================" << endl;
}