; Copyright (C) 2026 Kiyotsugu Arai
; SPDX-License-Identifier: LGPL-3.0-or-later
;
; mpn_x64_mont.asm — Montgomery CIOS 乗算 / REDC (BMI2+ADX)
;
; 関数:
;   mpn_mont_mul_mulx(rp, ap, bp, n, mp, m_inv, scratch)
;   mpn_mont_redc_mulx(rp, tp, tn, mp, n, m_inv)
;   mpn_mont_mul_8(rp, ap, bp, mp, m_inv)  — n=8 特化 SOS
;   mpn_mont_redc_16(rp, tp, mp, m_inv)   — n=16 特化 REDC
;   mpn_mont_redc_32(rp, tp, mp, m_inv)   — n=32 特化 REDC
;   mpn_mont_mul_4(rp, ap, bp, mp, m_inv)  — n=4 特化 SOS
;   mpn_mont_sqr_4(rp, ap, mp, m_inv)    — n=4 統合 SQR+REDC
;   mpn_mont_mul_2(rp, ap, bp, mp, m_inv)  — n=2 特化 SOS
;   mpn_mont_sqr_2(rp, ap, mp, m_inv)    — n=2 統合 SQR+REDC
;   mpn_mont_sqr_16(rp, ap, mp, m_inv)   — n=16 統合 SQR+REDC
;
; 前提: CPU が BMI2 + ADX をサポートしていること (呼び出し元で CPUID チェック済み)
;
; Windows x64 calling convention:
;   rcx = 1st, rdx = 2nd, r8 = 3rd, r9 = 4th
;   戻り値: rax
;   非破壊: rbx, rbp, rdi, rsi, r12-r15
;   破壊可: rax, rcx, rdx, r8, r9, r10, r11

; =====================================================================
; Macro: inline addmul_1 body (4x unrolled + 1x tail)
;
; 入力: rdx = multiplier, rsi = source, rbx = dest, rcx = n
;        r9 = carry (初期値 0), r8 = 0 constant
; 出力: r9 = final carry
; 破壊: r10-r15, rdi, rax, flags
; =====================================================================
INLINE_ADDMUL_1 MACRO
LOCAL am_4x, am_tail_check, am_tail, am_done
    cmp     rcx, 4
    jb      am_tail_check
am_4x:
    xor     eax, eax            ; CF=0, OF=0
    mulx    r10, r11, [rsi]
    mulx    r12, r13, [rsi+8]
    mulx    r14, r15, [rsi+16]
    mulx    rdi, rax, [rsi+24]
    adcx    r11, r9
    adox    r11, [rbx]
    mov     QWORD PTR [rbx], r11
    adcx    r13, r10
    adox    r13, [rbx+8]
    mov     QWORD PTR [rbx+8], r13
    adcx    r15, r12
    adox    r15, [rbx+16]
    mov     QWORD PTR [rbx+16], r15
    adcx    rax, r14
    adox    rax, [rbx+24]
    mov     QWORD PTR [rbx+24], rax
    mov     r9, rdi
    adcx    r9, r8
    adox    r9, r8
    lea     rsi, [rsi+32]
    lea     rbx, [rbx+32]
    sub     rcx, 4
    cmp     rcx, 4
    jae     am_4x
am_tail_check:
    test    rcx, rcx
    jz      am_done
am_tail:
    mulx    r10, r11, [rsi]
    add     r11, r9
    adc     r10, 0
    add     QWORD PTR [rbx], r11
    adc     r10, 0
    mov     r9, r10
    lea     rsi, [rsi+8]
    lea     rbx, [rbx+8]
    dec     rcx
    jnz     am_tail
am_done:
ENDM


; =====================================================================
; n=8 レジスタ常駐 Montgomery 乗算マクロ群
; IPP 解析に基づくシフトトリック: 累算器 r8-r14 + r15(carry)
; rbp/rbx を MULX hi/lo テンポラリとして交互使用
; =====================================================================

; 8-limb addmul 第 1 リム: デュアルキャリーチェーン開始
; 入力: rdx = 乗数, rsi = ソース配列, r8-r14 = acc[0..6], r15 = acc[7]
; 出力: r8 = シフトアウトされたワード (格納 or 破棄), rbx = hi (REST_8 用)
ADDMUL_FIRST_8 MACRO
    xor     eax, eax
    mulx    rbx, rbp, [rsi]
    adox    r8, rbp
ENDM

; 8-limb addmul 残り 7 リム + キャリー捕捉
; 入力: ADDMUL_FIRST_8 直後, rbx = 第 0 リムの hi
; 出力: r8-r14 = 新 acc[0..6], r15 = キャリーワード, rax = 0
ADDMUL_REST_8 MACRO
    adcx    r9, rbx
    mulx    rbp, r8, [rsi+8]
    adox    r8, r9
    adcx    r10, rbp
    mulx    rbx, r9, [rsi+16]
    adox    r9, r10
    adcx    r11, rbx
    mulx    rbp, r10, [rsi+24]
    adox    r10, r11
    adcx    r12, rbp
    mulx    rbx, r11, [rsi+32]
    adox    r11, r12
    adcx    r13, rbx
    mulx    rbp, r12, [rsi+40]
    adox    r12, r13
    adcx    r14, rbp
    mulx    rbx, r13, [rsi+48]
    adox    r13, r14
    adcx    r15, rbx
    mulx    rbp, r14, [rsi+56]
    adox    r14, r15
    mov     r15, rax
    adcx    r15, rbp
    adox    r15, rax
ENDM

; REDC 1 反復: q = r8 * m_inv, acc += mp * q, シフト, P[8+i] にキャリー加算
; 入力: rsi = mp, rdi = product buffer, r8-r14 = acc[0..6], r15 = acc[7]
;        [rsp+168] = m_inv
; 出力: r8-r14 = シフト後 acc[0..6], r15 = P[8+i] + carry
; 4-word addmul ブロック (固定オフセット版, ポインタ進行なし)
; 入力: rdx = multiplier, rsi = source base, rbx = dest base
;        r9 = carry_in, r8 = 0 constant, off = バイトオフセット
; 出力: r9 = carry_out
; 破壊: r10-r15, rdi, rax, flags
ADDMUL_BLOCK_4 MACRO off
    xor     eax, eax
    mulx    r10, r11, [rsi+off]
    mulx    r12, r13, [rsi+off+8]
    mulx    r14, r15, [rsi+off+16]
    mulx    rdi, rax, [rsi+off+24]
    adcx    r11, r9
    adox    r11, [rbx+off]
    mov     [rbx+off], r11
    adcx    r13, r10
    adox    r13, [rbx+off+8]
    mov     [rbx+off+8], r13
    adcx    r15, r12
    adox    r15, [rbx+off+16]
    mov     [rbx+off+16], r15
    adcx    rax, r14
    adox    rax, [rbx+off+24]
    mov     [rbx+off+24], rax
    mov     r9, rdi
    adcx    r9, r8
    adox    r9, r8
ENDM

REDC_ITER MACRO iter_idx
LOCAL ri_nc
    mov     rax, r8
    imul    rax, QWORD PTR [rsp+168]
    mov     rdx, rax
    ADDMUL_FIRST_8
    ADDMUL_REST_8
    add     r15, QWORD PTR [rdi + (8 + iter_idx)*8]
    jnc     ri_nc
_prop_idx = 9 + iter_idx
    WHILE _prop_idx LE 16
        add     QWORD PTR [rdi + _prop_idx*8], 1
        jnc     ri_nc
_prop_idx = _prop_idx + 1
    ENDM
ri_nc:
ENDM


.code

; =====================================================================
; void mpn_mont_mul_mulx(uint64_t* rp, const uint64_t* ap,
;                         const uint64_t* bp, size_t n,
;                         const uint64_t* mp, uint64_t m_inv,
;                         uint64_t* scratch)
;
; CIOS Montgomery 乗算: rp = ap * bp * R^{-1} mod mp
;
; rcx = rp, rdx = ap, r8 = bp, r9 = n
; [rsp+40] = mp, [rsp+48] = m_inv, [rsp+56] = scratch
;
; scratch は 2n+1 ワード以上必要 (関数内でゼロ初期化される)
;
; addmul_1 をインライン化し、push/pop オーバーヘッドを除去。
; 従来: 2n 回の addmul_1 関数呼び出し (各 7 push + 7 pop)
; 本関数: 1 回の push/pop で全反復を実行
; =====================================================================
mpn_mont_mul_mulx PROC
    push    rbx
    push    rsi
    push    rdi
    push    r12
    push    r13
    push    r14
    push    r15
    push    rbp
    sub     rsp, 64             ; ローカル変数 8 qwords

    ; スタック配置:
    ;   8 push × 8 = 64, sub 64 → total = 128
    ;   5th arg (mp):      [rsp + 168]
    ;   6th arg (m_inv):   [rsp + 176]
    ;   7th arg (scratch):  [rsp + 184]
    ;
    ; ローカル変数:
    ;   [rsp+0]  = rp
    ;   [rsp+8]  = ap
    ;   [rsp+16] = n
    ;   [rsp+24] = mp
    ;   [rsp+32] = m_inv
    ;   [rsp+40] = scratch
    ;   [rsp+48] = bp

    mov     [rsp+0], rcx
    mov     [rsp+8], rdx
    mov     [rsp+16], r9
    mov     rax, [rsp+168]
    mov     [rsp+24], rax
    mov     rax, [rsp+176]
    mov     [rsp+32], rax
    mov     rax, [rsp+184]
    mov     [rsp+40], rax
    mov     [rsp+48], r8

    ; scratch[0..2n] をゼロ初期化
    mov     rdi, rax            ; scratch ptr
    xor     eax, eax
    mov     rcx, r9
    lea     rcx, [rcx*2+1]
    rep     stosq

    ; === 外側ループ: i = 0 .. n-1 ===
    xor     ebp, ebp            ; rbp = i = 0

mm_outer:
    ; --- Phase 1: scratch[i..i+n-1] += ap[0..n-1] * bp[i] ---
    mov     rsi, [rsp+8]        ; ap
    mov     rbx, [rsp+40]       ; scratch
    lea     rbx, [rbx + rbp*8]
    mov     rax, [rsp+48]       ; bp
    mov     rdx, [rax + rbp*8]  ; bp[i] → MULX multiplier
    mov     rcx, [rsp+16]       ; n
    xor     r9d, r9d
    xor     r8d, r8d

    INLINE_ADDMUL_1

    ; carry (r9) → scratch[i+n], propagate
    mov     rax, [rsp+40]
    mov     rcx, rbp
    add     rcx, [rsp+16]       ; i + n
    add     [rax + rcx*8], r9
    jnc     mm_phase2
    mov     rdi, [rsp+16]
    shl     rdi, 1              ; 2n
mm_carry1:
    inc     rcx
    cmp     rcx, rdi
    ja      mm_phase2
    add     QWORD PTR [rax + rcx*8], 1
    jc      mm_carry1

