情報系の手考ノート

数学とか情報系の技術とか調べたり勉強したりしてメモしていきます.

ある特定の剰余環上での2の羃倍を計算する

多倍長演算のアルゴリズムに Schönhage-Strassen Algorithm というものがあります。 これは、n桁の数同士の積が畳み込み演算で表わされることから、畳み込み定理を用いて乗算を高速化するアルゴリズムです。 畳み込み定理を利用するために剰余環上で離散フーリエ変換と似た計算を行うらしいです。 その仮定で m = 2^ k + 1 を法とする剰余環上である整数 \overline{y}

\begin{aligned}
\overline{y} = \overline{x \cdot 2^ p}
\end{aligned}

を計算する必要が出てきました。 ただ計算するだけなら問題ないのですが、 p は非常に大きい場合があり愚直に計算すると y が64bit整数の表現範囲を越える可能性があります。

つまり、愚直に計算をするためには多倍長演算を行うか、2を掛けて剰余をとるという計算を p 回繰り返す必要があります。 多倍長演算を高速化するのに、そんな効率の悪い方法でいいのかと調べてみたら、その解決策となる計算方法を見つけました(調べたのが昔すぎてどこのサイトか覚えてないですが...)。 それが

\begin{aligned}
x \cdot 2^ p = y_2 \cdot 2^ {2k} + y_1 \cdot 2^ k + y_0
\end{aligned}

ならば

\begin{aligned}
x \cdot 2^ p \equiv y_0 + y_2 - y_1 \qquad (\rm{mod} \quad 2^ k + 1)
\end{aligned}

というものでした。 なんでこうなるのか、というのが当時全く理解できてなかったのですが、最近導出できたのでその前提条件や導出仮定をまとめます。 ついでに実装もしてみます。

前提

対象はある自然数 k によって法が m = 2^ k + 1 と表わされるような整数環上の剰余環 \mathbb{Z}/m\mathbb{Z}です。 以下、代表元を x とする剰余環上の元を \overline{x} と書きます。

命題

ある 0 以上 2^ k 以下の整数 x  0 以上 2k 未満の整数 p に対して、

\begin{aligned}
x \cdot 2^ p = y_2 \cdot 2^ {2k} + y_1 \cdot 2^ k + y_0
\end{aligned}

を満たす 0 以上 2^ k 未満の整数 y_0, y_1, y_2 に対して

\begin{aligned}
\overline{x \cdot 2^ p} = \overline{y_0 + y_2 - y_1}
\end{aligned}

が成立する。

導出

まず、 p の範囲を 0 以上 2k 未満に制限したのは 、 \overline{2^ {2k}} = \overline{1} であるためです。  \overline{2^ {2k}} = \overline{1} であるため、 \overline{2^ {2k + \alpha}} = \overline{2^ \alpha} となるので、任意の p に対して考える必要がなくなります。

さらに p の範囲を制限したことで、 x \cdot 2^ p の範囲が 0 以上 2^ {3k} 未満に限定できます(実際にはもう少し範囲が限定できますが必要がないのでしません)。

したがって除法の定理から、ある y_0, y_1, y_2 が唯一つ存在して

\begin{aligned}
x \cdot 2^ p = y_2 \cdot 2^ {2k} + y_1 \cdot 2^ k + y_0
\end{aligned}

となることが言えます。

ここで y' = x \cdot 2^ p とすると、除法の定理よりある整数 q, r を用いて

\begin{aligned}
y' &= q m + r \\
   &= q \cdot 2^ k + q + r
\end{aligned}

と書くことができます( 0 \le q \le 2^ {2k}, 0 \le r \lt 2^ k )。 さらも、 q + r  0 以上 2^ k 未満の整数 s, y'_0 を用いて

\begin{aligned}
q + r = s \cdot 2^ k + y'_0
\end{aligned}

として置きかえると

\begin{aligned}
y' &= q \cdot 2^ k + q + r \\
   &= q \cdot 2^ k + s \cdot 2^ k + y'_0 \\
   &= ( q + s ) \cdot 2^ k + y'_0
\end{aligned}

と書くことができます。 同様にして q + s  0 以上 2^ k 未満の整数 y'_1, y'_2 を用いて

\begin{aligned}
q + s = y'_2 \cdot 2^ k + y'_1
\end{aligned}

として置きかえると

\begin{aligned}
y' &= ( q + s ) \cdot 2^ k + y'_0 \\
   &= ( y'_2 \cdot 2^ k + y'_1 ) \cdot 2^ k + y'_0 \\
   &= y'_2 \cdot 2^ {2k} + y'_1 \cdot 2^ k + y'_0
\end{aligned}

と書くことができます。

除法の定理より y_0, y_1, y_2 はそれぞれ y'_0, y'_1, y'_2 と一致するため、 y_0, y_1, y_2 には

\begin{aligned}
q + r = s \cdot 2^ k + y_0 \\
q + s = y_2 \cdot 2^ k + y_1
\end{aligned}

という関係が成立します。 この関係式から q を削除すると

\begin{aligned}
r - s & = s \cdot 2^ k + y_0 - y_2 \cdot 2^ k - y_1 \\
r &= s (2^ k + 1) + y_0 - y_2 \cdot 2^ k - y_1 \\
r &= s m + y_0 - y_2 \cdot 2^ k - y_1
\end{aligned}

となります。 ここで r  q  m で割った余りだったため

\begin{aligned}
\overline{x \cdot 2^ p} &= \overline{r} \\
                                   &= \overline{s m + y_0 - y_2 \cdot 2^ k - y_1} \\
                                   &= \overline{y_0 - y_2 \cdot 2^ k - y_1}
\end{aligned}

が成立し、 \overline{2^ k} = \overline{-1} であるため

\begin{aligned}
\overline{x \cdot 2^ p} = \overline{y_0 + y_2 - y_1}
\end{aligned}

となり、命題は示されました。

実際の計算

実際の計算ではもう少し工夫をします。

 y_0 + y_2 - y_1 という計算を行ないますが、 x  0 以上 2^ k 以下であるという制約から y_0, y_2 のどちらも非ゼロとなることはありません。 さらに y_0, y_1, y_2 はプログラム上ではビットシフトのみで計算できます。

以上をふまえてc++で簡単に実装したコードが以下のようになります。

std::uint64_t fmp(std::uint64_t k, std::uint64_t x, std::uint64_t p) {
    std::uint64_t m = (1 << k) + 1;         // 法
    std::uint64_t mask = (1 << k) - 1;
    std::uint64_t y;

    p = p % (2 * k);                        // p の範囲を制限

    if (p < k) {                            // y_2 がゼロの場合
        std::uint64_t y_0, y_1;
        y_0 = (x << p) & mask;              // ビットシフトで y_0 を計算
        y_1 = (x >> (k - p)) & mask;        // ビットシフトで y_1 を計算
        y = y_0 - y_1;
        return y < 0 ? y + m : y;           // 負数になったときに範囲を調整
    }
    else {                                  // y_1 がゼロの場合
        std::uint64_t y_2, y_1;
        y_2 = (x >> (2 * k - p)) & mask;    // ビットシフトで y_2 を計算
        y_1 = (x << (p - k)) & mask;        // ビットシフトで y_1 を計算
        y = y_2 - y_1;
        return y < 0 ? y + m : y;           // 負数になったときに範囲を調整
    }
}

まとめ

この方法で特定の剰余環上での2の羃との積を計算することができます。

剰余環上では、直感的でない方法でいろいろな計算ができるので、ほかにも剰余環関連のアルゴリズム等を調べてみたいです。