; Copyright (C) 2026 Kiyotsugu Arai
; SPDX-License-Identifier: LGPL-3.0-or-later
;
; mpn_x64.asm — x86-64 最適化ルーチン
;
; 関数 (BMI2+ADX 必須):
;   mpn_addmul_1_mulx(rp, ap, n, b) -> carry
;   mpn_mul_1_mulx(rp, ap, n, b)    -> carry
;   mpn_submul_1_mulx(rp, ap, n, b) -> borrow
;   mpn_mul_basecase_mulx(rp, ap, an, bp, bn)
;   mpn_sqr_basecase_mulx(rp, ap, n)
;
; 関数 (基本 x86-64 のみ):
;   mpn_add_n_asm(rp, ap, bp, n) -> carry
;   mpn_sub_n_asm(rp, ap, bp, n) -> borrow
;   mpn_lshift_asm(rp, ap, n, shift) -> overflow
;   mpn_rshift_asm(rp, ap, n, shift) -> underflow (shifted-out bits)
;
; 前提: CPU が BMI2 + ADX をサポートしていること (呼び出し元で CPUID チェック済み)
;
; Windows x64 calling convention:
;   rcx = 1st, rdx = 2nd, r8 = 3rd, r9 = 4th
;   戻り値: rax
;   非破壊: rbx, rbp, rdi, rsi, r12-r15
;   破壊可: rax, rcx, rdx, r8, r9, r10, r11
;
; MULX+ADCX/ADOX の利点:
;   MULX はフラグ非破壊。ADCX は CF のみ、ADOX は OF のみ使用。
;   2 本の独立キャリーチェーンでデータ依存を分離し、
;   1 反復あたりの critical path を短縮 (~2 cycles vs ~4 cycles)。
;
; ループ制御:
;   LEA (フラグ非破壊) と JRCXZ (フラグ非破壊) を使用し、
;   ループ内で CF/OF を一切破壊しない。

.code

; =====================================================================
; uint64_t mpn_addmul_1_mulx(uint64_t* rp, const uint64_t* ap,
;                             size_t n, uint64_t b)
;
; rp[0..n-1] += ap[0..n-1] * b
; 戻り値: キャリー
;
; 4x アンロール版:
;   4 つの独立な MULX を先行発行し、ADCX/ADOX の双キャリーチェーンで
;   結果を畳み込む。1x版の ~2 cycles/elem → ~1.75 cycles/elem に改善。
;   - ADCX chain (CF): lo += prev_hi (積のキャリー伝播)
;   - ADOX chain (OF): lo += rp[i]   (メモリ値の累加)
;   4 要素ごとにキャリーを回収し、CF/OF をリセットしてからループ継続。
;
; レジスタ割り当て (4x ループ):
;   rdx = b (MULX 暗黙), rsi = ap, rbx = rp, rcx = 残り要素数
;   r8  = 0 (定数), r9 = carry (グループ間)
;   MULX 出力: r10:r11 (elem0), r12:r13 (elem1), r14:r15 (elem2), rdi:rax (elem3)
; =====================================================================
mpn_addmul_1_mulx PROC
    push    rbx
    push    rsi
    push    rdi
    push    r12
    push    r13
    push    r14
    push    r15

    mov     rsi, rdx            ; rsi = ap
    mov     rbx, rcx            ; rbx = rp
    mov     rcx, r8             ; rcx = n
    mov     rdx, r9             ; rdx = b (MULX multiplier)

    xor     r9d, r9d            ; carry = 0
    xor     r8d, r8d            ; r8 = 0 (定数)

    test    rcx, rcx
    jz      am1_zero

    cmp     rcx, 8
    jb      am1_check_4x

    ; ---- 8x メインループ ----
    ; CF/OF を Group 1 → Group 2 に自然に流す (mid-group merge 不要)
    ; Group 2 末尾でのみ carry 回収
    ALIGN   16
am1_8x_entry:
    ; --- Group 1: elements [0..3] ---
    xor     eax, eax            ; CF = 0, OF = 0

    mulx    r10, r11, [rsi]
    mulx    r12, r13, [rsi + 8]
    mulx    r14, r15, [rsi + 16]
    mulx    rdi, rax, [rsi + 24]

    adcx    r11, r9
    adox    r11, [rbx]
    mov     QWORD PTR [rbx], r11

    adcx    r13, r10
    adox    r13, [rbx + 8]
    mov     QWORD PTR [rbx + 8], r13

    adcx    r15, r12
    adox    r15, [rbx + 16]
    mov     QWORD PTR [rbx + 16], r15

    adcx    rax, r14
    adox    rax, [rbx + 24]
    mov     QWORD PTR [rbx + 24], rax

    ; CF, OF はそのまま Group 2 に流す (rdi = hi3)

    ; --- Group 2: elements [4..7] ---
    mulx    r10, r11, [rsi + 32]
    mulx    r12, r13, [rsi + 40]
    mulx    r14, r15, [rsi + 48]
    mulx    r9, rax, [rsi + 56]

    adcx    r11, rdi
    adox    r11, [rbx + 32]
    mov     QWORD PTR [rbx + 32], r11

    adcx    r13, r10
    adox    r13, [rbx + 40]
    mov     QWORD PTR [rbx + 40], r13

    adcx    r15, r12
    adox    r15, [rbx + 48]
    mov     QWORD PTR [rbx + 48], r15

    adcx    rax, r14
    adox    rax, [rbx + 56]
    mov     QWORD PTR [rbx + 56], rax

    ; carry 回収 (r9 = hi7, CF/OF を回収)
    adcx    r9, r8
    adox    r9, r8

    lea     rsi, [rsi + 64]
    lea     rbx, [rbx + 64]
    sub     rcx, 8
    cmp     rcx, 8
    jae     am1_8x_entry

    ; ---- 残り 0-7 要素 ----
am1_check_4x:
    cmp     rcx, 4
    jb      am1_tail_check

    ; ---- 4x 単発 ----
    xor     eax, eax
    mulx    r10, r11, [rsi]
    mulx    r12, r13, [rsi + 8]
    mulx    r14, r15, [rsi + 16]
    mulx    rdi, rax, [rsi + 24]

    adcx    r11, r9
    adox    r11, [rbx]
    mov     QWORD PTR [rbx], r11
    adcx    r13, r10
    adox    r13, [rbx + 8]
    mov     QWORD PTR [rbx + 8], r13
    adcx    r15, r12
    adox    r15, [rbx + 16]
    mov     QWORD PTR [rbx + 16], r15
    adcx    rax, r14
    adox    rax, [rbx + 24]
    mov     QWORD PTR [rbx + 24], rax

    mov     r9, rdi
    adcx    r9, r8
    adox    r9, r8
    lea     rsi, [rsi + 32]
    lea     rbx, [rbx + 32]
    sub     rcx, 4