mm_phase2:
    ; --- Phase 2: q = scratch[i] * m_inv, scratch += mp * q ---
    mov     rax, [rsp+40]       ; scratch
    mov     rax, [rax + rbp*8]  ; scratch[i]
    imul    rax, QWORD PTR [rsp+32]  ; q = scratch[i] * m_inv
    mov     rdx, rax            ; q → MULX multiplier

    mov     rsi, [rsp+24]       ; mp
    mov     rbx, [rsp+40]       ; scratch
    lea     rbx, [rbx + rbp*8]
    mov     rcx, [rsp+16]       ; n
    xor     r9d, r9d
    xor     r8d, r8d

    INLINE_ADDMUL_1

    ; carry (r9) → scratch[i+n], propagate
    mov     rax, [rsp+40]
    mov     rcx, rbp
    add     rcx, [rsp+16]
    add     [rax + rcx*8], r9
    jnc     mm_next
    mov     rdi, [rsp+16]
    shl     rdi, 1
mm_carry2:
    inc     rcx
    cmp     rcx, rdi
    ja      mm_next
    add     QWORD PTR [rax + rcx*8], 1
    jc      mm_carry2

mm_next:
    inc     rbp
    cmp     rbp, [rsp+16]
    jb      mm_outer

    ; === 条件付き減算: scratch[n..2n-1] >= mp なら減算 ===
    ; scratch[n..2n] - mp[0..n-1] を試行し、borrow なしなら採用
    mov     rax, [rsp+40]       ; scratch
    mov     rcx, [rsp+16]       ; n
    mov     rbx, [rsp+0]        ; rp (結果書き込み先)
    mov     rsi, [rsp+24]       ; mp

    ; 負インデックスループ (inc は CF を破壊しない)
    mov     r11, rcx            ; r11 = n
    lea     rdi, [rax + rcx*8]  ; &scratch[n]
    neg     rcx
    lea     r12, [rdi + r11*8]  ; &scratch[2n] (save for scratch[2n] check)
    lea     rdi, [rdi + r11*8]  ; &scratch[2n]
    lea     rsi, [rsi + r11*8]  ; &mp[n]
    lea     rbx, [rbx + r11*8]  ; &rp[n]
    clc
mm_sub_loop:
    mov     r10, [rdi + rcx*8]
    sbb     r10, [rsi + rcx*8]
    mov     [rbx + rcx*8], r10
    inc     rcx                 ; CF 非破壊
    jnz     mm_sub_loop

    ; scratch[2n] - 0 - borrow
    mov     r10, [r12]          ; scratch[2n]
    sbb     r10, 0
    jnc     mm_done             ; borrow なし → 減算結果が有効

    ; borrow あり → 元の scratch[n..2n-1] を rp にコピー
    mov     rax, [rsp+40]       ; scratch
    mov     rbx, [rsp+0]        ; rp
    lea     rdi, [rax + r11*8]  ; &scratch[n]
    xor     r9d, r9d
mm_copy:
    mov     r10, [rdi + r9*8]
    mov     [rbx + r9*8], r10
    inc     r9
    cmp     r9, r11
    jb      mm_copy

mm_done:
    add     rsp, 64
    pop     rbp
    pop     r15
    pop     r14
    pop     r13
    pop     r12
    pop     rdi
    pop     rsi
    pop     rbx
    ret
mpn_mont_mul_mulx ENDP


; =====================================================================
; void mpn_mont_redc_mulx(uint64_t* rp, uint64_t* tp, size_t tn,
;                          const uint64_t* mp, size_t n, uint64_t m_inv)
;
; Montgomery リダクション: rp = tp * R^{-1} mod mp
; tp[0..tn-1] は破壊される (in-place リダクション)
; tn >= 2n+1 でなければならない
;
; rcx = rp, rdx = tp, r8 = tn, r9 = mp
; [rsp+40] = n, [rsp+48] = m_inv
;
; mpn::square() の出力を Montgomery リダクションする際に使用。
; C++ 版 mont_redc のスカラーキャリーループを ADCX/ADOX チェーンに置換。
; =====================================================================
mpn_mont_redc_mulx PROC
    push    rbx
    push    rsi
    push    rdi
    push    r12
    push    r13
    push    r14
    push    r15
    push    rbp
    sub     rsp, 48             ; ローカル変数 6 qwords

    ; スタック配置:
    ;   8 push × 8 = 64, sub 48 → total = 112
    ;   5th arg (n):       [rsp + 152]
    ;   6th arg (m_inv):   [rsp + 160]
    ;
    ; ローカル変数:
    ;   [rsp+0]  = rp
    ;   [rsp+8]  = tp
    ;   [rsp+16] = tn
    ;   [rsp+24] = mp
    ;   [rsp+32] = n
    ;   [rsp+40] = m_inv

    mov     [rsp+0], rcx
    mov     [rsp+8], rdx
    mov     [rsp+16], r8
    mov     [rsp+24], r9
    mov     rax, [rsp+152]
    mov     [rsp+32], rax
    mov     rax, [rsp+160]
    mov     [rsp+40], rax

    ; === 外側ループ: i = 0 .. n-1 ===
    xor     ebp, ebp            ; rbp = i = 0

mr_outer:
    ; q = tp[i] * m_inv
    mov     rax, [rsp+8]        ; tp
    mov     rax, [rax + rbp*8]  ; tp[i]
    imul    rax, QWORD PTR [rsp+40]  ; q = tp[i] * m_inv
    mov     rdx, rax            ; q → MULX multiplier

    ; tp[i..i+n-1] += mp[0..n-1] * q
    mov     rsi, [rsp+24]       ; mp
    mov     rbx, [rsp+8]        ; tp
    lea     rbx, [rbx + rbp*8]
    mov     rcx, [rsp+32]       ; n
    xor     r9d, r9d
    xor     r8d, r8d

    INLINE_ADDMUL_1

    ; carry (r9) → tp[i+n], propagate
    mov     rax, [rsp+8]        ; tp
    mov     rcx, rbp
    add     rcx, [rsp+32]       ; i + n
    mov     rdi, [rsp+16]       ; tn
    dec     rdi                 ; max valid index = tn - 1
    add     [rax + rcx*8], r9
    jnc     mr_next
mr_carry:
    inc     rcx
    cmp     rcx, rdi
    ja      mr_next
    add     QWORD PTR [rax + rcx*8], 1
    jc      mr_carry

mr_next:
    inc     rbp
    cmp     rbp, [rsp+32]
    jb      mr_outer

    ; === tp[n..2n-1] → rp, 条件付き減算 ===
    mov     rax, [rsp+8]        ; tp
    mov     rcx, [rsp+32]       ; n
    mov     rbx, [rsp+0]        ; rp
    mov     rsi, [rsp+24]       ; mp

    mov     r11, rcx            ; r11 = n
    lea     rdi, [rax + rcx*8]  ; &tp[n]
    neg     rcx
    lea     rdi, [rdi + r11*8]  ; &tp[2n]
    lea     rsi, [rsi + r11*8]
    lea     rbx, [rbx + r11*8]
    clc
mr_sub_loop:
    mov     r10, [rdi + rcx*8]
    sbb     r10, [rsi + rcx*8]
    mov     [rbx + rcx*8], r10
    inc     rcx
    jnz     mr_sub_loop

    ; tp[2n] がある場合のチェック
    mov     rax, [rsp+8]
    lea     r12, [rax + r11*8]
    lea     r12, [r12 + r11*8]  ; &tp[2n]
    mov     r10, [r12]
    sbb     r10, 0
    jnc     mr_done

    ; borrow → 元の tp[n..2n-1] を rp にコピー
    mov     rax, [rsp+8]
    mov     rbx, [rsp+0]
    lea     rdi, [rax + r11*8]
    xor     r9d, r9d
mr_copy:
    mov     r10, [rdi + r9*8]
    mov     [rbx + r9*8], r10
    inc     r9
    cmp     r9, r11
    jb      mr_copy

mr_done:
    add     rsp, 48
    pop     rbp
    pop     r15
    pop     r14
    pop     r13
    pop     r12
    pop     rdi
    pop     rsi
    pop     rbx
    ret
mpn_mont_redc_mulx ENDP


