forked from vectorclass/version2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
instrset.h
1476 lines (1362 loc) · 63.9 KB
/
instrset.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/**************************** instrset.h **********************************
* Author: Agner Fog
* Date created: 2012-05-30
* Last modified: 2023-12-02
* Version: 2.02.02
* Project: vector class library
* Description:
* Header file for various compiler-specific tasks as well as common
* macros and templates. This file contains:
*
* > Selection of the supported instruction set
* > Defines compiler version macros
* > Undefines certain macros that prevent function overloading
* > Helper functions that depend on instruction set, compiler, or platform
* > Common templates for permute, blend, etc.
*
* For instructions, see vcl_manual.pdf
*
* (c) Copyright 2012-2023 Agner Fog.
* Apache License version 2.0 or later.
******************************************************************************/
#ifndef INSTRSET_H
#define INSTRSET_H 20200
// check if compiled for C++17
#if defined(_MSVC_LANG) // MS compiler has its own version of __cplusplus with different value
#if _MSVC_LANG < 201703
#error Please compile for C++17 or higher
#endif
#else // all other compilers
#if __cplusplus < 201703
#error Please compile for C++17 or higher
#endif
#endif
// Allow the use of floating point permute instructions on integer vectors.
// Some CPU's have an extra latency of 1 or 2 clock cycles for this, but
// it may still be faster than alternative implementations:
#define ALLOW_FP_PERMUTE true
// Macro to indicate 64 bit mode
#if (defined(_M_AMD64) || defined(_M_X64) || defined(__amd64) ) && ! defined(__x86_64__)
#define __x86_64__ 1 // There are many different macros for this, decide on only one
#endif
// The following values of INSTRSET are currently defined:
// 2: SSE2
// 3: SSE3
// 4: SSSE3
// 5: SSE4.1
// 6: SSE4.2
// 7: AVX
// 8: AVX2
// 9: AVX512F
// 10: AVX512BW/DQ/VL
// In the future, INSTRSET = 11 may include AVX512VBMI and AVX512VBMI2, but this
// decision cannot be made before the market situation for CPUs with these
// instruction sets is better known
// Find instruction set from compiler macros if INSTRSET is not defined.
// Note: Some of these macros are not defined in Microsoft compilers
#ifndef INSTRSET
#if defined ( __AVX512VL__ ) && defined ( __AVX512BW__ ) && defined ( __AVX512DQ__ )
#define INSTRSET 10
#elif defined ( __AVX512F__ ) || defined ( __AVX512__ )
#define INSTRSET 9
#elif defined ( __AVX2__ )
#define INSTRSET 8
#elif defined ( __AVX__ )
#define INSTRSET 7
#elif defined ( __SSE4_2__ )
#define INSTRSET 6
#elif defined ( __SSE4_1__ )
#define INSTRSET 5
#elif defined ( __SSSE3__ )
#define INSTRSET 4
#elif defined ( __SSE3__ )
#define INSTRSET 3
#elif defined ( __SSE2__ ) || defined ( __x86_64__ )
#define INSTRSET 2
#elif defined ( __SSE__ )
#define INSTRSET 1
#elif defined ( _M_IX86_FP ) // Defined in MS compiler. 1: SSE, 2: SSE2
#define INSTRSET _M_IX86_FP
#else
#define INSTRSET 0
#endif // instruction set defines
#endif // INSTRSET
#if INSTRSET >= 8 && !defined(__FMA__)
// Assume that all processors that have AVX2 also have FMA3
#if defined (__GNUC__) && ! defined (__INTEL_COMPILER)
// Prevent error message in g++ and Clang when using FMA intrinsics with avx2:
#if !defined(DISABLE_WARNING_AVX2_WITHOUT_FMA)
#pragma message "It is recommended to specify also option -mfma when using -mavx2 or higher"
#endif
#elif ! defined (__clang__)
#define __FMA__ 1
#endif
#endif
// Header files for non-vector intrinsic functions including _BitScanReverse(int), __cpuid(int[4],int), _xgetbv(int)
#ifdef _MSC_VER // Microsoft compiler or compatible Intel compiler
#include <intrin.h>
#pragma warning(disable: 6323 4514 4710 4711) // Diasble annoying warnings
#else
#include <x86intrin.h> // Gcc or Clang compiler
#endif
#include <stdint.h> // Define integer types with known size
#include <limits.h> // Define INT_MAX
#include <stdlib.h> // define abs(int)
// functions in instrset_detect.cpp:
#ifdef VCL_NAMESPACE
namespace VCL_NAMESPACE {
#endif
int instrset_detect(void); // tells which instruction sets are supported
bool hasFMA3(void); // true if FMA3 instructions supported
bool hasFMA4(void); // true if FMA4 instructions supported
bool hasXOP(void); // true if XOP instructions supported
bool hasAVX512ER(void); // true if AVX512ER instructions supported
bool hasAVX512VBMI(void); // true if AVX512VBMI instructions supported
bool hasAVX512VBMI2(void); // true if AVX512VBMI2 instructions supported
bool hasF16C(void); // true if F16C instructions supported
bool hasAVX512FP16(void); // true if AVX512_FP16 instructions supported
// function in physical_processors.cpp:
int physicalProcessors(int * logical_processors = 0);
#ifdef VCL_NAMESPACE
}
#endif
// GCC version
#if defined(__GNUC__) && !defined (GCC_VERSION) && !defined (__clang__)
#define GCC_VERSION ((__GNUC__) * 10000 + (__GNUC_MINOR__) * 100 + (__GNUC_PATCHLEVEL__))
#endif
// Clang version
#if defined (__clang__)
#define CLANG_VERSION ((__clang_major__) * 10000 + (__clang_minor__) * 100 + (__clang_patchlevel__))
// Problem: The version number is not consistent across platforms
// http://llvm.org/bugs/show_bug.cgi?id=12643
// Apple bug 18746972
#endif
// Fix problem with non-overloadable macros named min and max in WinDef.h
#ifdef _MSC_VER
#if defined (_WINDEF_) && defined(min) && defined(max)
#undef min
#undef max
#endif
#ifndef NOMINMAX
#define NOMINMAX
#endif
// warning for poor support for AVX512F in MS compiler
#if !defined(__INTEL_COMPILER) && !defined(__clang__)
#if INSTRSET == 9
#pragma message("Warning: MS compiler cannot generate code for AVX512F without AVX512DQ")
#endif
#if _MSC_VER < 1920 && INSTRSET > 8
#pragma message("Warning: Your compiler has poor support for AVX512. Code may be erroneous.\nPlease use a newer compiler version or a different compiler!")
#endif
#endif // __INTEL_COMPILER
#endif // _MSC_VER
#if defined(__INTEL_COMPILER) && __INTEL_COMPILER < 2021
#error The Intel compiler version 19.00 cannot compile VCL version 2
#endif
/* Clang problem:
The Clang compiler treats the intrinsic vector types __m128, __m128i, and __m128d as identical.
See the bug report at https://bugs.llvm.org/show_bug.cgi?id=17164
Additional problem: The version number is not consistent across platforms. The Apple build has
different version numbers. We have to rely on __apple_build_version__ on the Mac platform:
http://llvm.org/bugs/show_bug.cgi?id=12643
We have to make switches here when - hopefully - the error some day has been fixed.
We need different version checks with and whithout __apple_build_version__
*/
#if (defined (__clang__) || defined(__apple_build_version__)) && !defined(__INTEL_COMPILER)
#define FIX_CLANG_VECTOR_ALIAS_AMBIGUITY
#endif
#if defined (__GNUC__) && __GNUC__ < 10 && !defined(__clang__)
// Gcc 9 and earlier donot have _mm256_zextsi128_si256 and similar functions for xero-extending vector registers
#define ZEXT_MISSING
#endif
#ifdef VCL_NAMESPACE
namespace VCL_NAMESPACE {
#endif
// Constant for indicating don't care in permute and blend functions.
// V_DC is -256 in Vector class library version 1.xx
// V_DC can be any value less than -1 in Vector class library version 2.00
constexpr int V_DC = -256;
/*****************************************************************************
*
* Helper functions that depend on instruction set, compiler, or platform
*
*****************************************************************************/
// Define interface to cpuid instruction.
// input: functionnumber = leaf (eax), ecxleaf = subleaf(ecx)
// output: output[0] = eax, output[1] = ebx, output[2] = ecx, output[3] = edx
static inline void cpuid(int output[4], int functionnumber, int ecxleaf = 0) {
#if defined(__GNUC__) || defined(__clang__) // use inline assembly, Gnu/AT&T syntax
int a, b, c, d;
__asm("cpuid" : "=a"(a), "=b"(b), "=c"(c), "=d"(d) : "a"(functionnumber), "c"(ecxleaf) : );
output[0] = a;
output[1] = b;
output[2] = c;
output[3] = d;
#elif defined (_MSC_VER) // Microsoft compiler, intrin.h included
__cpuidex(output, functionnumber, ecxleaf); // intrinsic function for CPUID
#else // unknown platform. try inline assembly with masm/intel syntax
__asm {
mov eax, functionnumber
mov ecx, ecxleaf
cpuid;
mov esi, output
mov[esi], eax
mov[esi + 4], ebx
mov[esi + 8], ecx
mov[esi + 12], edx
}
#endif
}
// Define popcount function. Gives sum of bits
#if INSTRSET >= 6 // SSE4.2
// The popcnt instruction is not officially part of the SSE4.2 instruction set,
// but available in all known processors with SSE4.2
static inline uint32_t vml_popcnt(uint32_t a) {
return (uint32_t)_mm_popcnt_u32(a); // Intel intrinsic. Supported by gcc and clang
}
#ifdef __x86_64__
static inline int64_t vml_popcnt(uint64_t a) {
return _mm_popcnt_u64(a); // Intel intrinsic.
}
#else // 32 bit mode
static inline int64_t vml_popcnt(uint64_t a) {
return _mm_popcnt_u32(uint32_t(a >> 32)) + _mm_popcnt_u32(uint32_t(a));
}
#endif
#else // no SSE4.2
static inline uint32_t vml_popcnt(uint32_t a) {
// popcnt instruction not available
uint32_t b = a - ((a >> 1) & 0x55555555);
uint32_t c = (b & 0x33333333) + ((b >> 2) & 0x33333333);
uint32_t d = (c + (c >> 4)) & 0x0F0F0F0F;
uint32_t e = d * 0x01010101;
return e >> 24;
}
static inline int32_t vml_popcnt(uint64_t a) {
return (int32_t)(vml_popcnt(uint32_t(a >> 32)) + vml_popcnt(uint32_t(a)));
}
#endif
// Define bit-scan-forward function. Gives index to lowest set bit
#if (defined (__GNUC__) || defined(__clang__)) && !defined (_MSC_VER)
// _BitScanForward intrinsics are defined only under Windows and only when _MSC_VER is defined
// Use inline assembly for gcc and Clang
#if defined(__clang_major__) && __clang_major__ < 10
// fix bug in Clang version 6. (not detected in version 8 and later)
// Clang version 6 uses a k register as parameter a when inlined from horizontal_find_first
__attribute__((noinline))
#endif
static uint32_t bit_scan_forward(uint32_t a) {
uint32_t r;
__asm("bsfl %1, %0" : "=r"(r) : "r"(a) : );
return r;
}
static inline uint32_t bit_scan_forward(uint64_t a) {
uint32_t lo = uint32_t(a);
if (lo) return bit_scan_forward(lo);
uint32_t hi = uint32_t(a >> 32);
return bit_scan_forward(hi) + 32;
}
#else // MS compatible compilers under Windows
static inline uint32_t bit_scan_forward(uint32_t a) {
unsigned long r;
_BitScanForward(&r, a); // defined in intrin.h for MS and Intel compilers
return r;
}
#ifdef __x86_64__
static inline uint32_t bit_scan_forward(uint64_t a) {
unsigned long r;
_BitScanForward64(&r, a); // defined in intrin.h for MS and Intel compilers
return (uint32_t)r;
}
#else
static inline uint32_t bit_scan_forward(uint64_t a) {
uint32_t lo = uint32_t(a);
if (lo) return bit_scan_forward(lo);
uint32_t hi = uint32_t(a >> 32);
return bit_scan_forward(hi) + 32;
}
#endif
#endif
// Define bit-scan-reverse function. Gives index to highest set bit = floor(log2(a))
#if (defined (__GNUC__) || defined(__clang__)) && !defined (_MSC_VER)
// _BitScanReverse intrinsics are defined only under Windows and only when _MSC_VER is defined
// Use inline assembly for gcc and Clang
static inline uint32_t bit_scan_reverse(uint32_t a) __attribute__((pure));
static inline uint32_t bit_scan_reverse(uint32_t a) {
uint32_t r;
__asm("bsrl %1, %0" : "=r"(r) : "r"(a) : );
return r;
}
#ifdef __x86_64__
static inline uint32_t bit_scan_reverse(uint64_t a) {
uint64_t r;
__asm("bsrq %1, %0" : "=r"(r) : "r"(a) : );
return uint32_t(r);
}
#else // 32 bit mode
static inline uint32_t bit_scan_reverse(uint64_t a) {
uint64_t ahi = a >> 32;
if (ahi == 0) return bit_scan_reverse(uint32_t(a));
else return bit_scan_reverse(uint32_t(ahi)) + 32;
}
#endif
#else // MS compatible compilers under Windows
static inline uint32_t bit_scan_reverse(uint32_t a) {
unsigned long r;
_BitScanReverse(&r, a); // defined in intrin.h for MS compatible compilers
return r;
}
#ifdef __x86_64__
static inline uint32_t bit_scan_reverse(uint64_t a) {
unsigned long r;
_BitScanReverse64(&r, a); // defined in intrin.h for MS compatible compilers
return r;
}
#else // 32 bit mode
static inline uint32_t bit_scan_reverse(uint64_t a) {
uint64_t ahi = a >> 32;
if (ahi == 0) return bit_scan_reverse(uint32_t(a));
else return bit_scan_reverse(uint32_t(ahi)) + 32;
}
#endif
#endif
// Same function, for compile-time constants
constexpr int bit_scan_reverse_const(uint64_t const n) {
if (n == 0) return -1;
uint64_t a = n, b = 0, j = 64, k = 0;
do {
j >>= 1;
k = (uint64_t)1 << j;
if (a >= k) {
a >>= j;
b += j;
}
} while (j > 0);
return int(b);
}
/*****************************************************************************
*
* Common templates
*
*****************************************************************************/
#ifdef VCL_NAMESPACE
#define NAMESPACEPREFIX VCL_NAMESPACE::
#else
#define NAMESPACEPREFIX
#endif
template <int32_t n> class Const_int_t {}; // represent compile-time signed integer constant
template <uint32_t n> class Const_uint_t {}; // represent compile-time unsigned integer constant
#define const_int(n) (NAMESPACEPREFIX Const_int_t <n>()) // n must be compile-time integer constant
#define const_uint(n) (NAMESPACEPREFIX Const_uint_t<n>()) // n must be compile-time unsigned integer constant
// template for producing quiet NAN
template <class VTYPE>
static inline VTYPE nan_vec(uint32_t payload = 0x100) {
if constexpr (VTYPE::elementtype() == 17) { // double
union {
uint64_t q;
double f;
} ud;
// n is left justified to avoid loss of NAN payload when converting to float
ud.q = 0x7FF8000000000000 | uint64_t(payload) << 29;
return VTYPE(ud.f);
}
if constexpr (VTYPE::elementtype() == 16) { // float
union {
uint32_t i;
float f;
} uf;
uf.i = 0x7FC00000 | (payload & 0x003FFFFF);
return VTYPE(uf.f);
}
/* // defined in vectorfp16.h
if constexpr (VTYPE::elementtype() == 15) { // _Float16
union {
uint16_t i;
_Float16 f; // error if _Float16 not defined
} uf;
uf.i = 0x7C00 | (payload & 0x03FF);
return VTYPE(uf.f);
} */
}
// Test if a parameter is a compile-time constant
/* Unfortunately, this works only for macro parameters, not for inline function parameters.
I hope that some solution will appear in the future, but for now it appears to be
impossible to check if a function parameter is a compile-time constant.
This would be useful in operator / and in function pow:
#if defined(__GNUC__) || defined (__clang__)
#define is_constant(a) __builtin_constant_p(a)
#else
#define is_constant(a) false
#endif
*/
/*****************************************************************************
*
* Helper functions for permute and blend functions
*
******************************************************************************
Rules for constexpr functions:
> All variable declarations must include initialization
> Do not put variable declarations inside a for-clause, e.g. avoid: for (int i=0; ..
Instead, you have to declare the loop counter before the for-loop.
> Do not make constexpr functions that return vector types. This requires type
punning with a union, which is not allowed in constexpr functions under C++17.
*****************************************************************************/
// Define type for encapsulated array to use as return type:
template <typename T, int N>
struct EList {
T a[N];
};
// get_inttype: get an integer of a size that matches the element size
// of vector class V with the value -1
template <typename V>
constexpr auto get_inttype() {
constexpr int elementsize = int(sizeof(V) / V::size()); // size of vector elements
if constexpr (elementsize >= 8) {
return -int64_t(1);
}
else if constexpr (elementsize >= 4) {
return int32_t(-1);
}
else if constexpr (elementsize >= 2) {
return int16_t(-1);
}
else {
return int8_t(-1);
}
}
// zero_mask: return a compact bit mask mask for zeroing using AVX512 mask.
// Parameter a is a reference to a constexpr int array of permutation indexes
template <int N>
constexpr auto zero_mask(int const (&a)[N]) {
uint64_t mask = 0;
int i = 0;
for (i = 0; i < N; i++) {
if (a[i] >= 0) mask |= uint64_t(1) << i;
}
if constexpr (N <= 8 ) return uint8_t(mask);
else if constexpr (N <= 16) return uint16_t(mask);
else if constexpr (N <= 32) return uint32_t(mask);
else return mask;
}
// zero_mask_broad: return a broad byte mask for zeroing.
// Parameter a is a reference to a constexpr int array of permutation indexes
template <typename V>
constexpr auto zero_mask_broad(int const (&A)[V::size()]) {
constexpr int N = V::size(); // number of vector elements
typedef decltype(get_inttype<V>()) Etype; // element type
EList <Etype, N> u = {{0}}; // list for return
int i = 0;
for (i = 0; i < N; i++) {
u.a[i] = A[i] >= 0 ? get_inttype<V>() : 0;
}
return u; // return encapsulated array
}
// make_bit_mask: return a compact mask of bits from a list of N indexes:
// B contains options indicating how to gather the mask
// bit 0-7 in B indicates which bit in each index to collect
// bit 8 = 0x100: set 1 in the lower half of the bit mask if the indicated bit is 1.
// bit 8 = 0 : set 1 in the lower half of the bit mask if the indicated bit is 0.
// bit 9 = 0x200: set 1 in the upper half of the bit mask if the indicated bit is 1.
// bit 9 = 0 : set 1 in the upper half of the bit mask if the indicated bit is 0.
// bit 10 = 0x400: set 1 in the bit mask if the corresponding index is -1 or V_DC
// Parameter a is a reference to a constexpr int array of permutation indexes
template <int N, int B>
constexpr uint64_t make_bit_mask(int const (&a)[N]) {
uint64_t r = 0; // return value
uint8_t j = uint8_t(B & 0xFF); // index to selected bit
uint64_t s = 0; // bit number i in r
uint64_t f = 0; // 1 if bit not flipped
int i = 0;
for (i = 0; i < N; i++) {
int ix = a[i];
if (ix < 0) { // -1 or V_DC
s = (B >> 10) & 1;
}
else {
s = ((uint32_t)ix >> j) & 1; // extract selected bit
if (i < N/2) {
f = (B >> 8) & 1; // lower half
}
else {
f = (B >> 9) & 1; // upper half
}
s ^= f ^ 1; // flip bit if needed
}
r |= uint64_t(s) << i; // set bit in return value
}
return r;
}
// make_broad_mask: Convert a bit mask m to a broad mask
// The return value will be a broad boolean mask with elementsize matching vector class V
template <typename V>
constexpr auto make_broad_mask(uint64_t const m) {
constexpr int N = V::size(); // number of vector elements
typedef decltype(get_inttype<V>()) Etype; // element type
EList <Etype, N> u = {{0}}; // list for returning
int i = 0;
for (i = 0; i < N; i++) {
u.a[i] = ((m >> i) & 1) != 0 ? get_inttype<V>() : 0;
}
return u; // return encapsulated array
}
// perm_mask_broad: return a mask for permutation by a vector register index.
// Parameter A is a reference to a constexpr int array of permutation indexes
template <typename V>
constexpr auto perm_mask_broad(int const (&A)[V::size()]) {
constexpr int N = V::size(); // number of vector elements
typedef decltype(get_inttype<V>()) Etype; // vector element type
EList <Etype, N> u = {{0}}; // list for returning
int i = 0;
for (i = 0; i < N; i++) {
u.a[i] = Etype(A[i]);
}
return u; // return encapsulated array
}
// perm_flags: returns information about how a permute can be implemented.
// The return value is composed of these flag bits:
const int perm_zeroing = 1; // needs zeroing
const int perm_perm = 2; // permutation needed
const int perm_allzero = 4; // all is zero or don't care
const int perm_largeblock = 8; // fits permute with a larger block size (e.g permute Vec2q instead of Vec4i)
const int perm_addz = 0x10; // additional zeroing needed after permute with larger block size or shift
const int perm_addz2 = 0x20; // additional zeroing needed after perm_zext, perm_compress, or perm_expand
const int perm_cross_lane = 0x40; // permutation crossing 128-bit lanes
const int perm_same_pattern = 0x80; // same permute pattern in all 128-bit lanes
const int perm_punpckh = 0x100; // permutation pattern fits punpckh instruction
const int perm_punpckl = 0x200; // permutation pattern fits punpckl instruction
const int perm_rotate = 0x400; // permutation pattern fits 128-bit rotation within lanes. 4 bit byte count returned in bit perm_rot_count
const int perm_swap = 0x800; // permutation pattern fits swap of adjacent vector elements
const int perm_shright = 0x1000; // permutation pattern fits shift right within lanes. 4 bit count returned in bit perm_rot_count
const int perm_shleft = 0x2000; // permutation pattern fits shift left within lanes. negative count returned in bit perm_rot_count
const int perm_rotate_big = 0x4000; // permutation pattern fits rotation across lanes. 6 bit count returned in bit perm_rot_count
const int perm_broadcast = 0x8000; // permutation pattern fits broadcast of a single element.
const int perm_zext = 0x10000; // permutation pattern fits zero extension
const int perm_compress = 0x20000; // permutation pattern fits vpcompress instruction
const int perm_expand = 0x40000; // permutation pattern fits vpexpand instruction
const int perm_outofrange = 0x10000000; // index out of range
const int perm_rot_count = 32; // rotate or shift count is in bits perm_rot_count to perm_rot_count+3
const int perm_ipattern = 40; // pattern for pshufd is in bit perm_ipattern to perm_ipattern + 7 if perm_same_pattern and elementsize >= 4
template <typename V>
constexpr uint64_t perm_flags(int const (&a)[V::size()]) {
// a is a reference to a constexpr array of permutation indexes
// V is a vector class
constexpr int N = V::size(); // number of elements
uint64_t r = perm_largeblock | perm_same_pattern | perm_allzero; // return value
uint32_t i = 0; // loop counter
int j = 0; // loop counter
int ix = 0; // index number i
const uint32_t nlanes = sizeof(V) / 16; // number of 128-bit lanes
const uint32_t lanesize = N / nlanes; // elements per lane
const uint32_t elementsize = sizeof(V) / N; // size of each vector element
uint32_t lane = 0; // current lane
uint32_t rot = 999; // rotate left count
int32_t broadc = 999; // index to broadcasted element
uint32_t patfail = 0; // remember certain patterns that do not fit
uint32_t addz2 = 0; // remember certain patterns need extra zeroing
int32_t compresslasti = -1; // last index in perm_compress fit
int32_t compresslastp = -1; // last position in perm_compress fit
int32_t expandlasti = -1; // last index in perm_expand fit
int32_t expandlastp = -1; // last position in perm_expand fit
int lanepattern[lanesize] = {0}; // pattern in each lane
for (i = 0; i < N; i++) { // loop through indexes
ix = a[i]; // current index
// meaning of ix: -1 = set to zero, V_DC = don't care, non-negative value = permute.
if (ix == -1) {
r |= perm_zeroing; // zeroing requested
}
else if (ix != V_DC && uint32_t(ix) >= N) {
r |= perm_outofrange; // index out of range
}
if (ix >= 0) {
r &= ~ perm_allzero; // not all zero
if (ix != (int)i) r |= perm_perm; // needs permutation
if (broadc == 999) broadc = ix; // remember broadcast index
else if (broadc != ix) broadc = 1000; // does not fit broadcast
}
// check if pattern fits a larger block size:
// even indexes must be even, odd indexes must fit the preceding even index + 1
if ((i & 1) == 0) { // even index
if (ix >= 0 && (ix & 1)) r &= ~perm_largeblock;// not even. does not fit larger block size
int iy = a[i + 1]; // next odd index
if (iy >= 0 && (iy & 1) == 0) r &= ~ perm_largeblock; // not odd. does not fit larger block size
if (ix >= 0 && iy >= 0 && iy != ix+1) r &= ~ perm_largeblock; // does not fit preceding index + 1
if (ix == -1 && iy >= 0) r |= perm_addz; // needs additional zeroing at current block size
if (iy == -1 && ix >= 0) r |= perm_addz; // needs additional zeroing at current block size
}
lane = i / lanesize; // current lane
if (lane == 0) { // first lane, or no pattern yet
lanepattern[i] = ix; // save pattern
}
// check if crossing lanes
if (ix >= 0) {
uint32_t lanei = (uint32_t)ix / lanesize; // source lane
if (lanei != lane) r |= perm_cross_lane; // crossing lane
}
// check if same pattern in all lanes
if (lane != 0 && ix >= 0) { // not first lane
int j1 = int(i - int(lane * lanesize)); // index into lanepattern
int jx = int(ix - int(lane * lanesize)); // pattern within lane
if (jx < 0 || jx >= (int)lanesize) r &= ~perm_same_pattern; // source is in another lane
if (lanepattern[j1] < 0) {
lanepattern[j1] = jx; // pattern not known from previous lane
}
else {
if (lanepattern[j1] != jx) r &= ~perm_same_pattern; // not same pattern
}
}
if (ix >= 0) {
// check if pattern fits zero extension (perm_zext)
if (uint32_t(ix*2) != i) {
patfail |= 1; // does not fit zero extension
}
// check if pattern fits compress (perm_compress)
if (ix > compresslasti && ix - compresslasti >= (int)i - compresslastp) {
if ((int)i - compresslastp > 1) addz2 |= 2;// perm_compress may need additional zeroing
compresslasti = ix; compresslastp = int(i);
}
else {
patfail |= 2; // does not fit perm_compress
}
// check if pattern fits expand (perm_expand)
if (ix > expandlasti && ix - expandlasti <= (int)i - expandlastp) {
if (ix - expandlasti > 1) addz2 |= 4; // perm_expand may need additional zeroing
expandlasti = ix; expandlastp = int(i);
}
else {
patfail |= 4; // does not fit perm_compress
}
}
else if (ix == -1) {
if ((i & 1) == 0) addz2 |= 1; // zero extension needs additional zeroing
}
}
if (!(r & perm_perm)) return r; // more checks are superfluous
if (!(r & perm_largeblock)) r &= ~ perm_addz; // remove irrelevant flag
if (r & perm_cross_lane) r &= ~ perm_same_pattern; // remove irrelevant flag
if ((patfail & 1) == 0) {
r |= perm_zext; // fits zero extension
if ((addz2 & 1) != 0) r |= perm_addz2;
}
else if ((patfail & 2) == 0) {
r |= perm_compress; // fits compression
if ((addz2 & 2) != 0) { // check if additional zeroing needed
for (j = 0; j < compresslastp; j++) {
if (a[j] == -1) r |= perm_addz2;
}
}
}
else if ((patfail & 4) == 0) {
r |= perm_expand; // fits expansion
if ((addz2 & 4) != 0) { // check if additional zeroing needed
for (j = 0; j < expandlastp; j++) {
if (a[j] == -1) r |= perm_addz2;
}
}
}
if (r & perm_same_pattern) {
// same pattern in all lanes. check if it fits specific patterns
bool fit = true; // fits perm_rotate
bool fitswap = true; // fits perm_swap
// fit shift or rotate
for (i = 0; i < lanesize; i++) {
if (lanepattern[i] >= 0) {
uint32_t rot1 = uint32_t(lanepattern[i] + lanesize - i) % lanesize;
if (rot == 999) {
rot = rot1;
}
else { // check if fit
if (rot != rot1) fit = false;
}
if ((uint32_t)lanepattern[i] != (i ^ 1)) fitswap = false;
}
}
rot &= lanesize-1; // prevent out of range values
if (fitswap) r |= perm_swap;
if (fit) { // fits rotate, and possibly shift
uint64_t rot2 = (rot * elementsize) & 0xF; // rotate right count in bytes
r |= rot2 << perm_rot_count; // put shift/rotate count in output bit 16-19
#if INSTRSET >= 4 // SSSE3
r |= perm_rotate; // allow palignr
#endif
// fit shift left
fit = true;
for (i = 0; i < lanesize-rot; i++) { // check if first rot elements are zero or don't care
if (lanepattern[i] >= 0) fit = false;
}
if (fit) {
r |= perm_shleft;
for (; i < lanesize; i++) if (lanepattern[i] == -1) r |= perm_addz; // additional zeroing needed
}
// fit shift right
fit = true;
for (i = lanesize-(uint32_t)rot; i < lanesize; i++) { // check if last (lanesize-rot) elements are zero or don't care
if (lanepattern[i] >= 0) fit = false;
}
if (fit) {
r |= perm_shright;
for (i = 0; i < lanesize-rot; i++) {
if (lanepattern[i] == -1) r |= perm_addz; // additional zeroing needed
}
}
}
// fit punpckhi
fit = true;
uint32_t j2 = lanesize / 2;
for (i = 0; i < lanesize; i++) {
if (lanepattern[i] >= 0 && lanepattern[i] != (int)j2) fit = false;
if ((i & 1) != 0) j2++;
}
if (fit) r |= perm_punpckh;
// fit punpcklo
fit = true;
j2 = 0;
for (i = 0; i < lanesize; i++) {
if (lanepattern[i] >= 0 && lanepattern[i] != (int)j2) fit = false;
if ((i & 1) != 0) j2++;
}
if (fit) r |= perm_punpckl;
// fit pshufd
if constexpr (elementsize >= 4) {
uint32_t p = 0;
for (i = 0; i < lanesize; i++) {
if constexpr (lanesize == 4) {
p |= (lanepattern[i] & 3) << 2 * i;
}
else { // lanesize = 2
p |= ((lanepattern[i] & 1) * 10 + 4) << 4 * i;
}
}
r |= (uint64_t)p << perm_ipattern;
}
}
#if INSTRSET >= 7
else { // not same pattern in all lanes
if constexpr (nlanes > 1) { // Try if it fits big rotate
for (i = 0; i < N; i++) {
ix = a[i];
if (ix >= 0) {
uint32_t rot2 = (ix + N - i) % N; // rotate count
if (rot == 999) {
rot = rot2; // save rotate count
}
else if (rot != rot2) {
rot = 1000; break; // does not fit big rotate
}
}
}
if (rot < N) { // fits big rotate
r |= perm_rotate_big | (uint64_t)rot << perm_rot_count;
}
}
}
#endif
if (broadc < 999 && (r & (perm_rotate|perm_shright|perm_shleft|perm_rotate_big)) == 0) {
r |= perm_broadcast | (uint64_t)broadc << perm_rot_count; // fits broadcast
}
return r;
}
// compress_mask: returns a bit mask to use for compression instruction.
// It is presupposed that perm_flags indicates perm_compress.
// Additional zeroing is needed if perm_flags indicates perm_addz2
template <int N>
constexpr uint64_t compress_mask(int const (&a)[N]) {
// a is a reference to a constexpr array of permutation indexes
int ix = 0, lasti = -1, lastp = -1;
uint64_t m = 0;
int i = 0; int j = 1; // loop counters
for (i = 0; i < N; i++) {
ix = a[i]; // permutation index
if (ix >= 0) {
m |= (uint64_t)1 << ix; // mask for compression source
for (j = 1; j < i - lastp; j++) {
m |= (uint64_t)1 << (lasti + j); // dummy filling source
}
lastp = i; lasti = ix;
}
}
return m;
}
// expand_mask: returns a bit mask to use for expansion instruction.
// It is presupposed that perm_flags indicates perm_expand.
// Additional zeroing is needed if perm_flags indicates perm_addz2
template <int N>
constexpr uint64_t expand_mask(int const (&a)[N]) {
// a is a reference to a constexpr array of permutation indexes
int ix = 0, lasti = -1, lastp = -1;
uint64_t m = 0;
int i = 0; int j = 1;
for (i = 0; i < N; i++) {
ix = a[i]; // permutation index
if (ix >= 0) {
m |= (uint64_t)1 << i; // mask for expansion destination
for (j = 1; j < ix - lasti; j++) {
m |= (uint64_t)1 << (lastp + j); // dummy filling destination
}
lastp = i; lasti = ix;
}
}
return m;
}
// perm16_flags: returns information about how to permute a vector of 16-bit integers
// Note: It is presupposed that perm_flags reports perm_same_pattern
// The return value is composed of these bits:
// 1: data from low 64 bits to low 64 bits. pattern in bit 32-39
// 2: data from high 64 bits to high 64 bits. pattern in bit 40-47
// 4: data from high 64 bits to low 64 bits. pattern in bit 48-55
// 8: data from low 64 bits to high 64 bits. pattern in bit 56-63
template <typename V>
constexpr uint64_t perm16_flags(int const (&a)[V::size()]) {
// a is a reference to a constexpr array of permutation indexes
// V is a vector class
constexpr int N = V::size(); // number of elements
uint64_t retval = 0; // return value
uint32_t pat[4] = {0,0,0,0}; // permute patterns
uint32_t i = 0; // loop counter
int ix = 0; // index number i
const uint32_t lanesize = 8; // elements per lane
uint32_t lane = 0; // current lane
int lanepattern[lanesize] = {0}; // pattern in each lane
for (i = 0; i < N; i++) {
ix = a[i];
lane = i / lanesize; // current lane
if (lane == 0) {
lanepattern[i] = ix; // save pattern
}
else if (ix >= 0) { // not first lane
uint32_t j = i - lane * lanesize; // index into lanepattern
int jx = int(ix - lane * lanesize); // pattern within lane
if (lanepattern[j] < 0) {
lanepattern[j] = jx; // pattern not known from previous lane
}
}
}
// four patterns: low2low, high2high, high2low, low2high
for (i = 0; i < 4; i++) {
// loop through low pattern
if (lanepattern[i] >= 0) {
if (lanepattern[i] < 4) { // low2low
retval |= 1;
pat[0] |= uint32_t(lanepattern[i] & 3) << (2 * i);
}
else { // high2low
retval |= 4;
pat[2] |= uint32_t(lanepattern[i] & 3) << (2 * i);
}
}
// loop through high pattern
if (lanepattern[i+4] >= 0) {
if (lanepattern[i+4] < 4) { // low2high
retval |= 8;
pat[3] |= uint32_t(lanepattern[i+4] & 3) << (2 * i);
}
else { // high2high
retval |= 2;
pat[1] |= uint32_t(lanepattern[i+4] & 3) << (2 * i);
}
}
}
// join return data
for (i = 0; i < 4; i++) {
retval |= (uint64_t)pat[i] << (32 + i*8);
}
return retval;
}
// pshufb_mask: return a broad byte mask for permutation within lanes
// for use with the pshufb instruction (_mm..._shuffle_epi8).
// The pshufb instruction provides fast permutation and zeroing,
// allowing different patterns in each lane but no crossing of lane boundaries
template <typename V, int oppos = 0>
constexpr auto pshufb_mask(int const (&A)[V::size()]) {
// Parameter a is a reference to a constexpr array of permutation indexes
// V is a vector class
// oppos = 1 for data from the opposite 128-bit lane in 256-bit vectors
constexpr uint32_t N = uint32_t(V::size()); // number of vector elements
constexpr uint32_t elementsize = sizeof(V) / N; // size of each vector element
constexpr uint32_t nlanes = sizeof(V) / 16; // number of 128 bit lanes in vector
constexpr uint32_t elements_per_lane = N / nlanes; // number of vector elements per lane
EList <int8_t, sizeof(V)> u = {{0}}; // list for returning
uint32_t i = 0; // loop counters
uint32_t j = 0;
int m = 0;
int k = 0;
uint32_t lane = 0;
for (lane = 0; lane < nlanes; lane++) { // loop through lanes
for (i = 0; i < elements_per_lane; i++) { // loop through elements in lane
// permutation index for element within lane
int8_t p = -1;
int ix = A[m];
if (ix >= 0) {
ix ^= oppos * elements_per_lane; // flip bit if opposite lane
}
ix -= int(lane * elements_per_lane); // index relative to lane
if (ix >= 0 && ix < (int)elements_per_lane) { // index points to desired lane
p = int8_t(ix * elementsize);
}
for (j = 0; j < elementsize; j++) { // loop through bytes in element
u.a[k++] = int8_t(p < 0 ? -1 : p + j); // store byte permutation index
}
m++;
}
}
return u; // return encapsulated array
}
// largeblock_perm: return indexes for replacing a permute or blend with
// a certain block size by a permute or blend with the double block size.
// Note: it is presupposed that perm_flags() indicates perm_largeblock
// It is required that additional zeroing is added if perm_flags() indicates perm_addz
template <int N>
constexpr EList<int, N/2> largeblock_perm(int const (&a)[N]) {
// Parameter a is a reference to a constexpr array of permutation indexes
EList<int, N/2> list = {{0}}; // result indexes