am1_tail_check:
    test    rcx, rcx
    jz      am1_done

    ; ---- 1x テールループ ----
am1_tail_loop:
    mulx    r10, r11, [rsi]
    add     r11, r9
    adc     r10, 0
    add     QWORD PTR [rbx], r11
    adc     r10, 0
    mov     r9, r10
    lea     rsi, [rsi + 8]
    lea     rbx, [rbx + 8]
    dec     rcx
    jnz     am1_tail_loop

am1_done:
    mov     rax, r9

    pop     r15
    pop     r14
    pop     r13
    pop     r12
    pop     rdi
    pop     rsi
    pop     rbx
    ret

am1_zero:
    xor     eax, eax
    pop     r15
    pop     r14
    pop     r13
    pop     r12
    pop     rdi
    pop     rsi
    pop     rbx
    ret
mpn_addmul_1_mulx ENDP


; =====================================================================
; uint64_t mpn_mul_1_mulx(uint64_t* rp, const uint64_t* ap,
;                          size_t n, uint64_t b)
;
; rp[0..n-1] = ap[0..n-1] * b
; 戻り値: キャリー
;
; 4x アンロール版:
;   addmul_1 と異なり rp[i] への加算が不要なため ADOX 不使用。
;   MULX + ADCX の単一キャリーチェーンのみ (rdi/r12-r15 は不要)。
;   4 つの独立な MULX を先行発行し、ADCX で積のキャリーを伝播。
;   4 要素ごとにキャリーを回収してループ継続。
; =====================================================================
mpn_mul_1_mulx PROC
    push    rbx
    push    rsi
    push    rdi

    mov     rsi, rdx            ; rsi = ap
    mov     rbx, rcx            ; rbx = rp
    mov     rcx, r8             ; rcx = n
    mov     rdx, r9             ; rdx = b

    xor     r9d, r9d            ; carry = 0

    test    rcx, rcx
    jz      m1_zero

    cmp     rcx, 8
    jb      m1_check_4x

    ; ---- 8x メインループ ----
m1_8x_entry:
    xor     eax, eax            ; CF = 0

    ; Group 1: elements [0..3]
    mulx    r10, r11, [rsi]
    mulx    rax, r8, [rsi + 8]
    adcx    r11, r9
    mov     [rbx], r11
    adcx    r8, r10
    mov     [rbx + 8], r8
    mulx    r10, r11, [rsi + 16]
    mulx    r9, r8, [rsi + 24]
    adcx    r11, rax
    mov     [rbx + 16], r11
    adcx    r8, r10
    mov     [rbx + 24], r8

    ; Group 2: elements [4..7]
    mulx    r10, r11, [rsi + 32]
    mulx    rax, rdi, [rsi + 40]
    adcx    r11, r9
    mov     [rbx + 32], r11
    adcx    rdi, r10
    mov     [rbx + 40], rdi
    mulx    r10, r11, [rsi + 48]
    mulx    r9, rdi, [rsi + 56]
    adcx    r11, rax
    mov     [rbx + 48], r11
    adcx    rdi, r10
    mov     [rbx + 56], rdi

    ; キャリー回収
    mov     r8d, 0              ; フラグ非破壊
    adcx    r9, r8

    lea     rsi, [rsi + 64]
    lea     rbx, [rbx + 64]
    sub     rcx, 8
    cmp     rcx, 8
    jae     m1_8x_entry

    ; ---- 残り 4-7 要素: 4x ブロック ----
m1_check_4x:
    cmp     rcx, 4
    jb      m1_tail_check

    xor     eax, eax            ; CF = 0
    mulx    r10, r11, [rsi]
    mulx    rax, r8, [rsi + 8]
    adcx    r11, r9
    mov     [rbx], r11
    adcx    r8, r10
    mov     [rbx + 8], r8
    mulx    r10, r11, [rsi + 16]
    mulx    r9, r8, [rsi + 24]
    adcx    r11, rax
    mov     [rbx + 16], r11
    adcx    r8, r10
    mov     [rbx + 24], r8
    mov     r8d, 0
    adcx    r9, r8
    lea     rsi, [rsi + 32]
    lea     rbx, [rbx + 32]
    sub     rcx, 4

    ; ---- Tail: 残り 0-3 要素 ----
m1_tail_check:
    test    rcx, rcx
    jz      m1_done

m1_tail_loop:
    mulx    r10, r11, [rsi]
    add     r11, r9
    adc     r10, 0
    mov     [rbx], r11
    mov     r9, r10
    lea     rsi, [rsi + 8]
    lea     rbx, [rbx + 8]
    dec     rcx
    jnz     m1_tail_loop

m1_done:
    mov     rax, r9
    pop     rdi
    pop     rsi
    pop     rbx
    ret

m1_zero:
    xor     eax, eax
    pop     rdi
    pop     rsi
    pop     rbx
    ret
mpn_mul_1_mulx ENDP


; =====================================================================
; uint64_t mpn_submul_1_mulx(uint64_t* rp, const uint64_t* ap,
;                              size_t n, uint64_t b)
;
; rp[0..n-1] -= ap[0..n-1] * b
; 戻り値: borrow (キャリー兼ボロー)
;
; ADCX/ADOX の双キャリーチェーンは使えない (SUB が OF を破壊するため)。
; 代わりに MULX + ADD/ADC/SUB/ADC の単一チェーンを 4x/8x アンロール。
; MULX はフラグ非破壊なので、先行発行して OOO パイプラインを活用。
;
; 各要素の演算 (4 依存命令):
;   (hi, lo) = a[i] * b       ; MULX (フラグ不変、パイプライン可能)
;   lo += carry; c1 = CF       ; ADD
;   hi += c1                   ; ADC
;   rp[i] -= lo; c2 = CF       ; SUB
;   carry = hi + c2             ; ADC
;
; 8x アンロールでループオーバーヘッドを 0.5 cycles/elem に削減。
; MULX のペア先行発行で OOO が carry チェーンと並行して乗算を実行。
; =====================================================================
mpn_submul_1_mulx PROC
    push    rbx
    push    rsi

    mov     rsi, rdx            ; rsi = ap
    mov     rbx, rcx            ; rbx = rp
    mov     rcx, r8             ; rcx = n
    mov     rdx, r9             ; rdx = b (MULX multiplier)

    xor     r9d, r9d            ; carry = 0

    test    rcx, rcx
    jz      submul1_zero

    cmp     rcx, 8
    jb      submul1_check_4x

    ; ---- 8x メインループ ----