; =====================================================================
; void mpn_mont_mul_8(uint64_t* rp, const uint64_t* ap,
;                      const uint64_t* bp, const uint64_t* mp,
;                      uint64_t m_inv)
;
; レジスタ常駐 n=8 SOS Montgomery 乗算: rp = ap * bp * R^{-1} mod mp
; 512-bit 特化、scratch バッファ不要 (スタック上に product buffer)
;
; rcx = rp, rdx = ap, r8 = bp, r9 = mp, [rsp+40] = m_inv
;
; Phase 1: 8×8 完全乗算 (累算器 r8-r14 レジスタ常駐, シフトトリック)
; Phase 2: REDC 8 反復 (同様のレジスタ常駐 addmul)
; Phase 3: 条件付き減算
; =====================================================================
mpn_mont_mul_8 PROC
    push    rbx
    push    rsi
    push    rdi
    push    r12
    push    r13
    push    r14
    push    r15
    push    rbp
    sub     rsp, 184            ; P[0..16] 136B + locals 40B + pad 8B

    ; スタック配置:
    ;   8 push × 8 = 64, sub 184 → total = 248
    ;   5th arg (m_inv): [rsp + 288]
    ;
    ; ローカル変数:
    ;   [rsp+0..135]  = P[0..16] product buffer (17 qwords)
    ;   [rsp+136]     = rp
    ;   [rsp+144]     = ap
    ;   [rsp+152]     = bp
    ;   [rsp+160]     = mp
    ;   [rsp+168]     = m_inv

    mov     [rsp+136], rcx
    mov     [rsp+144], rdx
    mov     [rsp+152], r8
    mov     [rsp+160], r9
    mov     rax, [rsp+288]
    mov     [rsp+168], rax

    mov     rdi, rsp            ; product buffer base
    mov     rsi, rdx            ; ap
    mov     rcx, r8             ; bp

    ; ===== Phase 1: 8×8 完全乗算 (SOS) =====
    xor     r8d, r8d
    xor     r9d, r9d
    xor     r10d, r10d
    xor     r11d, r11d
    xor     r12d, r12d
    xor     r13d, r13d
    xor     r14d, r14d
    xor     r15d, r15d

    ; Row 0: b[0]
    mov     rdx, [rcx]
    ADDMUL_FIRST_8
    mov     QWORD PTR [rdi], r8
    ADDMUL_REST_8

    ; Row 1: b[1]
    mov     rdx, [rcx+8]
    ADDMUL_FIRST_8
    mov     QWORD PTR [rdi+8], r8
    ADDMUL_REST_8

    ; Row 2: b[2]
    mov     rdx, [rcx+16]
    ADDMUL_FIRST_8
    mov     QWORD PTR [rdi+16], r8
    ADDMUL_REST_8

    ; Row 3: b[3]
    mov     rdx, [rcx+24]
    ADDMUL_FIRST_8
    mov     QWORD PTR [rdi+24], r8
    ADDMUL_REST_8

    ; Row 4: b[4]
    mov     rdx, [rcx+32]
    ADDMUL_FIRST_8
    mov     QWORD PTR [rdi+32], r8
    ADDMUL_REST_8

    ; Row 5: b[5]
    mov     rdx, [rcx+40]
    ADDMUL_FIRST_8
    mov     QWORD PTR [rdi+40], r8
    ADDMUL_REST_8

    ; Row 6: b[6]
    mov     rdx, [rcx+48]
    ADDMUL_FIRST_8
    mov     QWORD PTR [rdi+48], r8
    ADDMUL_REST_8

    ; Row 7: b[7]
    mov     rdx, [rcx+56]
    ADDMUL_FIRST_8
    mov     QWORD PTR [rdi+56], r8
    ADDMUL_REST_8

    ; P[8..15] ← 累算器, P[16] = 0
    mov     [rdi+64], r8
    mov     [rdi+72], r9
    mov     [rdi+80], r10
    mov     [rdi+88], r11
    mov     [rdi+96], r12
    mov     [rdi+104], r13
    mov     [rdi+112], r14
    mov     [rdi+120], r15
    mov     QWORD PTR [rdi+128], 0

    ; ===== Phase 2: REDC (8 反復) =====
    mov     r8, [rdi]
    mov     r9, [rdi+8]
    mov     r10, [rdi+16]
    mov     r11, [rdi+24]
    mov     r12, [rdi+32]
    mov     r13, [rdi+40]
    mov     r14, [rdi+48]
    mov     r15, [rdi+56]

    mov     rsi, [rsp+160]      ; mp

    REDC_ITER 0
    REDC_ITER 1
    REDC_ITER 2
    REDC_ITER 3
    REDC_ITER 4
    REDC_ITER 5
    REDC_ITER 6
    REDC_ITER 7

    ; ===== Phase 3: 条件付き減算 =====
    ; r8-r14 = result[0..6], r15 = result[7]
    mov     rbx, [rsp+136]      ; rp
    mov     rsi, [rsp+160]      ; mp

    mov     rax, r8
    sub     rax, [rsi]
    mov     [rbx], rax

    mov     rax, r9
    sbb     rax, [rsi+8]
    mov     [rbx+8], rax

    mov     rax, r10
    sbb     rax, [rsi+16]
    mov     [rbx+16], rax

    mov     rax, r11
    sbb     rax, [rsi+24]
    mov     [rbx+24], rax

    mov     rax, r12
    sbb     rax, [rsi+32]
    mov     [rbx+32], rax

    mov     rax, r13
    sbb     rax, [rsi+40]
    mov     [rbx+40], rax

    mov     rax, r14
    sbb     rax, [rsi+48]
    mov     [rbx+48], rax

    mov     rax, r15
    sbb     rax, [rsi+56]
    mov     [rbx+56], rax

    ; P[16] のオーバーフロー考慮
    mov     rax, [rdi+128]
    sbb     rax, 0
    jnc     m8_done

    ; borrow → 減算結果は無効、元の結果を rp にコピー
    mov     [rbx], r8
    mov     [rbx+8], r9
    mov     [rbx+16], r10
    mov     [rbx+24], r11
    mov     [rbx+32], r12
    mov     [rbx+40], r13
    mov     [rbx+48], r14
    mov     [rbx+56], r15

m8_done:
    add     rsp, 184
    pop     rbp
    pop     r15
    pop     r14
    pop     r13
    pop     r12
    pop     rdi
    pop     rsi
    pop     rbx
    ret
mpn_mont_mul_8 ENDP


; =====================================================================
; void mpn_mont_sqr_8(uint64_t* rp, const uint64_t* ap,
;                      const uint64_t* mp, uint64_t m_inv)
;
; レジスタ常駐 n=8 SOS Montgomery 自乗: rp = ap^2 * R^{-1} mod mp
; 512-bit 特化, 対称性利用 (28+8=36 MULX vs 乗算の 64 MULX)
;
; rcx = rp, rdx = ap, r8 = mp, r9 = m_inv
;
; Phase 1a: 上三角乗算 (off-diagonal, ADCX/ADOX デュアルキャリー)
; Phase 1b: 二倍化 (P[1..15] <<= 1)
; Phase 1c: 対角成分 a[i]^2 加算
; Phase 2: REDC 8 反復 (REDC_ITER マクロ再利用)
; Phase 3: 条件付き減算
; =====================================================================
mpn_mont_sqr_8 PROC
    push    rbx
    push    rsi
    push    rdi
    push    r12
    push    r13
    push    r14
    push    r15
    push    rbp
    sub     rsp, 184            ; P[0..16] 136B + locals 48B

    ; スタック配置 (mont_mul_8 と REDC_ITER 互換):
    ;   [rsp+0..135]  = P[0..16] product buffer (17 qwords)
    ;   [rsp+136]     = rp
    ;   [rsp+160]     = mp
    ;   [rsp+168]     = m_inv

    mov     [rsp+136], rcx
    mov     rsi, rdx            ; rsi = ap
    mov     [rsp+160], r8       ; mp
    mov     [rsp+168], r9       ; m_inv

    mov     rdi, rsp            ; rdi = product buffer P

    ; --- Zero P[0..16] ---
    xor     eax, eax
    mov     [rdi], rax
    mov     [rdi+8], rax
    mov     [rdi+16], rax
    mov     [rdi+24], rax
    mov     [rdi+32], rax
    mov     [rdi+40], rax
    mov     [rdi+48], rax
    mov     [rdi+56], rax
    mov     [rdi+64], rax
    mov     [rdi+72], rax
    mov     [rdi+80], rax
    mov     [rdi+88], rax
    mov     [rdi+96], rax
    mov     [rdi+104], rax
    mov     [rdi+112], rax
    mov     [rdi+120], rax
    mov     [rdi+128], rax

    ; ===== Phase 1a: Off-diagonal 上三角 (28 MULX) =====

    ; --- Row 0: a[0]*a[1..7] → P[1..8] (P=0 なので store のみ) ---
    mov     rdx, [rsi]
    mulx    rbx, rcx, [rsi+8]
    mov     [rdi+8], rcx
    mulx    rbp, rcx, [rsi+16]
    add     rcx, rbx
    mov     [rdi+16], rcx
    mulx    rbx, rcx, [rsi+24]
    adc     rcx, rbp
    mov     [rdi+24], rcx
    mulx    rbp, rcx, [rsi+32]
    adc     rcx, rbx
    mov     [rdi+32], rcx
    mulx    rbx, rcx, [rsi+40]
    adc     rcx, rbp
    mov     [rdi+40], rcx
    mulx    rbp, rcx, [rsi+48]
    adc     rcx, rbx
    mov     [rdi+48], rcx
    mulx    rbx, rcx, [rsi+56]
    adc     rcx, rbp
    mov     [rdi+56], rcx
    adc     rbx, 0
    mov     [rdi+64], rbx

    ; --- Row 1: a[1]*a[2..7] → P[3..9] (ADCX/ADOX) ---
    mov     rdx, [rsi+8]
    xor     eax, eax
    mulx    rbx, rcx, [rsi+16]
    adox    rcx, [rdi+24]
    mov     [rdi+24], rcx
    mulx    rbp, rcx, [rsi+24]
    adcx    rcx, rbx
    adox    rcx, [rdi+32]
    mov     [rdi+32], rcx
    mulx    rbx, rcx, [rsi+32]
    adcx    rcx, rbp
    adox    rcx, [rdi+40]
    mov     [rdi+40], rcx
    mulx    rbp, rcx, [rsi+40]
    adcx    rcx, rbx
    adox    rcx, [rdi+48]
    mov     [rdi+48], rcx
    mulx    rbx, rcx, [rsi+48]
    adcx    rcx, rbp
    adox    rcx, [rdi+56]
    mov     [rdi+56], rcx
    mulx    rbp, rcx, [rsi+56]
    adcx    rcx, rbx
    adox    rcx, [rdi+64]
    mov     [rdi+64], rcx
    ; tail → P[9]
    mov     rcx, rax
    adcx    rcx, rbp
    adox    rcx, [rdi+72]
    mov     [rdi+72], rcx
    ; residual CF+OF → P[10]
    mov     rcx, rax
    adcx    rcx, rcx
    adox    rax, rax
    add     rcx, rax
    add     [rdi+80], rcx

    ; --- Row 2: a[2]*a[3..7] → P[5..10] ---
    mov     rdx, [rsi+16]
    xor     eax, eax
    mulx    rbx, rcx, [rsi+24]
    adox    rcx, [rdi+40]
    mov     [rdi+40], rcx
    mulx    rbp, rcx, [rsi+32]
    adcx    rcx, rbx
    adox    rcx, [rdi+48]
    mov     [rdi+48], rcx
    mulx    rbx, rcx, [rsi+40]
    adcx    rcx, rbp
    adox    rcx, [rdi+56]
    mov     [rdi+56], rcx
    mulx    rbp, rcx, [rsi+48]
    adcx    rcx, rbx
    adox    rcx, [rdi+64]
    mov     [rdi+64], rcx
    mulx    rbx, rcx, [rsi+56]
    adcx    rcx, rbp
    adox    rcx, [rdi+72]
    mov     [rdi+72], rcx
    mov     rcx, rax
    adcx    rcx, rbx
    adox    rcx, [rdi+80]
    mov     [rdi+80], rcx
    mov     rcx, rax
    adcx    rcx, rcx
    adox    rax, rax
    add     rcx, rax
    add     [rdi+88], rcx

    ; --- Row 3: a[3]*a[4..7] → P[7..11] ---
    mov     rdx, [rsi+24]
    xor     eax, eax
    mulx    rbx, rcx, [rsi+32]
    adox    rcx, [rdi+56]
    mov     [rdi+56], rcx
    mulx    rbp, rcx, [rsi+40]
    adcx    rcx, rbx
    adox    rcx, [rdi+64]
    mov     [rdi+64], rcx
    mulx    rbx, rcx, [rsi+48]
    adcx    rcx, rbp
    adox    rcx, [rdi+72]
    mov     [rdi+72], rcx
    mulx    rbp, rcx, [rsi+56]
    adcx    rcx, rbx
    adox    rcx, [rdi+80]
    mov     [rdi+80], rcx
    mov     rcx, rax
    adcx    rcx, rbp
    adox    rcx, [rdi+88]
    mov     [rdi+88], rcx
    mov     rcx, rax
    adcx    rcx, rcx
    adox    rax, rax
    add     rcx, rax
    add     [rdi+96], rcx

    ; --- Row 4: a[4]*a[5..7] → P[9..12] ---
    mov     rdx, [rsi+32]
    xor     eax, eax
    mulx    rbx, rcx, [rsi+40]
    adox    rcx, [rdi+72]
    mov     [rdi+72], rcx
    mulx    rbp, rcx, [rsi+48]
    adcx    rcx, rbx
    adox    rcx, [rdi+80]
    mov     [rdi+80], rcx
    mulx    rbx, rcx, [rsi+56]
    adcx    rcx, rbp
    adox    rcx, [rdi+88]
    mov     [rdi+88], rcx
    mov     rcx, rax
    adcx    rcx, rbx
    adox    rcx, [rdi+96]
    mov     [rdi+96], rcx
    mov     rcx, rax
    adcx    rcx, rcx
    adox    rax, rax
    add     rcx, rax
    add     [rdi+104], rcx

    ; --- Row 5: a[5]*a[6..7] → P[11..13] ---
    mov     rdx, [rsi+40]
    xor     eax, eax
    mulx    rbx, rcx, [rsi+48]
    adox    rcx, [rdi+88]
    mov     [rdi+88], rcx
    mulx    rbp, rcx, [rsi+56]
    adcx    rcx, rbx
    adox    rcx, [rdi+96]
    mov     [rdi+96], rcx
    mov     rcx, rax
    adcx    rcx, rbp
    adox    rcx, [rdi+104]
    mov     [rdi+104], rcx
    mov     rcx, rax
    adcx    rcx, rcx
    adox    rax, rax
    add     rcx, rax
    add     [rdi+112], rcx

    ; --- Row 6: a[6]*a[7] → P[13..14] (単一積, add/adc) ---
    mov     rdx, [rsi+48]
    mulx    rbx, rcx, [rsi+56]
    add     [rdi+104], rcx
    adc     [rdi+112], rbx
    adc     QWORD PTR [rdi+120], 0

    ; ===== Phase 1b: 二倍化 P[1..15] =====
    ; P[0]=0 なのでスキップ, P[1] から add で新キャリーチェーン開始
    mov     rax, [rdi+8]
    add     rax, rax
    mov     [rdi+8], rax
    mov     rax, [rdi+16]
    adc     rax, rax
    mov     [rdi+16], rax
    mov     rax, [rdi+24]
    adc     rax, rax
    mov     [rdi+24], rax
    mov     rax, [rdi+32]
    adc     rax, rax
    mov     [rdi+32], rax
    mov     rax, [rdi+40]
    adc     rax, rax
    mov     [rdi+40], rax
    mov     rax, [rdi+48]
    adc     rax, rax
    mov     [rdi+48], rax
    mov     rax, [rdi+56]
    adc     rax, rax
    mov     [rdi+56], rax
    mov     rax, [rdi+64]
    adc     rax, rax
    mov     [rdi+64], rax
    mov     rax, [rdi+72]
    adc     rax, rax
    mov     [rdi+72], rax
    mov     rax, [rdi+80]
    adc     rax, rax
    mov     [rdi+80], rax
    mov     rax, [rdi+88]
    adc     rax, rax
    mov     [rdi+88], rax
    mov     rax, [rdi+96]
    adc     rax, rax
    mov     [rdi+96], rax
    mov     rax, [rdi+104]
    adc     rax, rax
    mov     [rdi+104], rax
    mov     rax, [rdi+112]
    adc     rax, rax
    mov     [rdi+112], rax
    mov     rax, [rdi+120]
    adc     rax, rax
    mov     [rdi+120], rax
    ; carry → P[16]
    mov     rax, 0
    adc     rax, 0
    mov     [rdi+128], rax

    ; ===== Phase 1c: 対角成分 a[i]^2 加算 (adc チェーン) =====
    mov     rdx, [rsi]
    mulx    rbx, rax, rdx
    add     [rdi], rax
    adc     [rdi+8], rbx

    mov     rdx, [rsi+8]
    mulx    rbx, rax, rdx
    adc     [rdi+16], rax
    adc     [rdi+24], rbx

    mov     rdx, [rsi+16]
    mulx    rbx, rax, rdx
    adc     [rdi+32], rax
    adc     [rdi+40], rbx

    mov     rdx, [rsi+24]
    mulx    rbx, rax, rdx
    adc     [rdi+48], rax
    adc     [rdi+56], rbx

    mov     rdx, [rsi+32]
    mulx    rbx, rax, rdx
    adc     [rdi+64], rax
    adc     [rdi+72], rbx

    mov     rdx, [rsi+40]
    mulx    rbx, rax, rdx
    adc     [rdi+80], rax
    adc     [rdi+88], rbx

    mov     rdx, [rsi+48]
    mulx    rbx, rax, rdx
    adc     [rdi+96], rax
    adc     [rdi+104], rbx

    mov     rdx, [rsi+56]
    mulx    rbx, rax, rdx
    adc     [rdi+112], rax
    adc     [rdi+120], rbx
    adc     QWORD PTR [rdi+128], 0

    ; ===== Phase 2: REDC (8 反復) =====
    mov     r8, [rdi]
    mov     r9, [rdi+8]
    mov     r10, [rdi+16]
    mov     r11, [rdi+24]
    mov     r12, [rdi+32]
    mov     r13, [rdi+40]
    mov     r14, [rdi+48]
    mov     r15, [rdi+56]

    mov     rsi, [rsp+160]      ; mp

    REDC_ITER 0
    REDC_ITER 1
    REDC_ITER 2
    REDC_ITER 3
    REDC_ITER 4
    REDC_ITER 5
    REDC_ITER 6
    REDC_ITER 7

    ; ===== Phase 3: 条件付き減算 =====
    mov     rbx, [rsp+136]      ; rp
    mov     rsi, [rsp+160]      ; mp

    mov     rax, r8
    sub     rax, [rsi]
    mov     [rbx], rax

    mov     rax, r9
    sbb     rax, [rsi+8]
    mov     [rbx+8], rax

    mov     rax, r10
    sbb     rax, [rsi+16]
    mov     [rbx+16], rax

    mov     rax, r11
    sbb     rax, [rsi+24]
    mov     [rbx+24], rax

    mov     rax, r12
    sbb     rax, [rsi+32]
    mov     [rbx+32], rax

    mov     rax, r13
    sbb     rax, [rsi+40]
    mov     [rbx+40], rax

    mov     rax, r14
    sbb     rax, [rsi+48]
    mov     [rbx+48], rax

    mov     rax, r15
    sbb     rax, [rsi+56]
    mov     [rbx+56], rax

    ; P[16] オーバーフロー考慮
    mov     rax, [rdi+128]
    sbb     rax, 0
    jnc     s8_done

    ; borrow → 減算結果は無効、元の結果を rp にコピー
    mov     [rbx], r8
    mov     [rbx+8], r9
    mov     [rbx+16], r10
    mov     [rbx+24], r11
    mov     [rbx+32], r12
    mov     [rbx+40], r13
    mov     [rbx+48], r14
    mov     [rbx+56], r15

