; Copyright (C) 2026 Kiyotsugu Arai
; SPDX-License-Identifier: LGPL-3.0-or-later
;
; mpn_x64_mul.asm — mul_basecase 小サイズ特化 (n×n, n=1..4)
;
; 関数:
;   mpn_mul_small_asm(rp, ap, bp, n) — n×n 完全展開乗算
;
; 目的: mul_basecase の push/pop + ループオーバーヘッドを排除
;   n=2: 4 MULX (push/pop なし, 汎用版の 8 push + 8 pop を排除)
;   n=3: 9 MULX (2 push/pop のみ)
;   n=4: 16 MULX (4 push/pop のみ, 汎用版の 8 を半減)
;
; BMI2 必須 (MULX)
;
; Windows x64 calling convention:
;   rcx = rp, rdx = ap, r8 = bp, r9 = n (1..4)
;   戻り値: なし
;   非破壊: rbx, rbp, rdi, rsi, r12-r15
;   破壊可: rax, rcx, rdx, r8, r9, r10, r11

.code

; =====================================================================
; void mpn_mul_small_asm(uint64_t* rp, const uint64_t* ap,
;                         const uint64_t* bp, size_t n)
;
; rp[0..2n-1] = ap[0..n-1] × bp[0..n-1]
; 前提: 1 <= n <= 4, BMI2 対応 CPU
;
; ジャンプテーブルで n ごとにフルアンロール版にディスパッチ。
; =====================================================================
mpn_mul_small_asm PROC
    lea     r10, [mul_s_jt]
    jmp     qword ptr [r10 + r9*8 - 8]

mul_s_jt:
    dq      mul_s_1x1
    dq      mul_s_2x2
    dq      mul_s_3x3
    dq      mul_s_4x4

; --- 1×1: 1 MULX ---
    ALIGN   16
mul_s_1x1:
    mov     r9, rdx             ; r9 = ap
    mov     rdx, [r8]           ; rdx = b[0]
    mulx    r10, rax, [r9]      ; r10:rax = a[0]*b[0]
    mov     [rcx], rax
    mov     [rcx+8], r10
    ret

; --- 2×2: 4 MULX, push/pop なし ---
;
; r[0..3] = a[0..1] × b[0..1]
;
; レジスタ:
;   r9=ap, rdx=MULX multiplier, rcx=rp
;   r10,r11=accumulator, rax,r8=temp
    ALIGN   16
mul_s_2x2:
    mov     r9, rdx             ; r9 = ap

    ; Row 0: a × b[0]
    mov     rdx, [r8]           ; rdx = b[0]
    mulx    r11, r10, [r9]      ; r11:r10 = a[0]*b[0]
    mov     [rcx], r10          ; r[0]
    mulx    rax, r10, [r9+8]    ; rax:r10 = a[1]*b[0]
    add     r10, r11            ; partial r[1]
    adc     rax, 0              ; partial r[2]

    ; Row 1: a × b[1]
    mov     rdx, [r8+8]         ; rdx = b[1]
    mulx    r11, r8, [r9]       ; r11:r8 = a[0]*b[1]
    add     r10, r8             ; r[1]
    mov     [rcx+8], r10
    mulx    r10, r8, [r9+8]     ; r10:r8 = a[1]*b[1]
    adc     rax, r11            ; partial r[2] += h(a0*b1)
    adc     r10, 0              ; r10 = h(a1*b1) + carry
    add     rax, r8             ; r[2] += l(a1*b1)
    mov     [rcx+16], rax
    adc     r10, 0              ; r[3]
    mov     [rcx+24], r10
    ret

; --- 3×3: 9 MULX, 2 push/pop ---
;
; r[0..5] = a[0..2] × b[0..2]
;
; レジスタ:
;   r9=ap, rbx=bp (callee-saved), rdx=MULX multiplier, rcx=rp
;   r10,r11,r8=accumulator, rax,rdi=temp
    ALIGN   16