submul1_8x_entry:
    ; --- Group 1: elements 0-3 ---
    mulx    r10, r11, [rsi]           ; hi0:lo0
    mulx    rax, r8, [rsi + 8]        ; hi1:lo1

    add     r11, r9                   ; lo0 += carry
    adc     r10, 0
    sub     QWORD PTR [rbx], r11
    adc     r10, 0

    add     r8, r10                   ; lo1 += carry1
    adc     rax, 0
    sub     QWORD PTR [rbx + 8], r8
    adc     rax, 0

    mulx    r10, r11, [rsi + 16]      ; hi2:lo2
    mulx    r9, r8, [rsi + 24]        ; hi3:lo3

    add     r11, rax                  ; lo2 += carry2
    adc     r10, 0
    sub     QWORD PTR [rbx + 16], r11
    adc     r10, 0

    add     r8, r10                   ; lo3 += carry3
    adc     r9, 0
    sub     QWORD PTR [rbx + 24], r8
    adc     r9, 0

    ; --- Group 2: elements 4-7 ---
    mulx    r10, r11, [rsi + 32]
    mulx    rax, r8, [rsi + 40]

    add     r11, r9
    adc     r10, 0
    sub     QWORD PTR [rbx + 32], r11
    adc     r10, 0

    add     r8, r10
    adc     rax, 0
    sub     QWORD PTR [rbx + 40], r8
    adc     rax, 0

    mulx    r10, r11, [rsi + 48]
    mulx    r9, r8, [rsi + 56]

    add     r11, rax
    adc     r10, 0
    sub     QWORD PTR [rbx + 48], r11
    adc     r10, 0

    add     r8, r10
    adc     r9, 0
    sub     QWORD PTR [rbx + 56], r8
    adc     r9, 0

    lea     rsi, [rsi + 64]
    lea     rbx, [rbx + 64]
    sub     rcx, 8
    cmp     rcx, 8
    jae     submul1_8x_entry

    ; ---- 残り 0-7 要素 ----
submul1_check_4x:
    cmp     rcx, 4
    jb      submul1_tail_check

    ; ---- 4x 単発 ----
    mulx    r10, r11, [rsi]
    mulx    rax, r8, [rsi + 8]

    add     r11, r9
    adc     r10, 0
    sub     QWORD PTR [rbx], r11
    adc     r10, 0

    add     r8, r10
    adc     rax, 0
    sub     QWORD PTR [rbx + 8], r8
    adc     rax, 0

    mulx    r10, r11, [rsi + 16]
    mulx    r9, r8, [rsi + 24]

    add     r11, rax
    adc     r10, 0
    sub     QWORD PTR [rbx + 16], r11
    adc     r10, 0

    add     r8, r10
    adc     r9, 0
    sub     QWORD PTR [rbx + 24], r8
    adc     r9, 0

    lea     rsi, [rsi + 32]
    lea     rbx, [rbx + 32]
    sub     rcx, 4

submul1_tail_check:
    test    rcx, rcx
    jz      submul1_done

    ; ---- 1x テールループ ----
submul1_tail_loop:
    mulx    r10, r11, [rsi]
    add     r11, r9
    adc     r10, 0
    sub     QWORD PTR [rbx], r11
    adc     r10, 0
    mov     r9, r10
    lea     rsi, [rsi + 8]
    lea     rbx, [rbx + 8]
    dec     rcx
    jnz     submul1_tail_loop

submul1_done:
    mov     rax, r9
    pop     rsi
    pop     rbx
    ret

submul1_zero:
    xor     eax, eax
    pop     rsi
    pop     rbx
    ret
mpn_submul_1_mulx ENDP


; =====================================================================
; void mpn_mul_basecase_mulx(uint64_t* rp, const uint64_t* ap,
;                             size_t an, const uint64_t* bp, size_t bn)
;
; rp[0..an+bn-1] = ap[0..an-1] * bp[0..bn-1]
; 前提: an >= bn >= 1, rp は ap,bp と重なってはならない
;
; C++ 版 mul_basecase は mul_1 + addmul_1×(bn-1) を個別に呼ぶため、
; 各呼び出しのプロローグ/エピローグ (addmul_1 は 7 push/pop) が
; 小サイズで支配的になる。この関数は全行を 1 回の push/pop で処理。
;
; Windows x64 calling convention:
;   rcx = rp, rdx = ap, r8 = an, r9 = bp, [rsp+40] = bn
;
; Register allocation:
;   Inner loop (mul_1/addmul_1 body):
;     rbx = rp working pointer, rsi = ap working pointer
;     rdx = multiplier b[j], rcx = inner counter
;     r9 = carry, r8 = 0 constant
;     r10:r11, r12:r13, r14:r15, rdi:rax = 4x MULX hi:lo pairs
;   Outer loop state (saved to stack during inner loop):
;     [rsp+0]  = rp_base
;     [rsp+8]  = ap_base
;     [rsp+16] = an
;     [rsp+24] = bp_base
;     [rsp+32] = row counter (remaining addmul rows)
; =====================================================================
mpn_mul_basecase_mulx PROC
    push    rbx
    push    rsi
    push    rdi
    push    r12
    push    r13
    push    r14
    push    r15
    push    rbp
    sub     rsp, 40                 ; ローカル変数 (5 qwords)
    ; スタック配置:
    ;   8 push × 8 = 64 bytes + sub 40 = 104 bytes
    ;   5th arg (bn): [rsp + 40 + 64 + 40] = [rsp + 144]

    ; パラメータ保存
    mov     [rsp+0], rcx            ; rp_base
    mov     [rsp+8], rdx            ; ap_base
    mov     [rsp+16], r8            ; an
    mov     [rsp+24], r9            ; bp_base
    mov     rax, [rsp+144]          ; bn
    mov     [rsp+32], rax           ; row_counter

    ; ========== Row 0: mul_1 (r[0..an] = a * b[0]) ==========
    mov     rbx, rcx                ; rbx = rp
    mov     rsi, rdx                ; rsi = ap
    mov     rdx, [r9]               ; rdx = b[0] (MULX multiplier)
    mov     rcx, r8                 ; rcx = an (inner counter)
    xor     r9d, r9d                ; carry = 0
    xor     r8d, r8d                ; r8 = 0

    cmp     rcx, 8
    jb      bc_r0_check_4x

    ALIGN   16