s8_done:
    add     rsp, 184
    pop     rbp
    pop     r15
    pop     r14
    pop     r13
    pop     r12
    pop     rdi
    pop     rsi
    pop     rbx
    ret
mpn_mont_sqr_8 ENDP


; =====================================================================
; void mpn_mont_redc_16(uint64_t* rp, uint64_t* tp,
;                        const uint64_t* mp, uint64_t m_inv)
;
; n=16 特化 Montgomery REDC: rp = tp * R^{-1} mod mp
; tp[0..32] は 33 ワード (2*16+1), 破壊される
; 1024-bit モジュラスに最適化
;
; rcx = rp, rdx = tp, r8 = mp, r9 = m_inv
;
; 最適化:
;   - 16-word addmul を完全アンロール (ポインタ進行なし, 固定オフセット)
;   - rsi = mp レジスタ常駐 (16 反復で再ロードなし)
;   - r8 = 0 定数レジスタ常駐
;   - 1 反復あたりスタックリード: m_inv の 1 回のみ
; =====================================================================
mpn_mont_redc_16 PROC
    push    rbx
    push    rsi
    push    rdi
    push    r12
    push    r13
    push    r14
    push    r15
    push    rbp
    sub     rsp, 32

    ; スタック配置:
    ;   8 push × 8 = 64, sub 32 → total = 96
    ;   [rsp+0]  = rp
    ;   [rsp+8]  = m_inv
    ;   [rsp+16] = tp_end (&tp[16], ループ終了判定)
    ;   [rsp+24] = &tp[32] (キャリー伝播上限)

    mov     [rsp+0], rcx
    mov     [rsp+8], r9
    mov     rsi, r8             ; rsi = mp (ループ中レジスタ常駐)
    mov     rbp, rdx            ; rbp = &tp[i] (ループで +8 進行)
    lea     rax, [rdx + 128]
    mov     [rsp+16], rax       ; tp_end = &tp[16]
    lea     rax, [rdx + 256]
    mov     [rsp+24], rax       ; &tp[32]

    xor     r8d, r8d            ; r8 = 0 (ループ全体で維持)

r16_outer:
    ; q = tp[i] * m_inv
    mov     rax, [rbp]
    imul    rax, QWORD PTR [rsp+8]
    mov     rdx, rax            ; q → MULX multiplier
    mov     rbx, rbp            ; rbx = &tp[i] (dest)
    xor     r9d, r9d            ; carry_in = 0

    ; === 16-word addmul (4 ブロック × 4 ワード, ADCX/ADOX) ===

    ; Block 0: mp[0..3] * q + tp[i..i+3]
    xor     eax, eax            ; CF=0, OF=0
    mulx    r10, r11, [rsi]
    mulx    r12, r13, [rsi+8]
    mulx    r14, r15, [rsi+16]
    mulx    rdi, rax, [rsi+24]
    adcx    r11, r9
    adox    r11, [rbx]
    mov     [rbx], r11
    adcx    r13, r10
    adox    r13, [rbx+8]
    mov     [rbx+8], r13
    adcx    r15, r12
    adox    r15, [rbx+16]
    mov     [rbx+16], r15
    adcx    rax, r14
    adox    rax, [rbx+24]
    mov     [rbx+24], rax
    mov     r9, rdi
    adcx    r9, r8
    adox    r9, r8

    ; Block 1: mp[4..7] * q + tp[i+4..i+7]
    xor     eax, eax
    mulx    r10, r11, [rsi+32]
    mulx    r12, r13, [rsi+40]
    mulx    r14, r15, [rsi+48]
    mulx    rdi, rax, [rsi+56]
    adcx    r11, r9
    adox    r11, [rbx+32]
    mov     [rbx+32], r11
    adcx    r13, r10
    adox    r13, [rbx+40]
    mov     [rbx+40], r13
    adcx    r15, r12
    adox    r15, [rbx+48]
    mov     [rbx+48], r15
    adcx    rax, r14
    adox    rax, [rbx+56]
    mov     [rbx+56], rax
    mov     r9, rdi
    adcx    r9, r8
    adox    r9, r8

    ; Block 2: mp[8..11] * q + tp[i+8..i+11]
    xor     eax, eax
    mulx    r10, r11, [rsi+64]
    mulx    r12, r13, [rsi+72]
    mulx    r14, r15, [rsi+80]
    mulx    rdi, rax, [rsi+88]
    adcx    r11, r9
    adox    r11, [rbx+64]
    mov     [rbx+64], r11
    adcx    r13, r10
    adox    r13, [rbx+72]
    mov     [rbx+72], r13
    adcx    r15, r12
    adox    r15, [rbx+80]
    mov     [rbx+80], r15
    adcx    rax, r14
    adox    rax, [rbx+88]
    mov     [rbx+88], rax
    mov     r9, rdi
    adcx    r9, r8
    adox    r9, r8

    ; Block 3: mp[12..15] * q + tp[i+12..i+15]
    xor     eax, eax
    mulx    r10, r11, [rsi+96]
    mulx    r12, r13, [rsi+104]
    mulx    r14, r15, [rsi+112]
    mulx    rdi, rax, [rsi+120]
    adcx    r11, r9
    adox    r11, [rbx+96]
    mov     [rbx+96], r11
    adcx    r13, r10
    adox    r13, [rbx+104]
    mov     [rbx+104], r13
    adcx    r15, r12
    adox    r15, [rbx+112]
    mov     [rbx+112], r15
    adcx    rax, r14
    adox    rax, [rbx+120]
    mov     [rbx+120], rax
    mov     r9, rdi
    adcx    r9, r8
    adox    r9, r8

    ; === キャリー伝播: tp[i+16] += r9 ===
    add     [rbp + 128], r9
    jnc     r16_nc
    lea     rax, [rbp + 136]    ; &tp[i+17]
    mov     rcx, [rsp+24]       ; &tp[32]