mul_s_3x3:
    push    rbx
    push    rdi
    mov     r9, rdx             ; r9 = ap
    mov     rbx, r8             ; rbx = bp

    ; Row 0: a × b[0]
    mov     rdx, [rbx]          ; rdx = b[0]
    mulx    r11, r10, [r9]      ; r11:r10 = a[0]*b[0]
    mov     [rcx], r10          ; r[0]
    mulx    rdi, r10, [r9+8]    ; rdi:r10 = a[1]*b[0]
    add     r10, r11            ; partial r[1]
    mulx    r8, r11, [r9+16]    ; r8:r11 = a[2]*b[0]
    adc     r11, rdi            ; partial r[2]
    adc     r8, 0               ; partial r[3]

    ; Row 1: a × b[1]
    mov     rdx, [rbx+8]        ; rdx = b[1]
    mulx    rdi, rax, [r9]      ; rdi:rax = a[0]*b[1]
    add     r10, rax            ; r[1]
    mov     [rcx+8], r10

    mulx    rax, r10, [r9+8]    ; rax:r10 = a[1]*b[1]
    adc     r11, rdi
    adc     r8, 0
    add     r11, r10            ; partial r[2] += l(a1*b1)

    mulx    rdi, r10, [r9+16]   ; rdi:r10 = a[2]*b[1]
    adc     r8, rax             ; partial r[3] += h(a1*b1)
    adc     rdi, 0              ; partial r[4]
    add     r8, r10             ; partial r[3] += l(a2*b1)
    adc     rdi, 0

    ; Row 2: a × b[2]
    mov     rdx, [rbx+16]       ; rdx = b[2]
    mulx    rax, r10, [r9]      ; rax:r10 = a[0]*b[2]
    add     r11, r10            ; r[2]
    mov     [rcx+16], r11

    mulx    r10, r11, [r9+8]    ; r10:r11 = a[1]*b[2]
    adc     r8, rax
    adc     rdi, 0
    add     r8, r11             ; r[3]
    mov     [rcx+24], r8

    mulx    rax, r11, [r9+16]   ; rax:r11 = a[2]*b[2]
    adc     rdi, r10            ; partial r[4] += h(a1*b2)
    adc     rax, 0              ; partial r[5]
    add     rdi, r11            ; r[4] += l(a2*b2)
    mov     [rcx+32], rdi
    adc     rax, 0              ; r[5]
    mov     [rcx+40], rax

    pop     rdi
    pop     rbx
    ret

; --- 4×4: 16 MULX, 4 push/pop ---
;
; r[0..7] = a[0..3] × b[0..3]
;
; レジスタ:
;   r9=ap, rbx=bp, rdx=MULX, rcx=rp
;   rdi,rsi=callee-saved temp
;   r10,r11,r8,rax=accumulator
    ALIGN   16
mul_s_4x4:
    push    rbx
    push    rdi
    push    rsi
    push    rbp
    mov     r9, rdx             ; r9 = ap
    mov     rbx, r8             ; rbx = bp

    ; Row 0: a × b[0]
    mov     rdx, [rbx]          ; rdx = b[0]
    mulx    r11, r10, [r9]      ; a[0]*b[0]
    mov     [rcx], r10          ; r[0]
    mulx    rdi, r10, [r9+8]    ; a[1]*b[0]
    add     r10, r11            ; partial r[1]
    mulx    rsi, r11, [r9+16]   ; a[2]*b[0]
    adc     r11, rdi            ; partial r[2]
    mulx    r8, rdi, [r9+24]    ; a[3]*b[0]
    adc     rdi, rsi            ; partial r[3]
    adc     r8, 0               ; partial r[4]
    ; r10=r[1]p, r11=r[2]p, rdi=r[3]p, r8=r[4]p

    ; Row 1: a × b[1]
    mov     rdx, [rbx+8]
    mulx    rsi, rax, [r9]      ; a[0]*b[1]
    add     r10, rax            ; r[1]
    mov     [rcx+8], r10

    mulx    rax, r10, [r9+8]    ; a[1]*b[1]
    adc     r11, rsi
    adc     rdi, 0
    adc     r8, 0
    add     r11, r10

    mulx    r10, rsi, [r9+16]   ; a[2]*b[1]
    adc     rdi, rax
    adc     r8, 0
    add     rdi, rsi

    mulx    rsi, rax, [r9+24]   ; a[3]*b[1]
    adc     r8, r10
    adc     rsi, 0
    add     r8, rax
    adc     rsi, 0
    ; r11=r[2]p, rdi=r[3]p, r8=r[4]p, rsi=r[5]p

    ; Row 2: a × b[2]
    mov     rdx, [rbx+16]
    mulx    rax, r10, [r9]      ; a[0]*b[2]
    add     r11, r10            ; r[2]
    mov     [rcx+16], r11

    mulx    r10, r11, [r9+8]    ; a[1]*b[2]
    adc     rdi, rax
    adc     r8, 0
    adc     rsi, 0
    add     rdi, r11

    mulx    r11, rax, [r9+16]   ; a[2]*b[2]
    adc     r8, r10
    adc     rsi, 0
    add     r8, rax

    mulx    rbp, rax, [r9+24]   ; a[3]*b[2]
    adc     rsi, r11
    adc     rbp, 0
    add     rsi, rax
    adc     rbp, 0
    ; rdi=r[3]p, r8=r[4]p, rsi=r[5]p, rbp=r[6]p

    ; Row 3: a × b[3]
    mov     rdx, [rbx+24]
    mulx    rax, r10, [r9]      ; a[0]*b[3]
    add     rdi, r10            ; r[3]
    mov     [rcx+24], rdi

    mulx    r10, r11, [r9+8]    ; a[1]*b[3]
    adc     r8, rax
    adc     rsi, 0
    adc     rbp, 0
    add     r8, r11             ; r[4]
    mov     [rcx+32], r8

    mulx    r11, rax, [r9+16]   ; a[2]*b[3]
    adc     rsi, r10
    adc     rbp, 0
    add     rsi, rax            ; r[5]
    mov     [rcx+40], rsi

    mulx    rax, r10, [r9+24]   ; a[3]*b[3]
    adc     rbp, r11
    adc     rax, 0
    add     rbp, r10            ; r[6]
    mov     [rcx+48], rbp
    adc     rax, 0              ; r[7]
    mov     [rcx+56], rax

    pop     rbp
    pop     rsi
    pop     rdi
    pop     rbx
    ret