bc_r0_8x:
    ; --- Group 1: elements [0..3] ---
    xor     eax, eax                ; CF = 0
    mulx    r10, r11, [rsi]
    mulx    rax, r8, [rsi + 8]
    adcx    r11, r9
    mov     [rbx], r11
    adcx    r8, r10
    mov     [rbx + 8], r8
    mulx    r10, r11, [rsi + 16]
    mulx    r9, r8, [rsi + 24]
    adcx    r11, rax
    mov     [rbx + 16], r11
    adcx    r8, r10
    mov     [rbx + 24], r8
    ; --- Group 2: elements [4..7] ---
    mulx    r10, r11, [rsi + 32]
    mulx    rax, rdi, [rsi + 40]
    adcx    r11, r9
    mov     [rbx + 32], r11
    adcx    rdi, r10
    mov     [rbx + 40], rdi
    mulx    r10, r11, [rsi + 48]
    mulx    r9, rdi, [rsi + 56]
    adcx    r11, rax
    mov     [rbx + 48], r11
    adcx    rdi, r10
    mov     [rbx + 56], rdi
    mov     r8d, 0                  ; フラグ非破壊 (xor は CF を破壊するため不可)
    adcx    r9, r8
    lea     rsi, [rsi + 64]
    lea     rbx, [rbx + 64]
    sub     rcx, 8
    cmp     rcx, 8
    jae     bc_r0_8x

bc_r0_check_4x:
    cmp     rcx, 4
    jb      bc_r0_tail_check

    xor     eax, eax                ; CF = 0
    mulx    r10, r11, [rsi]
    mulx    rax, r8, [rsi + 8]
    adcx    r11, r9
    mov     [rbx], r11
    adcx    r8, r10
    mov     [rbx + 8], r8
    mulx    r10, r11, [rsi + 16]
    mulx    r9, r8, [rsi + 24]
    adcx    r11, rax
    mov     [rbx + 16], r11
    adcx    r8, r10
    mov     [rbx + 24], r8
    mov     r8d, 0
    adcx    r9, r8
    lea     rsi, [rsi + 32]
    lea     rbx, [rbx + 32]
    sub     rcx, 4

bc_r0_tail_check:
    test    rcx, rcx
    jz      bc_r0_done

bc_r0_tail:
    mulx    r10, r11, [rsi]
    add     r11, r9
    adc     r10, 0
    mov     [rbx], r11
    mov     r9, r10
    lea     rsi, [rsi + 8]
    lea     rbx, [rbx + 8]
    dec     rcx
    jnz     bc_r0_tail

bc_r0_done:
    mov     [rbx], r9               ; rp[an] = carry

    ; bn == 1 なら完了
    dec     QWORD PTR [rsp+32]
    jz      bc_finish

    ; ========== Rows 1..bn-1: addmul_1 (outer loop) ==========
    mov     rbp, 1                  ; j = 1 (行インデックス)

bc_outer:
    ; ap, an を復元
    mov     rsi, [rsp+8]            ; ap_base
    mov     rcx, [rsp+16]           ; an

    ; rbx = rp + j*8 (この行の内側ベースアドレス)
    mov     rbx, [rsp+0]
    lea     rbx, [rbx + rbp*8]

    ; b[j] をロード
    mov     rax, [rsp+24]           ; bp_base
    mov     rdx, [rax + rbp*8]      ; b[j] → MULX multiplier

    ; b[j] == 0 のスキップ
    test    rdx, rdx
    jz      bc_zero_row

    ; ---- addmul_1 inner loop (8x アンロール) ----
    xor     r9d, r9d                ; carry = 0
    xor     r8d, r8d                ; 0 constant

    cmp     rcx, 8
    jb      bc_am_check_4x

    ALIGN   16
bc_am_8x:
    ; --- Group 1: elements [0..3] ---
    xor     eax, eax
    mulx    r10, r11, [rsi]
    mulx    r12, r13, [rsi + 8]
    mulx    r14, r15, [rsi + 16]
    mulx    rdi, rax, [rsi + 24]
    adcx    r11, r9
    adox    r11, [rbx]
    mov     [rbx], r11
    adcx    r13, r10
    adox    r13, [rbx + 8]
    mov     [rbx + 8], r13
    adcx    r15, r12
    adox    r15, [rbx + 16]
    mov     [rbx + 16], r15
    adcx    rax, r14
    adox    rax, [rbx + 24]
    mov     [rbx + 24], rax
    ; CF, OF はそのまま Group 2 に流す (rdi = hi3)

    ; --- Group 2: elements [4..7] ---
    mulx    r10, r11, [rsi + 32]
    mulx    r12, r13, [rsi + 40]
    mulx    r14, r15, [rsi + 48]
    mulx    r9, rax, [rsi + 56]
    adcx    r11, rdi
    adox    r11, [rbx + 32]
    mov     [rbx + 32], r11
    adcx    r13, r10
    adox    r13, [rbx + 40]
    mov     [rbx + 40], r13
    adcx    r15, r12
    adox    r15, [rbx + 48]
    mov     [rbx + 48], r15
    adcx    rax, r14
    adox    rax, [rbx + 56]
    mov     [rbx + 56], rax
    ; carry 回収 (r9 = hi7)
    adcx    r9, r8
    adox    r9, r8

    lea     rsi, [rsi + 64]
    lea     rbx, [rbx + 64]
    sub     rcx, 8
    cmp     rcx, 8
    jae     bc_am_8x