r16_carry:
    cmp     rax, rcx
    ja      r16_nc
    add     QWORD PTR [rax], 1
    lea     rax, [rax + 8]
    jc      r16_carry

r16_nc:
    lea     rbp, [rbp + 8]      ; tp[i] → tp[i+1]
    cmp     rbp, [rsp+16]       ; tp_end
    jb      r16_outer

    ; === 条件付き減算: tp[16..31] - mp → rp ===
    ; rbp = &tp[16] (ループ終了時), rsi = mp
    mov     rbx, [rsp+0]        ; rp

    mov     rax, [rbp]
    sub     rax, [rsi]
    mov     [rbx], rax

    mov     rax, [rbp+8]
    sbb     rax, [rsi+8]
    mov     [rbx+8], rax

    mov     rax, [rbp+16]
    sbb     rax, [rsi+16]
    mov     [rbx+16], rax

    mov     rax, [rbp+24]
    sbb     rax, [rsi+24]
    mov     [rbx+24], rax

    mov     rax, [rbp+32]
    sbb     rax, [rsi+32]
    mov     [rbx+32], rax

    mov     rax, [rbp+40]
    sbb     rax, [rsi+40]
    mov     [rbx+40], rax

    mov     rax, [rbp+48]
    sbb     rax, [rsi+48]
    mov     [rbx+48], rax

    mov     rax, [rbp+56]
    sbb     rax, [rsi+56]
    mov     [rbx+56], rax

    mov     rax, [rbp+64]
    sbb     rax, [rsi+64]
    mov     [rbx+64], rax

    mov     rax, [rbp+72]
    sbb     rax, [rsi+72]
    mov     [rbx+72], rax

    mov     rax, [rbp+80]
    sbb     rax, [rsi+80]
    mov     [rbx+80], rax

    mov     rax, [rbp+88]
    sbb     rax, [rsi+88]
    mov     [rbx+88], rax

    mov     rax, [rbp+96]
    sbb     rax, [rsi+96]
    mov     [rbx+96], rax

    mov     rax, [rbp+104]
    sbb     rax, [rsi+104]
    mov     [rbx+104], rax

    mov     rax, [rbp+112]
    sbb     rax, [rsi+112]
    mov     [rbx+112], rax

    mov     rax, [rbp+120]
    sbb     rax, [rsi+120]
    mov     [rbx+120], rax

    ; tp[32] - borrow
    mov     rax, [rbp+128]
    sbb     rax, 0
    jnc     r16_done

    ; borrow あり → tp[16..31] をそのまま rp にコピー
    mov     rax, [rbp]
    mov     [rbx], rax
    mov     rax, [rbp+8]
    mov     [rbx+8], rax
    mov     rax, [rbp+16]
    mov     [rbx+16], rax
    mov     rax, [rbp+24]
    mov     [rbx+24], rax
    mov     rax, [rbp+32]
    mov     [rbx+32], rax
    mov     rax, [rbp+40]
    mov     [rbx+40], rax
    mov     rax, [rbp+48]
    mov     [rbx+48], rax
    mov     rax, [rbp+56]
    mov     [rbx+56], rax
    mov     rax, [rbp+64]
    mov     [rbx+64], rax
    mov     rax, [rbp+72]
    mov     [rbx+72], rax
    mov     rax, [rbp+80]
    mov     [rbx+80], rax
    mov     rax, [rbp+88]
    mov     [rbx+88], rax
    mov     rax, [rbp+96]
    mov     [rbx+96], rax
    mov     rax, [rbp+104]
    mov     [rbx+104], rax
    mov     rax, [rbp+112]
    mov     [rbx+112], rax
    mov     rax, [rbp+120]
    mov     [rbx+120], rax

r16_done:
    add     rsp, 32
    pop     rbp
    pop     r15
    pop     r14
    pop     r13
    pop     r12
    pop     rdi
    pop     rsi
    pop     rbx
    ret
mpn_mont_redc_16 ENDP


; =====================================================================
; void mpn_mont_redc_32(uint64_t* rp, uint64_t* tp,
;                        const uint64_t* mp, uint64_t m_inv)
;
; n=32 特化 Montgomery REDC: rp = tp * R^{-1} mod mp
; tp[0..64] は 65 ワード (2*32+1), 破壊される
; 2048-bit モジュラスに最適化
;
; rcx = rp, rdx = tp, r8 = mp, r9 = m_inv
;
; 最適化:
;   - 32-word addmul を ADDMUL_BLOCK_4 × 8 ブロックで完全アンロール
;   - rsi = mp レジスタ常駐, r8 = 0 定数
;   - 条件付き減算はループ版 (32 ワードの完全アンロールは冗長)
; =====================================================================
mpn_mont_redc_32 PROC
    push    rbx
    push    rsi
    push    rdi
    push    r12
    push    r13
    push    r14
    push    r15
    push    rbp
    sub     rsp, 32

    ; スタック配置:
    ;   [rsp+0]  = rp
    ;   [rsp+8]  = m_inv
    ;   [rsp+16] = tp_end (&tp[32])
    ;   [rsp+24] = &tp[64] (キャリー伝播上限)

    mov     [rsp+0], rcx
    mov     [rsp+8], r9
    mov     rsi, r8             ; rsi = mp (レジスタ常駐)
    mov     rbp, rdx            ; rbp = &tp[i]
    lea     rax, [rdx + 256]
    mov     [rsp+16], rax       ; tp_end = &tp[32]
    lea     rax, [rdx + 512]
    mov     [rsp+24], rax       ; &tp[64]

    xor     r8d, r8d            ; r8 = 0

r32_outer:
    ; q = tp[i] * m_inv
    mov     rax, [rbp]
    imul    rax, QWORD PTR [rsp+8]
    mov     rdx, rax
    mov     rbx, rbp
    xor     r9d, r9d

    ; === 32-word addmul (8 ブロック × 4 ワード) ===
    ADDMUL_BLOCK_4 0
    ADDMUL_BLOCK_4 32
    ADDMUL_BLOCK_4 64
    ADDMUL_BLOCK_4 96
    ADDMUL_BLOCK_4 128
    ADDMUL_BLOCK_4 160
    ADDMUL_BLOCK_4 192
    ADDMUL_BLOCK_4 224

    ; === キャリー伝播: tp[i+32] += r9 ===
    add     [rbp + 256], r9
    jnc     r32_nc
    lea     rax, [rbp + 264]
    mov     rcx, [rsp+24]
r32_carry:
    cmp     rax, rcx
    ja      r32_nc
    add     QWORD PTR [rax], 1
    lea     rax, [rax + 8]
    jc      r32_carry

r32_nc:
    lea     rbp, [rbp + 8]
    cmp     rbp, [rsp+16]
    jb      r32_outer

    ; === 条件付き減算: tp[32..63] - mp → rp (ループ版) ===
    ; rbp = &tp[32], rsi = mp
    mov     rbx, [rsp+0]       ; rp
    mov     rcx, 32             ; n
    xor     r9d, r9d            ; index

    ; 最初の 1 ワード: sub
    mov     rax, [rbp]
    sub     rax, [rsi]
    mov     [rbx], rax
    inc     r9
r32_sub:
    mov     rax, [rbp + r9*8]
    sbb     rax, [rsi + r9*8]
    mov     [rbx + r9*8], rax
    inc     r9
    cmp     r9, rcx
    jb      r32_sub

    ; tp[64] - borrow
    mov     rax, [rbp + 256]
    sbb     rax, 0
    jnc     r32_done

    ; borrow → tp[32..63] をそのまま rp にコピー
    xor     r9d, r9d
r32_copy:
    mov     rax, [rbp + r9*8]
    mov     [rbx + r9*8], rax
    inc     r9
    cmp     r9, rcx
    jb      r32_copy

r32_done:
    add     rsp, 32
    pop     rbp
    pop     r15
    pop     r14
    pop     r13
    pop     r12
    pop     rdi
    pop     rsi
    pop     rbx
    ret
mpn_mont_redc_32 ENDP


; =====================================================================
; void mpn_mont_sqr_16(uint64_t* rp, const uint64_t* ap,
;                       const uint64_t* mp, uint64_t m_inv)
;
; n=16 統合 SOS Montgomery 自乗: rp = ap^2 * R^{-1} mod mp
; 1024-bit 特化, 対称性利用 (120+16=136 MULX for sqr, 256 MULX for REDC)
;
; rcx = rp, rdx = ap, r8 = mp, r9 = m_inv
;
; Phase 1a: 上三角乗算 (120 MULX, 手動アンロール ADCX/ADOX)
; Phase 1b: 二倍化 (P[1..31] <<= 1)
; Phase 1c: 対角成分 a[i]^2 加算
; Phase 2: REDC 16 反復 (ADDMUL_BLOCK_4 × 4)
; Phase 3: 条件付き減算
; =====================================================================
mpn_mont_sqr_16 PROC
    push    rbx
    push    rsi
    push    rdi
    push    r12
    push    r13
    push    r14
    push    r15
    push    rbp
    sub     rsp, 328

    ; Stack layout:
    ;   [rsp+0..263]   = P[0..32] product buffer (33 qwords)
    ;   [rsp+264]      = rp
    ;   [rsp+272]      = ap
    ;   [rsp+280]      = mp
    ;   [rsp+288]      = m_inv

    mov     [rsp+264], rcx
    mov     [rsp+272], rdx
    mov     [rsp+280], r8
    mov     [rsp+288], r9

    ; --- Zero P[0..32] (33 qwords) ---
    xor     eax, eax
    lea     rdi, [rsp]