mpn_mul_small_asm ENDP

; =====================================================================
; uint64_t mpn_addmul_1_small_asm(uint64_t* rp, const uint64_t* ap,
;                                   size_t n, uint64_t b)
;
; rp[0..n-1] += ap[0..n-1] * b
; 戻り値: キャリー (rax)
; 前提: 1 <= n <= 4, BMI2 対応 CPU
;
; 汎用版 mpn_addmul_1_mulx の 7 push/pop + ループ制御を排除。
; volatile レジスタのみ使用 (push/pop なし)。
;
; Windows x64: rcx=rp, rdx=ap, r8=n, r9=b
; =====================================================================
mpn_addmul_1_small_asm PROC
    lea     r10, [am1_s_jt]
    jmp     qword ptr [r10 + r8*8 - 8]

am1_s_jt:
    dq      am1_s_1
    dq      am1_s_2
    dq      am1_s_3
    dq      am1_s_4

    ALIGN   16
am1_s_1:
    ; rp[0] += a[0]*b, return carry
    mov     r10, rdx            ; r10 = ap
    mov     rdx, r9             ; rdx = b (MULX)
    mulx    r11, rax, [r10]     ; r11:rax = a[0]*b
    add     rax, [rcx]
    adc     r11, 0
    mov     [rcx], rax
    mov     rax, r11            ; return carry
    ret

    ALIGN   16
am1_s_2:
    mov     r10, rdx            ; r10 = ap
    mov     rdx, r9             ; rdx = b
    mulx    r11, rax, [r10]     ; r11:rax = a[0]*b
    add     rax, [rcx]
    adc     r11, 0
    mov     [rcx], rax

    mulx    r8, rax, [r10+8]    ; r8:rax = a[1]*b
    add     rax, r11
    adc     r8, 0
    add     rax, [rcx+8]
    adc     r8, 0
    mov     [rcx+8], rax
    mov     rax, r8
    ret

    ALIGN   16
am1_s_3:
    mov     r10, rdx            ; r10 = ap
    mov     rdx, r9             ; rdx = b
    mulx    r11, rax, [r10]     ; a[0]*b
    add     rax, [rcx]
    adc     r11, 0
    mov     [rcx], rax

    mulx    r8, rax, [r10+8]    ; a[1]*b
    add     rax, r11
    adc     r8, 0
    add     rax, [rcx+8]
    adc     r8, 0
    mov     [rcx+8], rax

    mulx    r9, rax, [r10+16]   ; a[2]*b
    add     rax, r8
    adc     r9, 0
    add     rax, [rcx+16]
    adc     r9, 0
    mov     [rcx+16], rax
    mov     rax, r9
    ret

    ALIGN   16
am1_s_4:
    mov     r10, rdx            ; r10 = ap
    mov     rdx, r9             ; rdx = b
    mulx    r11, rax, [r10]     ; a[0]*b
    add     rax, [rcx]
    adc     r11, 0
    mov     [rcx], rax

    mulx    r8, rax, [r10+8]    ; a[1]*b
    add     rax, r11
    adc     r8, 0
    add     rax, [rcx+8]
    adc     r8, 0
    mov     [rcx+8], rax

    mulx    r9, rax, [r10+16]   ; a[2]*b
    add     rax, r8
    adc     r9, 0
    add     rax, [rcx+16]
    adc     r9, 0
    mov     [rcx+16], rax

    mulx    r11, rax, [r10+24]  ; a[3]*b
    add     rax, r9
    adc     r11, 0
    add     rax, [rcx+24]
    adc     r11, 0
    mov     [rcx+24], rax
    mov     rax, r11
    ret
mpn_addmul_1_small_asm ENDP

END