bc_am_check_4x:
    cmp     rcx, 4
    jb      bc_am_tail_check

    xor     eax, eax
    mulx    r10, r11, [rsi]
    mulx    r12, r13, [rsi + 8]
    mulx    r14, r15, [rsi + 16]
    mulx    rdi, rax, [rsi + 24]
    adcx    r11, r9
    adox    r11, [rbx]
    mov     [rbx], r11
    adcx    r13, r10
    adox    r13, [rbx + 8]
    mov     [rbx + 8], r13
    adcx    r15, r12
    adox    r15, [rbx + 16]
    mov     [rbx + 16], r15
    adcx    rax, r14
    adox    rax, [rbx + 24]
    mov     [rbx + 24], rax
    mov     r9, rdi
    adcx    r9, r8
    adox    r9, r8
    lea     rsi, [rsi + 32]
    lea     rbx, [rbx + 32]
    sub     rcx, 4

bc_am_tail_check:
    test    rcx, rcx
    jz      bc_am_done

bc_am_tail:
    mulx    r10, r11, [rsi]
    add     r11, r9
    adc     r10, 0
    add     QWORD PTR [rbx], r11
    adc     r10, 0
    mov     r9, r10
    lea     rsi, [rsi + 8]
    lea     rbx, [rbx + 8]
    dec     rcx
    jnz     bc_am_tail

bc_am_done:
    ; carry を rp[an+j] に書き込み
    ; (rbx は rp + j*8 + an*8 = rp + (an+j)*8 に到達している)
    mov     [rbx], r9
    jmp     bc_next_row

bc_zero_row:
    ; b[j] == 0: rp[an+j] = 0
    mov     rax, [rsp+0]            ; rp_base
    mov     rcx, [rsp+16]           ; an
    add     rcx, rbp                ; an + j
    mov     QWORD PTR [rax + rcx*8], 0

bc_next_row:
    inc     rbp
    dec     QWORD PTR [rsp+32]
    jnz     bc_outer

bc_finish:
    add     rsp, 40
    pop     rbp
    pop     r15
    pop     r14
    pop     r13
    pop     r12
    pop     rdi
    pop     rsi
    pop     rbx
    ret
mpn_mul_basecase_mulx ENDP

; =====================================================================
; uint64_t mpn_add_n_asm(uint64_t* rp, const uint64_t* ap,
;                         const uint64_t* bp, size_t n)
;
; rp[0..n-1] = ap[0..n-1] + bp[0..n-1]
; 戻り値: キャリー (0 or 1)
; 前提: n >= 1
;
; BMI2/ADX 不要。基本 x86-64 の ADC 命令のみ使用。
; MSVC の _addcarry_u64 が毎反復 CF↔汎用レジスタ変換を生成するのに対し、
; 純粋な ADC チェーンでキャリーをハードウェアフラグに保持し続ける。
;
; 8x アンロール。ループ制御:
;   - DEC (CF 非破壊) + JNZ でカウンタ管理
;   - LEA (フラグ非破壊) でポインタ進行
;   - JRCXZ (フラグ非破壊) でメインループ判定
;
; レジスタ割り当て:
;   rbx = rp (callee-saved), rdx = ap, r8 = bp
;   r10 = n%8 (残余カウンタ), rcx = n/8 (メインループ回数)
; =====================================================================
mpn_add_n_asm PROC
    push    rbx
    mov     rbx, rcx                ; rbx = rp (rcx を解放)

    mov     r10, r9
    and     r10, 7                  ; r10 = n % 8 (残余要素数)
    mov     rcx, r9
    shr     rcx, 3                  ; rcx = n / 8 (メインループ回数)

    clc                             ; CF = 0

    test    r10, r10                ; CF は clc で 0 なので破壊されても問題なし
    jz      addn_rem_done

addn_rem_loop:
    mov     rax, [rdx]
    adc     rax, [r8]
    mov     [rbx], rax
    lea     rbx, [rbx + 8]         ; lea はフラグ非破壊
    lea     rdx, [rdx + 8]
    lea     r8,  [r8  + 8]
    dec     r10                     ; dec は CF を保存
    jnz     addn_rem_loop

addn_rem_done:
    ; CF は残余ループから引き継がれている (残余なしの場合は 0)
    jrcxz   addn_done               ; rcx == 0 なら終了 (フラグ非破壊!)

addn_main_loop:
    mov     rax, [rdx]
    adc     rax, [r8]
    mov     [rbx], rax

    mov     rax, [rdx + 8]
    adc     rax, [r8  + 8]
    mov     [rbx + 8], rax

    mov     rax, [rdx + 16]
    adc     rax, [r8  + 16]
    mov     [rbx + 16], rax

    mov     rax, [rdx + 24]
    adc     rax, [r8  + 24]
    mov     [rbx + 24], rax

    mov     rax, [rdx + 32]
    adc     rax, [r8  + 32]
    mov     [rbx + 32], rax

    mov     rax, [rdx + 40]
    adc     rax, [r8  + 40]
    mov     [rbx + 40], rax

    mov     rax, [rdx + 48]
    adc     rax, [r8  + 48]
    mov     [rbx + 48], rax

    mov     rax, [rdx + 56]
    adc     rax, [r8  + 56]
    mov     [rbx + 56], rax

    lea     rbx, [rbx + 64]
    lea     rdx, [rdx + 64]
    lea     r8,  [r8  + 64]
    dec     rcx                     ; dec は CF を保存
    jnz     addn_main_loop

addn_done:
    setc    al
    movzx   eax, al
    pop     rbx
    ret
mpn_add_n_asm ENDP

; =====================================================================
; uint64_t mpn_sub_n_asm(uint64_t* rp, const uint64_t* ap,
;                         const uint64_t* bp, size_t n)
;
; rp[0..n-1] = ap[0..n-1] - bp[0..n-1]
; 戻り値: ボロー (0 or 1)
; 前提: n >= 1
;
; mpn_add_n_asm と同一構造。ADC → SBB に変更。8x アンロール。
; =====================================================================
mpn_sub_n_asm PROC
    push    rbx
    mov     rbx, rcx

    mov     r10, r9
    and     r10, 7
    mov     rcx, r9
    shr     rcx, 3

    clc

    test    r10, r10
    jz      subn_rem_done

subn_rem_loop:
    mov     rax, [rdx]
    sbb     rax, [r8]
    mov     [rbx], rax
    lea     rbx, [rbx + 8]
    lea     rdx, [rdx + 8]
    lea     r8,  [r8  + 8]
    dec     r10
    jnz     subn_rem_loop

subn_rem_done:
    jrcxz   subn_done