_s16z = 0
    REPEAT 33
        mov     QWORD PTR [rdi+_s16z], rax
_s16z = _s16z + 8
    ENDM

    ; ===== Phase 1a: Upper triangle (120 MULX, hand-unrolled) =====
    mov     rsi, [rsp+272]              ; rsi = ap

    ; --- Row 0: a[0]*a[1..15] → P[1..16] (mul_1, ADC chain) ---
    mov     rdx, [rsi]
    mulx    rbx, rcx, [rsi+8]
    mov     [rsp+8], rcx
    mulx    rbp, rcx, [rsi+16]
    add     rcx, rbx
    mov     [rsp+16], rcx
_s16k = 2
    WHILE _s16k LE 13
_s16_off = (_s16k + 1) * 8
        IF (_s16k AND 1)
            mulx    rbp, rcx, [rsi + _s16_off]
            adc     rcx, rbx
        ELSE
            mulx    rbx, rcx, [rsi + _s16_off]
            adc     rcx, rbp
        ENDIF
        mov     [rsp + _s16_off], rcx
_s16k = _s16k + 1
    ENDM
    mulx    rbx, rcx, [rsi + 120]
    adc     rcx, rbp
    mov     [rsp + 120], rcx
    adc     rbx, 0
    mov     [rsp + 128], rbx

    ; --- Rows 1..13: a[i]*a[i+1..15] (ADCX/ADOX dual carry chain) ---
_s16i = 1
    WHILE _s16i LE 13
_s16n = 15 - _s16i
        mov     rdx, [rsi + _s16i*8]
        xor     eax, eax
_s16_src = (_s16i + 1) * 8
_s16_dst = (2 * _s16i + 1) * 8
        mulx    rbx, rcx, [rsi + _s16_src]
        adox    rcx, [rsp + _s16_dst]
        mov     [rsp + _s16_dst], rcx
_s16j = 1
        WHILE _s16j LT _s16n
_s16_src = (_s16i + 1 + _s16j) * 8
_s16_dst = (2 * _s16i + 1 + _s16j) * 8
            IF (_s16j AND 1)
                mulx    rbp, rcx, [rsi + _s16_src]
                adcx    rcx, rbx
            ELSE
                mulx    rbx, rcx, [rsi + _s16_src]
                adcx    rcx, rbp
            ENDIF
            adox    rcx, [rsp + _s16_dst]
            mov     [rsp + _s16_dst], rcx
_s16j = _s16j + 1
        ENDM
_s16_tail = (_s16i + 16) * 8
        mov     rcx, rax
        IF ((_s16n - 1) AND 1)
            adcx    rcx, rbp
        ELSE
            adcx    rcx, rbx
        ENDIF
        adox    rcx, [rsp + _s16_tail]
        mov     [rsp + _s16_tail], rcx
        mov     rcx, rax
        adcx    rcx, rcx
        adox    rax, rax
        add     rcx, rax
        add     [rsp + _s16_tail + 8], rcx
_s16i = _s16i + 1
    ENDM

    ; --- Row 14: a[14]*a[15] → P[29..31] (1 MULX, add/adc) ---
    mov     rdx, [rsi+112]
    mulx    rbx, rcx, [rsi+120]
    add     [rsp+232], rcx
    adc     [rsp+240], rbx
    adc     QWORD PTR [rsp+248], 0

    ; ===== Phase 1b: Double P[1..31] =====
    mov     rax, [rsp+8]               ; P[1]
    add     rax, rax
    mov     [rsp+8], rax
_s16d = 16
    REPEAT 30
        mov     rax, [rsp+_s16d]
        adc     rax, rax
        mov     [rsp+_s16d], rax
_s16d = _s16d + 8
    ENDM
    ; carry → P[32]
    mov     rax, 0
    adc     rax, 0
    add     [rsp+256], rax

    ; ===== Phase 1c: Diagonal a[i]^2 (16 MULX, adc chain) =====
    mov     rcx, [rsp+272]              ; ap

    mov     rdx, [rcx]
    mulx    rbx, rax, rdx
    add     [rsp], rax
    adc     [rsp+8], rbx

    mov     rdx, [rcx+8]
    mulx    rbx, rax, rdx
    adc     [rsp+16], rax
    adc     [rsp+24], rbx

    mov     rdx, [rcx+16]
    mulx    rbx, rax, rdx
    adc     [rsp+32], rax
    adc     [rsp+40], rbx

    mov     rdx, [rcx+24]
    mulx    rbx, rax, rdx
    adc     [rsp+48], rax
    adc     [rsp+56], rbx

    mov     rdx, [rcx+32]
    mulx    rbx, rax, rdx
    adc     [rsp+64], rax
    adc     [rsp+72], rbx

    mov     rdx, [rcx+40]
    mulx    rbx, rax, rdx
    adc     [rsp+80], rax
    adc     [rsp+88], rbx

    mov     rdx, [rcx+48]
    mulx    rbx, rax, rdx
    adc     [rsp+96], rax
    adc     [rsp+104], rbx

    mov     rdx, [rcx+56]
    mulx    rbx, rax, rdx
    adc     [rsp+112], rax
    adc     [rsp+120], rbx

    mov     rdx, [rcx+64]
    mulx    rbx, rax, rdx
    adc     [rsp+128], rax
    adc     [rsp+136], rbx

    mov     rdx, [rcx+72]
    mulx    rbx, rax, rdx
    adc     [rsp+144], rax
    adc     [rsp+152], rbx

    mov     rdx, [rcx+80]
    mulx    rbx, rax, rdx
    adc     [rsp+160], rax
    adc     [rsp+168], rbx

    mov     rdx, [rcx+88]
    mulx    rbx, rax, rdx
    adc     [rsp+176], rax
    adc     [rsp+184], rbx

    mov     rdx, [rcx+96]
    mulx    rbx, rax, rdx
    adc     [rsp+192], rax
    adc     [rsp+200], rbx

    mov     rdx, [rcx+104]
    mulx    rbx, rax, rdx
    adc     [rsp+208], rax
    adc     [rsp+216], rbx

    mov     rdx, [rcx+112]
    mulx    rbx, rax, rdx
    adc     [rsp+224], rax
    adc     [rsp+232], rbx

    mov     rdx, [rcx+120]
    mulx    rbx, rax, rdx
    adc     [rsp+240], rax
    adc     [rsp+248], rbx

    adc     QWORD PTR [rsp+256], 0      ; P[32] final carry

    ; ===== Phase 2: REDC (16 iterations) =====
    mov     rsi, [rsp+280]              ; mp (register-resident)
    lea     rbp, [rsp]                  ; rbp = &P[i], starts at &P[0]
    xor     r8d, r8d                    ; r8 = 0 constant

s16_redc_outer:
    mov     rax, [rbp]
    imul    rax, QWORD PTR [rsp+288]    ; q = P[i] * m_inv
    mov     rdx, rax
    mov     rbx, rbp                    ; dest = &P[i]
    xor     r9d, r9d                    ; carry_in = 0

    ; Block 0: mp[0..3] * q + P[i..i+3]
    xor     eax, eax
    mulx    r10, r11, [rsi]
    mulx    r12, r13, [rsi+8]
    mulx    r14, r15, [rsi+16]
    mulx    rdi, rax, [rsi+24]
    adcx    r11, r9
    adox    r11, [rbx]
    mov     [rbx], r11
    adcx    r13, r10
    adox    r13, [rbx+8]
    mov     [rbx+8], r13
    adcx    r15, r12
    adox    r15, [rbx+16]
    mov     [rbx+16], r15
    adcx    rax, r14
    adox    rax, [rbx+24]
    mov     [rbx+24], rax
    mov     r9, rdi
    adcx    r9, r8
    adox    r9, r8

    ; Block 1: mp[4..7] * q + P[i+4..i+7]
    xor     eax, eax
    mulx    r10, r11, [rsi+32]
    mulx    r12, r13, [rsi+40]
    mulx    r14, r15, [rsi+48]
    mulx    rdi, rax, [rsi+56]
    adcx    r11, r9
    adox    r11, [rbx+32]
    mov     [rbx+32], r11
    adcx    r13, r10
    adox    r13, [rbx+40]
    mov     [rbx+40], r13
    adcx    r15, r12
    adox    r15, [rbx+48]
    mov     [rbx+48], r15
    adcx    rax, r14
    adox    rax, [rbx+56]
    mov     [rbx+56], rax
    mov     r9, rdi
    adcx    r9, r8
    adox    r9, r8

    ; Block 2: mp[8..11] * q + P[i+8..i+11]
    xor     eax, eax
    mulx    r10, r11, [rsi+64]
    mulx    r12, r13, [rsi+72]
    mulx    r14, r15, [rsi+80]
    mulx    rdi, rax, [rsi+88]
    adcx    r11, r9
    adox    r11, [rbx+64]
    mov     [rbx+64], r11
    adcx    r13, r10
    adox    r13, [rbx+72]
    mov     [rbx+72], r13
    adcx    r15, r12
    adox    r15, [rbx+80]
    mov     [rbx+80], r15
    adcx    rax, r14
    adox    rax, [rbx+88]
    mov     [rbx+88], rax
    mov     r9, rdi
    adcx    r9, r8
    adox    r9, r8

    ; Block 3: mp[12..15] * q + P[i+12..i+15]
    xor     eax, eax
    mulx    r10, r11, [rsi+96]
    mulx    r12, r13, [rsi+104]
    mulx    r14, r15, [rsi+112]
    mulx    rdi, rax, [rsi+120]
    adcx    r11, r9
    adox    r11, [rbx+96]
    mov     [rbx+96], r11
    adcx    r13, r10
    adox    r13, [rbx+104]
    mov     [rbx+104], r13
    adcx    r15, r12
    adox    r15, [rbx+112]
    mov     [rbx+112], r15
    adcx    rax, r14
    adox    rax, [rbx+120]
    mov     [rbx+120], rax
    mov     r9, rdi
    adcx    r9, r8
    adox    r9, r8

    ; carry → P[i+16]
    add     [rbp + 128], r9
    jnc     s16_redc_nc
    lea     rax, [rbp + 136]
    lea     rcx, [rsp + 256]            ; &P[32]
s16_redc_carry:
    cmp     rax, rcx
    ja      s16_redc_nc
    add     QWORD PTR [rax], 1
    lea     rax, [rax + 8]
    jc      s16_redc_carry

s16_redc_nc:
    lea     rbp, [rbp + 8]
    lea     rax, [rsp + 128]            ; &P[16]
    cmp     rbp, rax
    jb      s16_redc_outer

    ; ===== Phase 3: Conditional subtraction =====
    ; rbp = rsp + 128 = &P[16]
    mov     rbx, [rsp+264]              ; rp
    ; rsi = mp (still loaded)

    mov     rax, [rbp]
    sub     rax, [rsi]
    mov     [rbx], rax

