// Copyright (C) 2026 Kiyotsugu Arai // SPDX-License-Identifier: LGPL-3.0-or-later // simd_traits.hpp // SIMD パケット評価の基盤 // // PacketTraits: 型ごとの SIMD パケットサイズとロード/ストア/演算関数 // PacketOp: std::plus 等を SIMD 演算にディスパッチ // // 対応: // AVX2: double x4, float x8 // SSE2: double x2, float x4 // フォールバック: スカラー (size=1) #ifndef CALX_SIMD_TRAITS_HPP #define CALX_SIMD_TRAITS_HPP #include #include #ifdef __AVX2__ #include #elif defined(__SSE2__) || defined(_M_X64) || defined(_M_AMD64) #include #endif namespace calx { // ================================================================ // PacketTraits: 型ごとの SIMD パケット定義 // ================================================================ template struct PacketTraits { // スカラーフォールバック (SIMD 非対応型) using PacketType = T; static constexpr std::size_t size = 1; static PacketType load(const T* p) { return *p; } static void store(T* p, PacketType v) { *p = v; } static PacketType add(PacketType a, PacketType b) { return a + b; } static PacketType sub(PacketType a, PacketType b) { return a - b; } static PacketType mul(PacketType a, PacketType b) { return a * b; } static PacketType div(PacketType a, PacketType b) { return a / b; } static PacketType negate(PacketType a) { return -a; } static PacketType set1(T v) { return v; } static PacketType fmadd(PacketType a, PacketType b, PacketType c) { return a * b + c; } static T reduce_add(PacketType a) { return a; } }; // ---- double 特殊化 ---- #ifdef __AVX2__ template<> struct PacketTraits { using PacketType = __m256d; static constexpr std::size_t size = 4; static PacketType load(const double* p) { return _mm256_loadu_pd(p); } static void store(double* p, PacketType v) { _mm256_storeu_pd(p, v); } static PacketType add(PacketType a, PacketType b) { return _mm256_add_pd(a, b); } static PacketType sub(PacketType a, PacketType b) { return _mm256_sub_pd(a, b); } static PacketType mul(PacketType a, PacketType b) { return _mm256_mul_pd(a, b); } static PacketType div(PacketType a, PacketType b) { return _mm256_div_pd(a, b); } static PacketType negate(PacketType a) { return _mm256_xor_pd(a, _mm256_set1_pd(-0.0)); } static PacketType set1(double v) { return _mm256_set1_pd(v); } // FMA: a * b + c static PacketType fmadd(PacketType a, PacketType b, PacketType c) { return _mm256_fmadd_pd(a, b, c); } // 水平加算: パケット内の全要素の和を返す static double reduce_add(PacketType a) { __m128d lo = _mm256_castpd256_pd128(a); __m128d hi = _mm256_extractf128_pd(a, 1); lo = _mm_add_pd(lo, hi); // [a0+a2, a1+a3] hi = _mm_unpackhi_pd(lo, lo); // [a1+a3, a1+a3] lo = _mm_add_sd(lo, hi); // [a0+a1+a2+a3, ...] return _mm_cvtsd_f64(lo); } }; template<> struct PacketTraits { using PacketType = __m256; static constexpr std::size_t size = 8; static PacketType load(const float* p) { return _mm256_loadu_ps(p); } static void store(float* p, PacketType v) { _mm256_storeu_ps(p, v); } static PacketType add(PacketType a, PacketType b) { return _mm256_add_ps(a, b); } static PacketType sub(PacketType a, PacketType b) { return _mm256_sub_ps(a, b); } static PacketType mul(PacketType a, PacketType b) { return _mm256_mul_ps(a, b); } static PacketType div(PacketType a, PacketType b) { return _mm256_div_ps(a, b); } static PacketType negate(PacketType a) { return _mm256_xor_ps(a, _mm256_set1_ps(-0.0f)); } static PacketType set1(float v) { return _mm256_set1_ps(v); } static PacketType fmadd(PacketType a, PacketType b, PacketType c) { return _mm256_fmadd_ps(a, b, c); } static float reduce_add(PacketType a) { __m128 lo = _mm256_castps256_ps128(a); __m128 hi = _mm256_extractf128_ps(a, 1); lo = _mm_add_ps(lo, hi); // [4 floats] hi = _mm_movehl_ps(lo, lo); // [2+3, 3, ...] lo = _mm_add_ps(lo, hi); // [0+2, 1+3, ...] hi = _mm_shuffle_ps(lo, lo, 1); // [1+3, ...] lo = _mm_add_ss(lo, hi); // [0+1+2+3] return _mm_cvtss_f32(lo); } }; #elif defined(__SSE2__) || defined(_M_X64) || defined(_M_AMD64) template<> struct PacketTraits { using PacketType = __m128d; static constexpr std::size_t size = 2; static PacketType load(const double* p) { return _mm_loadu_pd(p); } static void store(double* p, PacketType v) { _mm_storeu_pd(p, v); } static PacketType add(PacketType a, PacketType b) { return _mm_add_pd(a, b); } static PacketType sub(PacketType a, PacketType b) { return _mm_sub_pd(a, b); } static PacketType mul(PacketType a, PacketType b) { return _mm_mul_pd(a, b); } static PacketType div(PacketType a, PacketType b) { return _mm_div_pd(a, b); } static PacketType negate(PacketType a) { return _mm_xor_pd(a, _mm_set1_pd(-0.0)); } static PacketType set1(double v) { return _mm_set1_pd(v); } static PacketType fmadd(PacketType a, PacketType b, PacketType c) { return _mm_add_pd(_mm_mul_pd(a, b), c); // SSE2 には FMA なし } static double reduce_add(PacketType a) { __m128d hi = _mm_unpackhi_pd(a, a); return _mm_cvtsd_f64(_mm_add_sd(a, hi)); } }; template<> struct PacketTraits { using PacketType = __m128; static constexpr std::size_t size = 4; static PacketType load(const float* p) { return _mm_loadu_ps(p); } static void store(float* p, PacketType v) { _mm_storeu_ps(p, v); } static PacketType add(PacketType a, PacketType b) { return _mm_add_ps(a, b); } static PacketType sub(PacketType a, PacketType b) { return _mm_sub_ps(a, b); } static PacketType mul(PacketType a, PacketType b) { return _mm_mul_ps(a, b); } static PacketType div(PacketType a, PacketType b) { return _mm_div_ps(a, b); } static PacketType negate(PacketType a) { return _mm_xor_ps(a, _mm_set1_ps(-0.0f)); } static PacketType set1(float v) { return _mm_set1_ps(v); } static PacketType fmadd(PacketType a, PacketType b, PacketType c) { return _mm_add_ps(_mm_mul_ps(a, b), c); } static float reduce_add(PacketType a) { __m128 hi = _mm_movehl_ps(a, a); a = _mm_add_ps(a, hi); hi = _mm_shuffle_ps(a, a, 1); return _mm_cvtss_f32(_mm_add_ss(a, hi)); } }; #endif // __AVX2__ / __SSE2__ // ================================================================ // PacketOp: std::plus 等を SIMD 演算に変換 // ================================================================ template struct PacketOp { using PT = PacketTraits; static auto apply(typename PT::PacketType a, typename PT::PacketType b) { // スカラーフォールバック return Op{}(a, b); } }; template struct PacketOp, T> { using PT = PacketTraits; static auto apply(typename PT::PacketType a, typename PT::PacketType b) { return PT::add(a, b); } }; template struct PacketOp, T> { using PT = PacketTraits; static auto apply(typename PT::PacketType a, typename PT::PacketType b) { return PT::sub(a, b); } }; template struct PacketOp, T> { using PT = PacketTraits; static auto apply(typename PT::PacketType a, typename PT::PacketType b) { return PT::mul(a, b); } }; template struct PacketOp, T> { using PT = PacketTraits; static auto apply(typename PT::PacketType a, typename PT::PacketType b) { return PT::div(a, b); } }; // ================================================================ // has_simd_packet: SIMD パケット評価が有効な型かどうか // ================================================================ template inline constexpr bool has_simd_packet_v = (PacketTraits::size > 1); } // namespace calx #endif // CALX_SIMD_TRAITS_HPP