subn_main_loop:
    mov     rax, [rdx]
    sbb     rax, [r8]
    mov     [rbx], rax

    mov     rax, [rdx + 8]
    sbb     rax, [r8  + 8]
    mov     [rbx + 8], rax

    mov     rax, [rdx + 16]
    sbb     rax, [r8  + 16]
    mov     [rbx + 16], rax

    mov     rax, [rdx + 24]
    sbb     rax, [r8  + 24]
    mov     [rbx + 24], rax

    mov     rax, [rdx + 32]
    sbb     rax, [r8  + 32]
    mov     [rbx + 32], rax

    mov     rax, [rdx + 40]
    sbb     rax, [r8  + 40]
    mov     [rbx + 40], rax

    mov     rax, [rdx + 48]
    sbb     rax, [r8  + 48]
    mov     [rbx + 48], rax

    mov     rax, [rdx + 56]
    sbb     rax, [r8  + 56]
    mov     [rbx + 56], rax

    lea     rbx, [rbx + 64]
    lea     rdx, [rdx + 64]
    lea     r8,  [r8  + 64]
    dec     rcx
    jnz     subn_main_loop

subn_done:
    setc    al
    movzx   eax, al
    pop     rbx
    ret
mpn_sub_n_asm ENDP

; =====================================================================
; uint64_t mpn_lshift_asm(uint64_t* rp, const uint64_t* ap,
;                          size_t n, unsigned shift)
;
; rp[0..n-1] = ap[0..n-1] << shift
; 戻り値: 溢れた最上位ビット群
; 前提: n >= 1, 1 <= shift <= 63
;
; アルゴリズム:
;   HIGH→LOW 方向に SHLD 命令で処理。
;   SHLD dst, src, cl: dst = (dst << cl) | (src >> (64-cl))
;   各要素で ap[i] と ap[i-1] を連結シフトし、結果を rp[i] に格納。
;   4x アンロール + 残余ループ。in-place 安全 (HIGH→LOW)。
;
; レジスタ割り当て:
;   rbx = rp (callee-saved), rdx = ap
;   cl = shift, rdi = 前回の ap[i] (callee-saved)
;   r10 = overflow (戻り値), r9 = 残余カウンタ, r8 = メインループ回数
; =====================================================================
mpn_lshift_asm PROC
    push    rbx
    push    rdi
    mov     rbx, rcx                ; rbx = rp
    mov     rcx, r9                 ; cl = shift

    ; ポインタを末尾にセット (HIGH→LOW で処理)
    lea     rbx, [rbx + r8*8 - 8]   ; rbx = &rp[n-1]
    lea     rdx, [rdx + r8*8 - 8]   ; rdx = &ap[n-1]

    ; overflow = ap[n-1] >> (64-shift) を SHLD で計算
    mov     rdi, [rdx]               ; rdi = ap[n-1]
    xor     r10d, r10d
    shld    r10, rdi, cl             ; r10 = rdi >> (64-shift) = overflow

    ; n-1 ペアを処理
    dec     r8                       ; r8 = n-1
    jz      lsh_last                 ; n==1 → 最後の要素だけ

    ; 残余 = (n-1) % 4
    mov     r9, r8
    and     r9, 3                    ; r9 = 残余カウンタ
    shr     r8, 2                    ; r8 = メインループ回数

    test    r9, r9
    jz      lsh_rem_done

lsh_rem_loop:
    mov     rax, [rdx - 8]          ; rax = ap[i-1]
    shld    rdi, rax, cl            ; rdi = (ap[i]<<shift)|(ap[i-1]>>(64-shift))
    mov     [rbx], rdi              ; rp[i] = result
    mov     rdi, rax                ; rdi = ap[i-1] for next iteration
    lea     rbx, [rbx - 8]
    lea     rdx, [rdx - 8]
    dec     r9
    jnz     lsh_rem_loop

lsh_rem_done:
    test    r8, r8
    jz      lsh_last

lsh_main_loop:
    mov     rax, [rdx - 8]
    shld    rdi, rax, cl
    mov     [rbx], rdi
    mov     rdi, rax

    mov     rax, [rdx - 16]
    shld    rdi, rax, cl
    mov     [rbx - 8], rdi
    mov     rdi, rax

    mov     rax, [rdx - 24]
    shld    rdi, rax, cl
    mov     [rbx - 16], rdi
    mov     rdi, rax

    mov     rax, [rdx - 32]
    shld    rdi, rax, cl
    mov     [rbx - 24], rdi
    mov     rdi, rax

    lea     rbx, [rbx - 32]
    lea     rdx, [rdx - 32]
    dec     r8
    jnz     lsh_main_loop

lsh_last:
    ; rp[0] = rdi << shift (rdi は最後に読んだ ap[0])
    shl     rdi, cl
    mov     [rbx], rdi

    mov     rax, r10                ; return overflow
    pop     rdi
    pop     rbx
    ret
mpn_lshift_asm ENDP

; =====================================================================
; uint64_t mpn_rshift_asm(uint64_t* rp, const uint64_t* ap,
;                          size_t n, unsigned shift)
;
; rp[0..n-1] = ap[0..n-1] >> shift
; 戻り値: 最下位から溢れたビット (ap[0] << (64-shift))
; 前提: n >= 1, 1 <= shift <= 63
;
; アルゴリズム:
;   LOW→HIGH 方向に SHRD 命令で処理。
;   SHRD dst, src, cl: dst = (src << (64-cl)) | (dst >> cl)
;   4x アンロール + 残余ループ。in-place 安全 (LOW→HIGH)。
; =====================================================================
mpn_rshift_asm PROC
    push    rbx
    push    rdi
    mov     rbx, rcx                ; rbx = rp
    mov     rcx, r9                 ; cl = shift

    ; underflow = ap[0] << (64-shift) を SHRD で計算
    mov     rdi, [rdx]               ; rdi = ap[0]
    xor     r10d, r10d
    shrd    r10, rdi, cl             ; r10 = rdi << (64-shift) = underflow

    ; n-1 ペアを処理
    dec     r8                       ; r8 = n-1
    jz      rsh_last

    mov     r9, r8
    and     r9, 3
    shr     r8, 2

    test    r9, r9
    jz      rsh_rem_done