_s16s = 8
    REPEAT 15
        mov     rax, [rbp+_s16s]
        sbb     rax, [rsi+_s16s]
        mov     [rbx+_s16s], rax
_s16s = _s16s + 8
    ENDM

    ; P[32] - borrow
    mov     rax, [rbp+128]
    sbb     rax, 0
    jnc     s16_sqr_done

    ; borrow → P[16..31] をそのまま rp にコピー
_s16c = 0
    REPEAT 16
        mov     rax, [rbp+_s16c]
        mov     [rbx+_s16c], rax
_s16c = _s16c + 8
    ENDM

s16_sqr_done:
    add     rsp, 328
    pop     rbp
    pop     r15
    pop     r14
    pop     r13
    pop     r12
    pop     rdi
    pop     rsi
    pop     rbx
    ret
mpn_mont_sqr_16 ENDP


; =====================================================================
; n=4 特化マクロ (256-bit Montgomery)
;
; ADDMUL_FIRST_4 / ADDMUL_REST_4: レジスタ常駐 4-word addmul
; 累算器: r8, r9, r10, r11
; rdx = multiplier, rsi = source, rax = 0 constant (FIRST で xor)
; rbx, rbp = MULX hi/lo テンポラリ
; =====================================================================
ADDMUL_FIRST_4 MACRO
    xor     eax, eax
    mulx    rbx, rbp, [rsi]
    adox    r8, rbp
ENDM

ADDMUL_REST_4 MACRO
    adcx    r9, rbx
    mulx    rbp, r8, [rsi+8]
    adox    r8, r9
    adcx    r10, rbp
    mulx    rbx, r9, [rsi+16]
    adox    r9, r10
    adcx    r11, rbx
    mulx    rbp, r10, [rsi+24]
    adox    r10, r11
    mov     r11, rax
    adcx    r11, rbp
    adox    r11, rax
ENDM

REDC_ITER_4 MACRO iter_idx
LOCAL ri_nc
    mov     rax, r8
    imul    rax, QWORD PTR [rsp+88]
    mov     rdx, rax
    ADDMUL_FIRST_4
    ADDMUL_REST_4
    add     r11, QWORD PTR [rdi + (4 + iter_idx)*8]
    jnc     ri_nc
_prop_idx = 5 + iter_idx
    WHILE _prop_idx LE 8
        add     QWORD PTR [rdi + _prop_idx*8], 1
        jnc     ri_nc
_prop_idx = _prop_idx + 1
    ENDM
ri_nc:
ENDM


; =====================================================================
; void mpn_mont_mul_4(uint64_t* rp, const uint64_t* ap,
;                      const uint64_t* bp, const uint64_t* mp,
;                      uint64_t m_inv)
;
; レジスタ常駐 n=4 SOS Montgomery 乗算: rp = ap * bp * R^{-1} mod mp
; 256-bit 特化, n=8 と同様のレジスタ常駐累算器方式
;
; rcx = rp, rdx = ap, r8 = bp, r9 = mp
; [rsp+176] = m_inv (5th arg, after 4 pushes + sub 104)
;
; Phase 1: 4×4 完全乗算 (SOS, 16 MULX)
; Phase 2: REDC 4 反復 (16 MULX)
; Phase 3: 条件付き減算
; =====================================================================
mpn_mont_mul_4 PROC
    push    rbx
    push    rbp
    push    rsi
    push    rdi
    sub     rsp, 104

    ; スタック配置:
    ;   4 push × 8 = 32, sub 104 → total = 136
    ;   5th arg (m_inv): [rsp + 136 + 40] = [rsp + 176]
    ;
    ;   [rsp+0..71]   = P[0..8] product buffer (9 qwords)
    ;   [rsp+72]      = rp
    ;   [rsp+80]      = mp
    ;   [rsp+88]      = m_inv

    mov     [rsp+72], rcx
    mov     rsi, rdx            ; rsi = ap
    mov     rcx, r8             ; rcx = bp
    mov     [rsp+80], r9        ; mp
    mov     rax, [rsp+176]
    mov     [rsp+88], rax       ; m_inv

    mov     rdi, rsp            ; product buffer

    ; ===== Phase 1: 4×4 完全乗算 (SOS) =====
    xor     r8d, r8d
    xor     r9d, r9d
    xor     r10d, r10d
    xor     r11d, r11d

    ; Row 0: b[0]
    mov     rdx, [rcx]
    ADDMUL_FIRST_4
    mov     QWORD PTR [rdi], r8
    ADDMUL_REST_4

    ; Row 1: b[1]
    mov     rdx, [rcx+8]
    ADDMUL_FIRST_4
    mov     QWORD PTR [rdi+8], r8
    ADDMUL_REST_4

    ; Row 2: b[2]
    mov     rdx, [rcx+16]
    ADDMUL_FIRST_4
    mov     QWORD PTR [rdi+16], r8
    ADDMUL_REST_4

    ; Row 3: b[3]
    mov     rdx, [rcx+24]
    ADDMUL_FIRST_4
    mov     QWORD PTR [rdi+24], r8
    ADDMUL_REST_4

    ; P[4..7] ← 累算器, P[8] = 0
    mov     [rdi+32], r8
    mov     [rdi+40], r9
    mov     [rdi+48], r10
    mov     [rdi+56], r11
    mov     QWORD PTR [rdi+64], 0

    ; ===== Phase 2: REDC (4 反復) =====
    mov     r8, [rdi]
    mov     r9, [rdi+8]
    mov     r10, [rdi+16]
    mov     r11, [rdi+24]

    mov     rsi, [rsp+80]       ; mp

    REDC_ITER_4 0
    REDC_ITER_4 1
    REDC_ITER_4 2
    REDC_ITER_4 3

    ; ===== Phase 3: 条件付き減算 =====
    mov     rbx, [rsp+72]       ; rp
    mov     rsi, [rsp+80]       ; mp

    mov     rax, r8
    sub     rax, [rsi]
    mov     [rbx], rax

    mov     rax, r9
    sbb     rax, [rsi+8]
    mov     [rbx+8], rax

    mov     rax, r10
    sbb     rax, [rsi+16]
    mov     [rbx+16], rax

    mov     rax, r11
    sbb     rax, [rsi+24]
    mov     [rbx+24], rax

    ; P[8] オーバーフロー考慮
    mov     rax, [rdi+64]
    sbb     rax, 0
    jnc     m4_done

    ; borrow → 減算結果は無効
    mov     [rbx], r8
    mov     [rbx+8], r9
    mov     [rbx+16], r10
    mov     [rbx+24], r11

m4_done:
    add     rsp, 104
    pop     rdi
    pop     rsi
    pop     rbp
    pop     rbx
    ret
mpn_mont_mul_4 ENDP


; =====================================================================
; void mpn_mont_sqr_4(uint64_t* rp, const uint64_t* ap,
;                      const uint64_t* mp, uint64_t m_inv)
;
; レジスタ常駐 n=4 統合 SQR+REDC: rp = ap^2 * R^{-1} mod mp
; 256-bit 特化, 対称性利用 (6+4=10 MULX vs 乗算の 16 MULX)
;
; rcx = rp, rdx = ap, r8 = mp, r9 = m_inv
;
; Phase 1a: 上三角乗算 (off-diagonal, 6 MULX)
; Phase 1b: 二倍化 (P[1..7] <<= 1)
; Phase 1c: 対角成分 a[i]^2 加算 (4 MULX)
; Phase 2: REDC 4 反復 (REDC_ITER_4 マクロ)
; Phase 3: 条件付き減算
; =====================================================================
mpn_mont_sqr_4 PROC
    push    rbx
    push    rbp
    push    rsi
    push    rdi
    sub     rsp, 104

    ; スタック配置 (mont_mul_4 と同一):
    ;   [rsp+0..71]   = P[0..8] product buffer (9 qwords)
    ;   [rsp+72]      = rp
    ;   [rsp+80]      = mp
    ;   [rsp+88]      = m_inv

    mov     [rsp+72], rcx       ; rp
    mov     rsi, rdx            ; rsi = ap
    mov     [rsp+80], r8        ; mp
    mov     [rsp+88], r9        ; m_inv

    mov     rdi, rsp            ; product buffer

    ; --- Zero P[0..8] ---
    xor     eax, eax
    mov     [rdi], rax
    mov     [rdi+8], rax
    mov     [rdi+16], rax
    mov     [rdi+24], rax
    mov     [rdi+32], rax
    mov     [rdi+40], rax
    mov     [rdi+48], rax
    mov     [rdi+56], rax
    mov     [rdi+64], rax

    ; ===== Phase 1a: Off-diagonal 上三角 (6 MULX) =====

    ; --- Row 0: a[0]*a[1..3] → P[1..4] ---
    mov     rdx, [rsi]
    mulx    rbx, rcx, [rsi+8]
    mov     [rdi+8], rcx
    mulx    rbp, rcx, [rsi+16]
    add     rcx, rbx
    mov     [rdi+16], rcx
    mulx    rbx, rcx, [rsi+24]
    adc     rcx, rbp
    mov     [rdi+24], rcx
    adc     rbx, 0
    mov     [rdi+32], rbx

    ; --- Row 1: a[1]*a[2..3] → P[3..5] (ADCX/ADOX) ---
    mov     rdx, [rsi+8]
    xor     eax, eax
    mulx    rbx, rcx, [rsi+16]
    adox    rcx, [rdi+24]
    mov     [rdi+24], rcx
    mulx    rbp, rcx, [rsi+24]
    adcx    rcx, rbx
    adox    rcx, [rdi+32]
    mov     [rdi+32], rcx
    mov     rcx, rax
    adcx    rcx, rbp
    adox    rcx, [rdi+40]
    mov     [rdi+40], rcx
    ; residual CF+OF → P[6]
    mov     rcx, rax
    adcx    rcx, rcx
    adox    rax, rax
    add     rcx, rax
    add     [rdi+48], rcx

    ; --- Row 2: a[2]*a[3] → P[5..6] (単一積, add/adc) ---
    mov     rdx, [rsi+16]
    mulx    rbx, rcx, [rsi+24]
    add     [rdi+40], rcx
    adc     [rdi+48], rbx
    adc     QWORD PTR [rdi+56], 0

    ; ===== Phase 1b: 二倍化 P[1..7] =====
    mov     rax, [rdi+8]
    add     rax, rax
    mov     [rdi+8], rax
    mov     rax, [rdi+16]
    adc     rax, rax
    mov     [rdi+16], rax
    mov     rax, [rdi+24]
    adc     rax, rax
    mov     [rdi+24], rax
    mov     rax, [rdi+32]
    adc     rax, rax
    mov     [rdi+32], rax
    mov     rax, [rdi+40]
    adc     rax, rax
    mov     [rdi+40], rax
    mov     rax, [rdi+48]
    adc     rax, rax
    mov     [rdi+48], rax
    mov     rax, [rdi+56]
    adc     rax, rax
    mov     [rdi+56], rax
    ; carry → P[8]
    mov     rax, 0
    adc     rax, 0
    mov     [rdi+64], rax

    ; ===== Phase 1c: 対角成分 a[i]^2 加算 =====
    mov     rdx, [rsi]
    mulx    rbx, rax, rdx
    add     [rdi], rax
    adc     [rdi+8], rbx

    mov     rdx, [rsi+8]
    mulx    rbx, rax, rdx
    adc     [rdi+16], rax
    adc     [rdi+24], rbx

    mov     rdx, [rsi+16]
    mulx    rbx, rax, rdx
    adc     [rdi+32], rax
    adc     [rdi+40], rbx

    mov     rdx, [rsi+24]
    mulx    rbx, rax, rdx
    adc     [rdi+48], rax
    adc     [rdi+56], rbx
    adc     QWORD PTR [rdi+64], 0

    ; ===== Phase 2: REDC (4 反復) =====
    mov     r8, [rdi]
    mov     r9, [rdi+8]
    mov     r10, [rdi+16]
    mov     r11, [rdi+24]

    mov     rsi, [rsp+80]       ; mp

    REDC_ITER_4 0
    REDC_ITER_4 1
    REDC_ITER_4 2
    REDC_ITER_4 3

    ; ===== Phase 3: 条件付き減算 =====
    mov     rbx, [rsp+72]       ; rp
    mov     rsi, [rsp+80]       ; mp

    mov     rax, r8
    sub     rax, [rsi]
    mov     [rbx], rax

    mov     rax, r9
    sbb     rax, [rsi+8]
    mov     [rbx+8], rax

    mov     rax, r10
    sbb     rax, [rsi+16]
    mov     [rbx+16], rax

    mov     rax, r11
    sbb     rax, [rsi+24]
    mov     [rbx+24], rax

    ; P[8] オーバーフロー考慮
    mov     rax, [rdi+64]
    sbb     rax, 0
    jnc     s4_done

    ; borrow → 減算結果は無効
    mov     [rbx], r8
    mov     [rbx+8], r9
    mov     [rbx+16], r10
    mov     [rbx+24], r11