rsh_rem_loop:
    mov     rax, [rdx + 8]          ; rax = ap[i+1]
    shrd    rdi, rax, cl            ; rdi = (ap[i+1]<<(64-shift))|(ap[i]>>shift)
    mov     [rbx], rdi
    mov     rdi, rax
    lea     rbx, [rbx + 8]
    lea     rdx, [rdx + 8]
    dec     r9
    jnz     rsh_rem_loop

rsh_rem_done:
    test    r8, r8
    jz      rsh_last

rsh_main_loop:
    mov     rax, [rdx + 8]
    shrd    rdi, rax, cl
    mov     [rbx], rdi
    mov     rdi, rax

    mov     rax, [rdx + 16]
    shrd    rdi, rax, cl
    mov     [rbx + 8], rdi
    mov     rdi, rax

    mov     rax, [rdx + 24]
    shrd    rdi, rax, cl
    mov     [rbx + 16], rdi
    mov     rdi, rax

    mov     rax, [rdx + 32]
    shrd    rdi, rax, cl
    mov     [rbx + 24], rdi
    mov     rdi, rax

    lea     rbx, [rbx + 32]
    lea     rdx, [rdx + 32]
    dec     r8
    jnz     rsh_main_loop

rsh_last:
    ; rp[n-1] = rdi >> shift
    shr     rdi, cl
    mov     [rbx], rdi

    mov     rax, r10                ; return underflow
    pop     rdi
    pop     rbx
    ret
mpn_rshift_asm ENDP

; =====================================================================
; void mpn_sqr_basecase_mulx(uint64_t* rp, const uint64_t* ap, size_t n)
;
; rp[0..2n-1] = ap[0..n-1]²
; 前提: n >= 1, rp と ap は重ならない, rp は事前にゼロクリア不要
;
; アルゴリズム (3 ステップ):
;   Step 1: Off-diagonal — 上三角 Σ_{i<j} a[i]*a[j]*B^(i+j)
;           Row 0: mul_1, Rows 1..n-2: addmul_1 (全行インライン)
;   Step 2: 2 倍 — ADC チェーンによるインライン lshift 1
;   Step 3: Diagonal — a[i]² を r[2i..2i+1] に MULX+ADCX で加算
;
; 対称性により乗算回数 = n(n-1)/2 + n (vs 通常乗算の n²)
; 全行インライン化により addmul_1 の push/pop (14 命令/行) を除去
;
; rcx = rp, rdx = ap, r8 = n
; =====================================================================
mpn_sqr_basecase_mulx PROC
    push    rbx
    push    rsi
    push    rdi
    push    r12
    push    r13
    push    r14
    push    r15
    push    rbp
    sub     rsp, 40                 ; 5 local qwords
    ; Stack: 8 push * 8 = 64 + sub 40 = 104 bytes

    mov     [rsp+0], rcx            ; rp_base
    mov     [rsp+8], rdx            ; ap_base
    mov     [rsp+16], r8            ; n

    ; ===== Step 0: Zero r[0..2n-1] (rep stosq) =====
    mov     rdi, rcx                ; rdi = rp
    mov     rcx, r8
    shl     rcx, 1                  ; rcx = 2n
    xor     eax, eax
    rep     stosq

    ; ===== n == 1: r[0..1] = a[0]² =====
    mov     rcx, [rsp+16]
    cmp     rcx, 1
    ja      sqb_step1

    mov     rsi, [rsp+8]
    mov     rdx, [rsi]
    mulx    r10, r11, rdx
    mov     rbx, [rsp+0]
    mov     [rbx], r11
    mov     [rbx+8], r10
    jmp     sqb_done

    ; ========== Step 1: Off-diagonal ==========
sqb_step1:
    ; --- Row 0: mul_1 (r[1..n-1] = a[1..n-1]*a[0], carry → r[n]) ---
    mov     rbx, [rsp+0]
    lea     rbx, [rbx + 8]         ; rbx = rp + 1
    mov     rsi, [rsp+8]
    mov     rdx, [rsi]              ; a[0]
    lea     rsi, [rsi + 8]          ; ap + 1
    mov     rcx, [rsp+16]
    dec     rcx                     ; n - 1
    xor     r9d, r9d
    xor     r8d, r8d

    cmp     rcx, 4
    jb      sqb_r0_tail

sqb_r0_4x:
    xor     eax, eax
    mulx    r10, r11, [rsi]
    mulx    rax, r8, [rsi + 8]
    adcx    r11, r9
    mov     [rbx], r11
    adcx    r8, r10
    mov     [rbx + 8], r8
    mulx    r10, r11, [rsi + 16]
    mulx    r9, r8, [rsi + 24]
    adcx    r11, rax
    mov     [rbx + 16], r11
    adcx    r8, r10
    mov     [rbx + 24], r8
    mov     r8d, 0
    adcx    r9, r8
    lea     rsi, [rsi + 32]
    lea     rbx, [rbx + 32]
    sub     rcx, 4
    cmp     rcx, 4
    jae     sqb_r0_4x

    test    rcx, rcx
    jz      sqb_r0_done

sqb_r0_tail:
    mulx    r10, r11, [rsi]
    add     r11, r9
    adc     r10, 0
    mov     [rbx], r11
    mov     r9, r10
    lea     rsi, [rsi + 8]
    lea     rbx, [rbx + 8]
    dec     rcx
    jnz     sqb_r0_tail

sqb_r0_done:
    mov     [rbx], r9               ; r[n] = carry

    ; n == 2: no more rows
    mov     rax, [rsp+16]
    cmp     rax, 2
    jbe     sqb_step2

    ; --- Rows 1..n-2: addmul_1 (inline, push/pop 不要) ---
    mov     rbp, 1                  ; i = 1
    mov     rax, [rsp+16]
    sub     rax, 2                  ; row_counter = n - 2
    mov     [rsp+32], rax

sqb_outer:
    ; dst = rp + (2*i + 1)*8
    mov     rbx, [rsp+0]
    lea     rax, [rbp*2 + 1]
    lea     rbx, [rbx + rax*8]

    ; src = ap + (i+1)*8
    mov     rsi, [rsp+8]
    lea     rsi, [rsi + rbp*8 + 8]

    ; multiplier = a[i]
    mov     rax, [rsp+8]
    mov     rdx, [rax + rbp*8]

    ; inner count = n - i - 1
    mov     rcx, [rsp+16]
    sub     rcx, rbp
    dec     rcx

    ; a[i] == 0 をスキップ
    test    rdx, rdx
    jz      sqb_zero_row

    xor     r9d, r9d
    xor     r8d, r8d

    cmp     rcx, 8
    jb      sqb_am_4x_check

    ; ---- 8x unrolled addmul_1 (MULX+ADCX/ADOX) ----