s4_done:
    add     rsp, 104
    pop     rdi
    pop     rsi
    pop     rbp
    pop     rbx
    ret
mpn_mont_sqr_4 ENDP


; =====================================================================
; n=2 特化マクロ (128-bit Montgomery)
;
; ADDMUL_FIRST_2 / ADDMUL_REST_2: レジスタ常駐 2-word addmul
; 累算器: r8, r9
; rdx = multiplier, rsi = source, rax = 0 constant (FIRST で xor)
; rbx, rbp = MULX hi/lo テンポラリ
; =====================================================================
ADDMUL_FIRST_2 MACRO
    xor     eax, eax
    mulx    rbx, rbp, [rsi]
    adox    r8, rbp
ENDM

ADDMUL_REST_2 MACRO
    adcx    r9, rbx
    mulx    rbp, r8, [rsi+8]
    adox    r8, r9
    mov     r9, rax
    adcx    r9, rbp
    adox    r9, rax
ENDM

REDC_ITER_2 MACRO iter_idx
LOCAL ri_nc
    mov     rax, r8
    imul    rax, QWORD PTR [rsp+56]
    mov     rdx, rax
    ADDMUL_FIRST_2
    ADDMUL_REST_2
    add     r9, QWORD PTR [rdi + (2 + iter_idx)*8]
    jnc     ri_nc
_prop_idx = 3 + iter_idx
    WHILE _prop_idx LE 4
        add     QWORD PTR [rdi + _prop_idx*8], 1
        jnc     ri_nc
_prop_idx = _prop_idx + 1
    ENDM
ri_nc:
ENDM


; =====================================================================
; void mpn_mont_mul_2(uint64_t* rp, const uint64_t* ap,
;                      const uint64_t* bp, const uint64_t* mp,
;                      uint64_t m_inv)
;
; レジスタ常駐 n=2 SOS Montgomery 乗算: rp = ap * bp * R^{-1} mod mp
; 128-bit 特化
;
; rcx = rp, rdx = ap, r8 = bp, r9 = mp
; [rsp+144] = m_inv (5th arg, after 4 pushes + sub 72)
;
; Phase 1: 2×2 完全乗算 (SOS, 4 MULX)
; Phase 2: REDC 2 反復 (4 MULX)
; Phase 3: 条件付き減算
; =====================================================================
mpn_mont_mul_2 PROC
    push    rbx
    push    rbp
    push    rsi
    push    rdi
    sub     rsp, 72

    ; スタック配置:
    ;   4 push × 8 = 32, sub 72 → total = 104
    ;   5th arg (m_inv): [rsp + 104 + 40] = [rsp + 144]
    ;
    ;   [rsp+0..39]   = P[0..4] product buffer (5 qwords)
    ;   [rsp+40]      = rp
    ;   [rsp+48]      = mp
    ;   [rsp+56]      = m_inv

    mov     [rsp+40], rcx
    mov     rsi, rdx            ; rsi = ap
    mov     rcx, r8             ; rcx = bp
    mov     [rsp+48], r9        ; mp
    mov     rax, [rsp+144]
    mov     [rsp+56], rax       ; m_inv

    mov     rdi, rsp            ; product buffer

    ; ===== Phase 1: 2×2 完全乗算 (SOS) =====
    xor     r8d, r8d
    xor     r9d, r9d

    ; Row 0: b[0]
    mov     rdx, [rcx]
    ADDMUL_FIRST_2
    mov     QWORD PTR [rdi], r8
    ADDMUL_REST_2

    ; Row 1: b[1]
    mov     rdx, [rcx+8]
    ADDMUL_FIRST_2
    mov     QWORD PTR [rdi+8], r8
    ADDMUL_REST_2

    ; P[2..3] ← 累算器, P[4] = 0
    mov     [rdi+16], r8
    mov     [rdi+24], r9
    mov     QWORD PTR [rdi+32], 0

    ; ===== Phase 2: REDC (2 反復) =====
    mov     r8, [rdi]
    mov     r9, [rdi+8]

    mov     rsi, [rsp+48]       ; mp

    REDC_ITER_2 0
    REDC_ITER_2 1

    ; ===== Phase 3: 条件付き減算 =====
    mov     rbx, [rsp+40]       ; rp
    mov     rsi, [rsp+48]       ; mp

    mov     rax, r8
    sub     rax, [rsi]
    mov     [rbx], rax

    mov     rax, r9
    sbb     rax, [rsi+8]
    mov     [rbx+8], rax

    ; P[4] オーバーフロー考慮
    mov     rax, [rdi+32]
    sbb     rax, 0
    jnc     m2_done

    ; borrow → 減算結果は無効
    mov     [rbx], r8
    mov     [rbx+8], r9

m2_done:
    add     rsp, 72
    pop     rdi
    pop     rsi
    pop     rbp
    pop     rbx
    ret
mpn_mont_mul_2 ENDP


; =====================================================================
; void mpn_mont_sqr_2(uint64_t* rp, const uint64_t* ap,
;                      const uint64_t* mp, uint64_t m_inv)
;
; レジスタ常駐 n=2 統合 SQR+REDC: rp = ap^2 * R^{-1} mod mp
; 128-bit 特化, 対称性利用 (1+2=3 MULX vs 乗算の 4 MULX)
;
; rcx = rp, rdx = ap, r8 = mp, r9 = m_inv
;
; Phase 1a: 上三角乗算 (off-diagonal, 1 MULX)
; Phase 1b: 二倍化 (P[1..3] <<= 1)
; Phase 1c: 対角成分 a[i]^2 加算 (2 MULX)
; Phase 2: REDC 2 反復 (REDC_ITER_2 マクロ)
; Phase 3: 条件付き減算
; =====================================================================
mpn_mont_sqr_2 PROC
    push    rbx
    push    rbp
    push    rsi
    push    rdi
    sub     rsp, 72

    ; スタック配置 (mont_mul_2 と同一):
    ;   [rsp+0..39]   = P[0..4] product buffer (5 qwords)
    ;   [rsp+40]      = rp
    ;   [rsp+48]      = mp
    ;   [rsp+56]      = m_inv

    mov     [rsp+40], rcx       ; rp
    mov     rsi, rdx            ; rsi = ap
    mov     [rsp+48], r8        ; mp
    mov     [rsp+56], r9        ; m_inv

    mov     rdi, rsp            ; product buffer

    ; --- Zero P[0..4] ---
    xor     eax, eax
    mov     [rdi], rax
    mov     [rdi+8], rax
    mov     [rdi+16], rax
    mov     [rdi+24], rax
    mov     [rdi+32], rax

    ; ===== Phase 1a: Off-diagonal 上三角 (1 MULX) =====
    ; a[0]*a[1] → P[1], P[2]
    mov     rdx, [rsi]
    mulx    rbx, rcx, [rsi+8]
    mov     [rdi+8], rcx
    mov     [rdi+16], rbx

    ; ===== Phase 1b: 二倍化 P[1..3] =====
    mov     rax, [rdi+8]
    add     rax, rax
    mov     [rdi+8], rax
    mov     rax, [rdi+16]
    adc     rax, rax
    mov     [rdi+16], rax
    mov     rax, [rdi+24]
    adc     rax, rax
    mov     [rdi+24], rax
    ; carry → P[4]
    mov     rax, 0
    adc     rax, 0
    mov     [rdi+32], rax

    ; ===== Phase 1c: 対角成分 a[i]^2 加算 =====
    mov     rdx, [rsi]
    mulx    rbx, rax, rdx
    add     [rdi], rax
    adc     [rdi+8], rbx

    mov     rdx, [rsi+8]
    mulx    rbx, rax, rdx
    adc     [rdi+16], rax
    adc     [rdi+24], rbx
    adc     QWORD PTR [rdi+32], 0

    ; ===== Phase 2: REDC (2 反復) =====
    mov     r8, [rdi]
    mov     r9, [rdi+8]

    mov     rsi, [rsp+48]       ; mp

    REDC_ITER_2 0
    REDC_ITER_2 1

    ; ===== Phase 3: 条件付き減算 =====
    mov     rbx, [rsp+40]       ; rp
    mov     rsi, [rsp+48]       ; mp

    mov     rax, r8
    sub     rax, [rsi]
    mov     [rbx], rax

    mov     rax, r9
    sbb     rax, [rsi+8]
    mov     [rbx+8], rax

    ; P[4] オーバーフロー考慮
    mov     rax, [rdi+32]
    sbb     rax, 0
    jnc     s2_done

    ; borrow → 減算結果は無効
    mov     [rbx], r8
    mov     [rbx+8], r9

s2_done:
    add     rsp, 72
    pop     rdi
    pop     rsi
    pop     rbp
    pop     rbx
    ret
mpn_mont_sqr_2 ENDP


END