sqb_am_8x:
    xor     eax, eax
    mulx    r10, r11, [rsi]
    mulx    r12, r13, [rsi + 8]
    mulx    r14, r15, [rsi + 16]
    mulx    rdi, rax, [rsi + 24]
    adcx    r11, r9
    adox    r11, [rbx]
    mov     [rbx], r11
    adcx    r13, r10
    adox    r13, [rbx + 8]
    mov     [rbx + 8], r13
    adcx    r15, r12
    adox    r15, [rbx + 16]
    mov     [rbx + 16], r15
    adcx    rax, r14
    adox    rax, [rbx + 24]
    mov     [rbx + 24], rax
    mov     r9, rdi
    adcx    r9, r8
    adox    r9, r8

    mulx    r10, r11, [rsi + 32]
    mulx    r12, r13, [rsi + 40]
    mulx    r14, r15, [rsi + 48]
    mulx    rdi, rax, [rsi + 56]
    adcx    r11, r9
    adox    r11, [rbx + 32]
    mov     [rbx + 32], r11
    adcx    r13, r10
    adox    r13, [rbx + 40]
    mov     [rbx + 40], r13
    adcx    r15, r12
    adox    r15, [rbx + 48]
    mov     [rbx + 48], r15
    adcx    rax, r14
    adox    rax, [rbx + 56]
    mov     [rbx + 56], rax
    mov     r9, rdi
    adcx    r9, r8
    adox    r9, r8

    lea     rsi, [rsi + 64]
    lea     rbx, [rbx + 64]
    sub     rcx, 8
    cmp     rcx, 8
    jae     sqb_am_8x

sqb_am_4x_check:
    cmp     rcx, 4
    jb      sqb_am_tail_check

    xor     eax, eax
    mulx    r10, r11, [rsi]
    mulx    r12, r13, [rsi + 8]
    mulx    r14, r15, [rsi + 16]
    mulx    rdi, rax, [rsi + 24]
    adcx    r11, r9
    adox    r11, [rbx]
    mov     [rbx], r11
    adcx    r13, r10
    adox    r13, [rbx + 8]
    mov     [rbx + 8], r13
    adcx    r15, r12
    adox    r15, [rbx + 16]
    mov     [rbx + 16], r15
    adcx    rax, r14
    adox    rax, [rbx + 24]
    mov     [rbx + 24], rax
    mov     r9, rdi
    adcx    r9, r8
    adox    r9, r8
    lea     rsi, [rsi + 32]
    lea     rbx, [rbx + 32]
    sub     rcx, 4

sqb_am_tail_check:
    test    rcx, rcx
    jz      sqb_am_done

sqb_am_tail:
    mulx    r10, r11, [rsi]
    add     r11, r9
    adc     r10, 0
    add     QWORD PTR [rbx], r11
    adc     r10, 0
    mov     r9, r10
    lea     rsi, [rsi + 8]
    lea     rbx, [rbx + 8]
    dec     rcx
    jnz     sqb_am_tail

sqb_am_done:
    mov     [rbx], r9               ; carry → r[n+i]
    jmp     sqb_next_row

sqb_zero_row:
    ; a[i] == 0: r[n+i] は Step 0 でゼロ済み

sqb_next_row:
    inc     rbp
    dec     QWORD PTR [rsp+32]
    jnz     sqb_outer

    ; ========== Step 2: lshift by 1 (inline ADC chain) ==========
sqb_step2:
    mov     rbx, [rsp+0]           ; rp
    mov     rcx, [rsp+16]          ; n (2 words/iter × n = 2n words)
    clc                             ; CF = 0

sqb_lsh:
    mov     rax, [rbx]
    adc     rax, rax
    mov     [rbx], rax
    mov     rax, [rbx + 8]
    adc     rax, rax
    mov     [rbx + 8], rax
    lea     rbx, [rbx + 16]        ; フラグ非破壊
    dec     rcx                     ; DEC は CF 非破壊
    jnz     sqb_lsh

    ; ========== Step 3: Diagonal — a[i]² を r[2i..2i+1] に加算 ==========
    mov     rbx, [rsp+0]           ; rp
    mov     rsi, [rsp+8]           ; ap
    mov     rcx, [rsp+16]          ; n

    ; n/2 groups + 奇数の場合 1 要素先行処理
    mov     r8, rcx
    shr     rcx, 1                  ; rcx = n / 2
    and     r8, 1                   ; r8 = n % 2

    xor     eax, eax                ; CF = 0 (ADCX チェーン開始)

    test    r8, r8                  ; CF = 0 (test は CF をクリア)
    jz      sqb_diag_even

    ; n が奇数: 先頭 1 要素
    mov     rdx, [rsi]
    mulx    r10, r11, rdx
    adcx    r11, [rbx]
    mov     [rbx], r11
    adcx    r10, [rbx + 8]
    mov     [rbx + 8], r10
    lea     rsi, [rsi + 8]
    lea     rbx, [rbx + 16]

sqb_diag_even:
    jrcxz   sqb_done                ; JRCXZ はフラグ非破壊

sqb_diag_2x:
    mov     rdx, [rsi]
    mulx    r10, r11, rdx           ; a[i]²
    adcx    r11, [rbx]
    mov     [rbx], r11
    adcx    r10, [rbx + 8]
    mov     [rbx + 8], r10

    mov     rdx, [rsi + 8]
    mulx    r10, r11, rdx           ; a[i+1]²
    adcx    r11, [rbx + 16]
    mov     [rbx + 16], r11
    adcx    r10, [rbx + 24]
    mov     [rbx + 24], r10

    lea     rsi, [rsi + 16]
    lea     rbx, [rbx + 32]
    dec     rcx                     ; DEC は CF 非破壊
    jnz     sqb_diag_2x

sqb_done:
    add     rsp, 40
    pop     rbp
    pop     r15
    pop     r14
    pop     r13
    pop     r12
    pop     rdi
    pop     rsi
    pop     rbx
    ret
mpn_sqr_basecase_mulx ENDP

END
