diff --git a/3rdparty/llama.cpp b/3rdparty/llama.cpp index 5eb47b721..1f86f058d 160000 --- a/3rdparty/llama.cpp +++ b/3rdparty/llama.cpp @@ -1 +1 @@ -Subproject commit 5eb47b72106e3b35f10e8befa616a9241242b226 +Subproject commit 1f86f058de0c3f4098dedae2ae8653c335c868a1 diff --git a/README.md b/README.md index 4af4626b6..0f494bd55 100644 --- a/README.md +++ b/README.md @@ -4,16 +4,16 @@ [BitNet Model on Hugging Face](https://huggingface.co/microsoft/BitNet-b1.58-2B-4T) -Try it out via this [demo](https://bitnet-demo.azurewebsites.net/), or [build and run](https://github.com/microsoft/BitNet?tab=readme-ov-file#build-from-source) it on your own CPU. +Try it out via this [demo](https://bitnet-demo.azurewebsites.net/), or build and run it on your own [CPU](https://github.com/microsoft/BitNet?tab=readme-ov-file#build-from-source) or [GPU](https://github.com/microsoft/BitNet/blob/main/gpu/README.md). -bitnet.cpp is the official inference framework for 1-bit LLMs (e.g., BitNet b1.58). It offers a suite of optimized kernels, that support **fast** and **lossless** inference of 1.58-bit models on CPU (with NPU and GPU support coming next). +bitnet.cpp is the official inference framework for 1-bit LLMs (e.g., BitNet b1.58). It offers a suite of optimized kernels, that support **fast** and **lossless** inference of 1.58-bit models on CPU and GPU (NPU support will coming next). The first release of bitnet.cpp is to support inference on CPUs. bitnet.cpp achieves speedups of **1.37x** to **5.07x** on ARM CPUs, with larger models experiencing greater performance gains. Additionally, it reduces energy consumption by **55.4%** to **70.0%**, further boosting overall efficiency. On x86 CPUs, speedups range from **2.37x** to **6.17x** with energy reductions between **71.9%** to **82.2%**. Furthermore, bitnet.cpp can run a 100B BitNet b1.58 model on a single CPU, achieving speeds comparable to human reading (5-7 tokens per second), significantly enhancing the potential for running LLMs on local devices. Please refer to the [technical report](https://arxiv.org/abs/2410.16144) for more details. -m2_performance -m2_performance +**Latest optimization** introduces parallel kernel implementations with configurable tiling and embedding quantization support, achieving **1.15x to 2.1x** additional speedup over the original implementation across different hardware platforms and workloads. For detailed technical information, see the [optimization guide](src/README.md). + +performance_comparison ->The tested models are dummy setups used in a research context to demonstrate the inference performance of bitnet.cpp. ## Demo @@ -22,7 +22,9 @@ A demo of bitnet.cpp running a BitNet b1.58 3B model on Apple M2: https://github.com/user-attachments/assets/7f46b736-edec-4828-b809-4be780a3e5b1 ## What's New: -- 04/14/2025 [BitNet Official 2B Parameter Model on Hugging Face](https://huggingface.co/microsoft/BitNet-b1.58-2B-4T) ![NEW](https://img.shields.io/badge/NEW-red) +- 01/15/2026 [BitNet CPU Inference Optimization](https://github.com/microsoft/BitNet/blob/main/src/README.md) ![NEW](https://img.shields.io/badge/NEW-red) +- 05/20/2025 [BitNet Official GPU inference kernel](https://github.com/microsoft/BitNet/blob/main/gpu/README.md) +- 04/14/2025 [BitNet Official 2B Parameter Model on Hugging Face](https://huggingface.co/microsoft/BitNet-b1.58-2B-4T) - 02/18/2025 [Bitnet.cpp: Efficient Edge Inference for Ternary LLMs](https://arxiv.org/abs/2502.11880) - 11/08/2024 [BitNet a4.8: 4-bit Activations for 1-bit LLMs](https://arxiv.org/abs/2411.04965) - 10/21/2024 [1-bit AI Infra: Part 1.1, Fast and Lossless BitNet b1.58 Inference on CPUs](https://arxiv.org/abs/2410.16144) @@ -136,6 +138,20 @@ This project is based on the [llama.cpp](https://github.com/ggerganov/llama.cpp) ✅ ❌ + + Falcon-E Family + 1B-3B + x86 + ✅ + ❌ + ✅ + + + ARM + ✅ + ✅ + ❌ + @@ -277,6 +293,17 @@ python utils/generate-dummy-bitnet-model.py models/bitnet_b1_58-large --outfile # Run benchmark with the generated model, use -m to specify the model path, -p to specify the prompt processed, -n to specify the number of token to generate python utils/e2e_benchmark.py -m models/dummy-bitnet-125m.tl1.gguf -p 512 -n 128 ``` + +### Convert from `.safetensors` Checkpoints + +```sh +# Prepare the .safetensors model file +huggingface-cli download microsoft/bitnet-b1.58-2B-4T-bf16 --local-dir ./models/bitnet-b1.58-2B-4T-bf16 + +# Convert to gguf model +python ./utils/convert-helper-bitnet.py ./models/bitnet-b1.58-2B-4T-bf16 +``` + ### FAQ (Frequently Asked Questions)📌 #### Q1: The build dies with errors building llama.cpp due to issues with std::chrono in log.cpp? diff --git a/assets/intel_performance.jpg b/assets/intel_performance.jpg deleted file mode 100644 index 38a1bcf5c..000000000 Binary files a/assets/intel_performance.jpg and /dev/null differ diff --git a/assets/m2_performance.jpg b/assets/m2_performance.jpg deleted file mode 100644 index 9b5934868..000000000 Binary files a/assets/m2_performance.jpg and /dev/null differ diff --git a/assets/performance.png b/assets/performance.png new file mode 100644 index 000000000..078fd3f73 Binary files /dev/null and b/assets/performance.png differ diff --git a/gpu/README.md b/gpu/README.md index 3fcc59524..da4b25925 100755 --- a/gpu/README.md +++ b/gpu/README.md @@ -73,7 +73,9 @@ It significantly improves GEMV throughput when processing quantized weights and ## Performance -Kernel performance (tested on NVIDIA A100 40GB GPU): +### Kernel Benchmarks + +Tested on NVIDIA A100 40GB GPU, our custom W2A8 kernel shows significant speedups over standard BF16 implementations: | Shape (N×K) | W2A8 Latency (us) | BF16 Latency (us) | Speedup Ratio | |---------------------|-------------------|-------------------|----------------------| @@ -86,8 +88,20 @@ Kernel performance (tested on NVIDIA A100 40GB GPU): | 3200 × 10240 | 19.64 | 60.79 | 3.10 | | 20480 × 3200 | 30.99 | 112.39 | 3.63 | -Generation throughput: +### End-to-End Generation Latency + +Compared to a similarly-sized BF16 model (Gemma-2-2B using vLLM), BitNet-b1.58-2B with our kernel achieves consistent speedups across workloads: + +| Input Length | Output Length | BF16 Latency (ms) | W2A8 Latency (ms) | Speedup Ratio | +| --- | --- | --- | --- | --- | +| 64 | 16 | 187.64 | 57.40 | 3.27 | +| 64 | 32 | 353.50 | 112.22 | 3.15 | +| 64 | 64 | 683.23 | 221.08 | 3.09 | +| 256 | 16 | 183.14 | 61.24 | 2.99 | +| 256 | 32 | 353.14 | 115.47 | 3.06 | +| 256 | 64 | 684.24 | 224.16 | 3.05 | +| 512 | 16 | 208.99 | 68.06 | 3.07 | +| 512 | 32 | 354.33 | 122.72 | 2.89 | +| 512 | 64 | 709.65 | 231.82 | 3.06 | -| BF16 (tokens/s) | W2A8 (tokens/s) | Speedup Ratio | -|---|---|---| -| 10.9 | 213.3 | 19.6 | \ No newline at end of file +*Note: Comparison uses equivalent-sized models (2B parameters) on NVIDIA A100 40GB GPU.* \ No newline at end of file diff --git a/include/gemm-config.h b/include/gemm-config.h new file mode 100644 index 000000000..6a88c4248 --- /dev/null +++ b/include/gemm-config.h @@ -0,0 +1,35 @@ +#define ACT_PARALLEL +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) +#if defined(ACT_PARALLEL) + #define ROW_BLOCK_SIZE 4 + #define COL_BLOCK_SIZE 128 + #define PARALLEL_SIZE 4 +#else + #define ROW_BLOCK_SIZE 128 + #define COL_BLOCK_SIZE 32 + #define PARALLEL_SIZE 8 +#endif // ACT_PARALLEL +#elif defined(__ARM_NEON) +#if defined(__ARM_FEATURE_DOTPROD) +#if defined(ACT_PARALLEL) + #define ROW_BLOCK_SIZE 8 + #define COL_BLOCK_SIZE 256 + #define PARALLEL_SIZE 8 +#else + #define ROW_BLOCK_SIZE 64 + #define COL_BLOCK_SIZE 16 + #define PARALLEL_SIZE 2 +#endif // ACT_PARALLEL +#else +#if defined(ACT_PARALLEL) + #define ROW_BLOCK_SIZE 8 + #define COL_BLOCK_SIZE 256 + #define PARALLEL_SIZE 4 +#else + #define ROW_BLOCK_SIZE 128 + #define COL_BLOCK_SIZE 32 + #define PARALLEL_SIZE 4 +#endif // ACT_PARALLEL +#endif // __ARM_FEATURE_DOTPROD +#endif // __AVX__ + diff --git a/setup_env.py b/setup_env.py index dfad6c3e7..3bf5fb8f7 100644 --- a/setup_env.py +++ b/setup_env.py @@ -44,6 +44,18 @@ "microsoft/BitNet-b1.58-2B-4T": { "model_name": "BitNet-b1.58-2B-4T", }, + "tiiuae/Falcon-E-3B-Instruct": { + "model_name": "Falcon-E-3B-Instruct", + }, + "tiiuae/Falcon-E-1B-Instruct": { + "model_name": "Falcon-E-1B-Instruct", + }, + "tiiuae/Falcon-E-3B-Base": { + "model_name": "Falcon-E-3B-Base", + }, + "tiiuae/Falcon-E-1B-Base": { + "model_name": "Falcon-E-1B-Base", + }, } SUPPORTED_QUANT_TYPES = { @@ -52,8 +64,8 @@ } COMPILER_EXTRA_ARGS = { - "arm64": ["-DBITNET_ARM_TL1=ON"], - "x86_64": ["-DBITNET_X86_TL2=ON"] + "arm64": ["-DBITNET_ARM_TL1=OFF"], + "x86_64": ["-DBITNET_X86_TL2=OFF"] } OS_EXTRA_ARGS = { @@ -144,7 +156,7 @@ def setup_gguf(): def gen_code(): _, arch = system_info() - llama3_f3_models = set([model['model_name'] for model in SUPPORTED_HF_MODELS.values() if model['model_name'].startswith("Falcon3") or model['model_name'].startswith("Llama")]) + llama3_f3_models = set([model['model_name'] for model in SUPPORTED_HF_MODELS.values() if model['model_name'].startswith("Falcon") or model['model_name'].startswith("Llama")]) if arch == "arm64": if args.use_pretuned: diff --git a/src/README.md b/src/README.md new file mode 100644 index 000000000..f713b9ab2 --- /dev/null +++ b/src/README.md @@ -0,0 +1,205 @@ +# BitNet CPU Inference Optimization + +This update provides significant performance improvements for BitNet inference on CPU through paralleled kernel implementations, native I2_S GEMM/GEMV support, configurable tiling block size and embedding quantization. + +## Update + +- **Parallel Weight & Activation Computation** + Implemented parallel processing of weights and activations in the W2A8 vet_dot kernel, achieving improved throughput on both x86 and ARM architectures. + +- **Native I2_S GEMM & GEMV Support** + Integrated I2_S GEMM and GEMV operations into ggml library, making them fully compatible with the llama.cpp architecture. This enables seamless integration with existing inference pipelines. + +- **Configurable Tiling & Parallelism** + Introduced configurable GEMM & GEMV block sizes and parallelism levels, allowing performance fine-tuning for different CPU architectures. + +- **Embedding Quantization** + Added support for embedding layer quantization with Q6_K format, reducing memory footprint and improving inference speed while maintaining high accuracy. + +## Usage + +### Configuration Options + +The `include/gemm-config.h` file controls kernel behavior: + +```c +#define ROW_BLOCK_SIZE 4 +#define COL_BLOCK_SIZE 128 +#define PARALLEL_SIZE 4 +``` + +Modify these values based on your CPU cache size and architecture for optimal performance. Users can fine-tune performance on their machine through `include/gemm-config.h`. + +### Enabling Embedding Quantization + +To use embedding quantization for additional speedup: + +**Using setup_env.py:** +```bash +python setup_env.py --quant-embd +``` +This automatically converts embeddings to Q6_K format. + +**Manual conversion:** +```bash +build/bin/llama-quantize --token-embedding-type Q6_K models/BitNet-b1.58-2B-4T/ggml-model-f32.gguf models/BitNet-b1.58-2B-4T/ggml-model-i2_s-embed-q6_k.gguf I2_S 1 1 +``` + +## Optimizations + +### 1. Weight & Activation Parallelism + +The kernel implements two parallelization strategies: + +- **Weight Parallel:** Processes multiple weight rows/columns in a single kernel call, reducing kernel launch overhead. + +- **Activation Parallel:** Built on top of weight parallel, amortizes the I2_S weight unpacking cost across multiple activation elements. + +**Recommendation:** For I2_S quantization format, activation parallel is recommended due to the unpack operation benefits. The current kernel defaults to activation parallel. + +**Kernel Performance Comparison:** + +
+ +Test configuration: AMD EPYC 7V13 (x86), 1 threads, time in milliseconds (mean±std) + +| Matrix Size | No Parallel | Weight Parallel | Activation Parallel | +|:---:|:---:|:---:|:---:| +| [1, 2048] × [2048, 2048] | 0.075±0.012 | **0.058±0.007** | 0.076±0.011 | +| [32, 2048] × [2048, 2048] | 2.400±0.041 | 1.599±0.020 | **1.202±0.018** | +| [128, 2048] × [2048, 2048] | 10.820±0.039 | 6.458±0.168 | **5.805±0.039** | +| [256, 2048] × [2048, 2048] | 21.669±0.080 | 12.739±0.183 | **11.882±0.040** | +| [512, 2048] × [2048, 2048] | 43.257±0.083 | 25.680±0.335 | **23.342±0.082** | +| [2048, 2048] × [2048, 2048] | 173.175±0.214 | 103.112±0.552 | **93.276±0.612** | +| [128, 2048] × [2048, 8192] | 43.345±0.090 | 25.541±0.239 | **23.528±0.052** | +| [128, 8192] × [8192, 2048] | 38.085±0.162 | 23.866±0.096 | **22.569±0.132** | + +
+ +### 2. GEMM/GEMV Integration with llama.cpp + +Integrated I2_S quantization format into llama.cpp's compute graph: + +- **GEMV Operations:** Optimized matrix-vector multiplication for token generation. +- **GEMM Operations:** Efficient matrix-matrix multiplication for prompt processing. +- **Tiling Strategy:** Configurable block sizes for optimal cache utilization. + +### 3. Configuration Fine-tuning + +Fine-tuning kernel parameters for optimal performance on specific hardware: + +**Example Configuration (x86, AMD EPYC 7V13):** +- Method: Activation Parallel +- Threads: 8 +- Workload: 128 prompt tokens (pp128) + +**Fine-tuning Parameters:** +- **Parallelism Degree:** [2, 4, 8] +- **Row Block Size:** [2, 4, 8, 16, 32] +- **Column Block Size:** [32, 64, 128, 256, 512, 1024] + +**Fine-tuning Results:** + +
+ +fine_tune_result + +*Shows throughput (tokens/s) for various configurations.* + +
+ +**Optimal Configuration:** Under this setup (x86, 8 threads, pp128), the best performance is achieved with parallelism degree = 4, row block size = 4, and column block size = 128. + +### 4. Embedding Quantization + +Evaluated multiple embedding quantization formats to balance memory usage, model quality, and inference speed: + +**Perplexity Comparison:** + +
+ +Test configuration: BitNet-b1.58-2B-4T, TG128 + +| Embedding Type | Wikitext | PTB | LAMBADA | IMDB | AG NEWS | +|:---:|:---:|:---:|:---:|:---:|:---:| +| **F32** | 17.1090±0.1278 | 33.0858±0.4886 | 43.2850±0.6363 | 29.3016±0.2890 | 36.7686±0.3920 | +| **F16** | 17.1090±0.1278 | 33.0858±0.4886 | 43.2850±0.6363 | 29.3016±0.2890 | 36.7686±0.3920 | +| **Q8_0** | 17.1197±0.1280 | 33.1181±0.4893 | 43.2891±0.6364 | 29.3133±0.2892 | 36.7740±0.3920 | +| **Q6_K** | 17.1487±0.1282 | 33.2203±0.4914 | 43.3046±0.6362 | 29.3491±0.2897 | 36.7972±0.3921 | +| **Q5_0** | 17.2379±0.1288 | 33.2439±0.4907 | 43.4631±0.6379 | 29.5481±0.2920 | 36.8539±0.3924 | +| **Q4_0** | 17.3529±0.1300 | 33.7754±0.5001 | 44.4552±0.6559 | 30.1044±0.2978 | 37.3985±0.3997 | +| **Q3_K** | 17.6434±0.1320 | 34.3914±0.5089 | 45.4591±0.6735 | 30.8476±0.3069 | 39.5692±0.4259 | +| **I2_S** | N/A | N/A | N/A | N/A | N/A | + +**N/A indicates model failure due to extreme quantization.* + +
+ +**Inference Speed Comparison:** + +
+ +embedding_throughput + +*Token generation throughput (tg128) for different embedding quantization types.* + +
+ +**Recommendation:** Based on comprehensive evaluation of memory footprint, perplexity preservation, and inference speed, **Q6_K** is selected as the optimal embedding quantization format. + +## Performance + +Comparison of optimized parallel kernels vs. original implementation: + +**Test Configuration:** +- Model: BitNet-b1.58-2B-4T +- Hardware: AMD EPYC 7V13 +- Threads: 1 / 2 / 4 / 8 / 12 / 16 +- Test: 128 prompt tokens (pp128) + 128 generated tokens (tg128) +- Method: Activation Parallel + +
+ +performance_comparison_amd_epyc + +
+ +**Test Configuration:** +- Model: BitNet-b1.58-2B-4T +- Hardware: Intel i7-13800H +- Threads: 1 / 2 / 4 / 6 +- Test: 128 prompt tokens (pp128) + 128 generated tokens (tg128) +- Method: Activation Parallel + +
+ +performance_comparison_i7-13800h + +
+ +**Test Configuration:** +- Model: BitNet-b1.58-2B-4T +- Hardware: Cobalt 100 +- Threads: 1 / 2 / 4 / 8 +- Test: 128 prompt tokens (pp128) + 128 generated tokens (tg128) +- Method: Activation Parallel + +
+ +performance_comparison_cobalt100_dotprod + +
+ +## Technical Details + +### Key Files Modified + +- `src/ggml-bitnet-mad.cpp`: Parallel kernel implementations +- `3rdparty/llama.cpp/ggml/src/ggml.c`: GEMM/GEMV integration +- `include/gemm-config.h`: Configuration file + +### Supported Architectures + +- ✅ x86-64 with AVX2 +- ✅ ARM with NEON +- ✅ ARM with DOTPROD extension diff --git a/src/assets/embedding_throughput.png b/src/assets/embedding_throughput.png new file mode 100644 index 000000000..b3ebb82f8 Binary files /dev/null and b/src/assets/embedding_throughput.png differ diff --git a/src/assets/fine_tuning_result.png b/src/assets/fine_tuning_result.png new file mode 100644 index 000000000..5bcab69c8 Binary files /dev/null and b/src/assets/fine_tuning_result.png differ diff --git a/src/assets/performance_comparison_amd_epyc.png b/src/assets/performance_comparison_amd_epyc.png new file mode 100644 index 000000000..6ebdb3dbf Binary files /dev/null and b/src/assets/performance_comparison_amd_epyc.png differ diff --git a/src/assets/performance_comparison_cobalt100_dotprod.png b/src/assets/performance_comparison_cobalt100_dotprod.png new file mode 100644 index 000000000..4d0ef8c7f Binary files /dev/null and b/src/assets/performance_comparison_cobalt100_dotprod.png differ diff --git a/src/assets/performance_comparison_i7-13800h.png b/src/assets/performance_comparison_i7-13800h.png new file mode 100644 index 000000000..e486d669f Binary files /dev/null and b/src/assets/performance_comparison_i7-13800h.png differ diff --git a/src/ggml-bitnet-mad.cpp b/src/ggml-bitnet-mad.cpp index eeca82b1a..4ba9d6509 100644 --- a/src/ggml-bitnet-mad.cpp +++ b/src/ggml-bitnet-mad.cpp @@ -1,13 +1,18 @@ #include #include - +#include #include "ggml-bitnet.h" #include "ggml-quants.h" +#include "gemm-config.h" +#include "ggml-cpu-impl.h" #include #include +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) #define QK_I2_S 128 -#define QK_I2 128 +#elif defined(__ARM_NEON) +#define QK_I2_S 64 +#endif #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) #include @@ -44,8 +49,8 @@ static inline int hsum_i32_8(const __m256i a) { #endif size_t quantize_i2_s(const float * src, void * dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { - // 2 bits per weight - +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) +#if defined(ACT_PARALLEL) size_t row_size = ggml_row_size(GGML_TYPE_I2_S, n_per_row); int n = nrow * n_per_row; @@ -73,11 +78,11 @@ size_t quantize_i2_s(const float * src, void * dst, int64_t nrow, int64_t n_per_ // -1, 0, 1 uint8_t* i2_weight = (uint8_t*)dst; - for (int i = 0; i < n / QK_I2; i++) { - for (int j = 0; j < QK_I2; j++) { + for (int i = 0; i < n / QK_I2_S; i++) { + for (int j = 0; j < QK_I2_S; j++) { int group_idx = j / 32; int group_pos = j % 32; - uint8_t temp = (q8[i * QK_I2 + j] << (6 - 2 * group_idx)); + uint8_t temp = (q8[i * QK_I2_S + j] << (6 - 2 * group_idx)); i2_weight[i * 32 + group_pos] |= temp; } } @@ -89,9 +94,109 @@ size_t quantize_i2_s(const float * src, void * dst, int64_t nrow, int64_t n_per_ // 32B for alignment return nrow * row_size / 4 + 32; +#else + assert((nrow % 4) == 0 && "quantize_i2_s_1x4 requires nrow % 4 == 0"); + + size_t row_size = ggml_row_size(GGML_TYPE_I2_S, n_per_row); + int64_t n = nrow * n_per_row; + + double max = 0; + for (int64_t i = 0; i < n; ++i) { + max = fmax(max, (double)fabs((double)src[i])); + } + double i2_scale = max; + + uint8_t* q8 = (uint8_t*)malloc(n * sizeof(uint8_t)); + for (int64_t i=0; i 0 ? 2 : 0; + } + + uint8_t* out = (uint8_t*)dst; + memset(out, 0, (size_t)(n / 4)); + + // for each group of 4 rows, for each column, write one byte + int64_t nrow4 = nrow / 4; + for (int64_t rg = 0; rg < nrow4; rg++) { + int64_t r0 = rg * 4 + 0; + int64_t r1 = rg * 4 + 1; + int64_t r2 = rg * 4 + 2; + int64_t r3 = rg * 4 + 3; + + int64_t base = rg * n_per_row; + + for (int64_t col = 0; col < n_per_row; col++) { + uint8_t q0 = q8[r0 * n_per_row + col]; + uint8_t q1 = q8[r1 * n_per_row + col]; + uint8_t q2 = q8[r2 * n_per_row + col]; + uint8_t q3 = q8[r3 * n_per_row + col]; + + uint8_t packed = (uint8_t)((q0 << 6) | (q1 << 4) | (q2 << 2) | (q3 << 0)); + out[base + col] = packed; + } + } + + // store scale at the end of quantized data (same location pattern as quantize_i2_s) + float* scale_ptr = (float*)((char*)out + n / 4); + scale_ptr[0] = (float)i2_scale; + + free(q8); + + // return size (keep same formula as quantize_i2_s) + return nrow * row_size / 4 + 32; +#endif +#elif defined(__ARM_NEON) + size_t row_size = ggml_row_size(GGML_TYPE_I2_S, n_per_row); + + int n = nrow * n_per_row; + + // f32 -> q8 + double max = 0; + for (int i = 0; i < n; ++i) { + max = fmax(max, (double)fabs((double)src[i])); + } + double i2_scale = max; + + uint8_t* q8 = (uint8_t*)malloc(n * sizeof(uint8_t)); + for (int i=0; i 0 ? 2 : 0; + } + + memset(dst, 0, n * sizeof(uint8_t) / 4); + + // q8 -> 0, 1, 2 + // | | | + // -1, 0, 1 + + uint8_t* i2_weight = (uint8_t*)dst; + for (int i = 0; i < n / QK_I2_S; i++) { + for (int j = 0; j < QK_I2_S; j++) { + int group_idx = j / 16; + int group_pos = j % 16; + uint8_t temp = (q8[i * QK_I2_S + j] << (6 - 2 * group_idx)); + i2_weight[i * 16 + group_pos] |= temp; + } + } + + float* scale_ptr = (float*)((char*)i2_weight + n / 4); + scale_ptr[0] = i2_scale; + + free(q8); + + // 32B for alignment + return nrow * row_size / 4 + 32; +#endif } -void ggml_vec_dot_i2_i8_s(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +void ggml_vec_dot_i2_i8_s_1x1(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if defined(__AVX2__) const uint8_t * x = (uint8_t *)vx; const int8_t * y = (int8_t *)vy; @@ -99,265 +204,853 @@ void ggml_vec_dot_i2_i8_s(int n, float * s, size_t bs, const void * vx, size_t b const int group32_num = nb / 32; const int la_num = nb % 32; const int groupla_num = nb % 32 != 0 ? 1 : 0; + + __m256i mask = _mm256_set1_epi8(0x03); + __m256i one16 = _mm256_set1_epi16(1); -#if defined(__AVX2__) + // 处理多行,nrc表示要处理的行数 + for (int row = 0; row < nrc; row++) { + __m256i accu = _mm256_setzero_si256(); + + // 计算当前行的x指针偏移 + const uint8_t * x_row = x + row * bx / 4; + + for (int i = 0; i < group32_num; i++) { + const uint8_t *px = x_row + i * 1024; // 32 * 32 + const int8_t *py = y + i * 4096; // 32 * 128 + __m256i accu32 = _mm256_setzero_si256(); + + for (int j = 0; j < 32; j++) { + // 128 index + __m256i xq8_3 = _mm256_loadu_si256((const __m256i*)(px)); + __m256i xq8_2 = _mm256_srli_epi16(xq8_3, 2); + __m256i xq8_1 = _mm256_srli_epi16(xq8_3, 4); + __m256i xq8_0 = _mm256_srli_epi16(xq8_3, 6); - __m256i mask = _mm256_set1_epi8(0x03); - __m256i accu = _mm256_setzero_si256(); - - for (int i=0; i < group32_num; i++){ - __m256i accu32 = _mm256_setzero_si256(); - for (int j=0; j < 32; j++) { - // 128 index - __m256i xq8_3 = _mm256_loadu_si256((const __m256i*)(x + i * 32 * 32 + j * 32)); - __m256i xq8_2 = _mm256_srli_epi16(xq8_3, 2); - __m256i xq8_1 = _mm256_srli_epi16(xq8_3, 4); - __m256i xq8_0 = _mm256_srli_epi16(xq8_3, 6); - - // each 32 index - xq8_3 = _mm256_and_si256(xq8_3, mask); - xq8_2 = _mm256_and_si256(xq8_2, mask); - xq8_1 = _mm256_and_si256(xq8_1, mask); - xq8_0 = _mm256_and_si256(xq8_0, mask); - - // each 32 index - __m256i yq8_0 = _mm256_loadu_si256((const __m256i*)(y + i * 128 * 32 + j * 128 + 0)); - __m256i yq8_1 = _mm256_loadu_si256((const __m256i*)(y + i * 128 * 32 + j * 128 + 32)); - __m256i yq8_2 = _mm256_loadu_si256((const __m256i*)(y + i * 128 * 32 + j * 128 + 64)); - __m256i yq8_3 = _mm256_loadu_si256((const __m256i*)(y + i * 128 * 32 + j * 128 + 96)); - - // 128 index accumulation add - // split into 32 accumulation block - // each block each 128 index accumulated 4index - // each index maximum 256 - // each block maximum 4 * 256 - // each block accumulation maximum 127 * 256 - // each 32 group index (128 index in one group) needs cast to int32 - xq8_0 = _mm256_maddubs_epi16(xq8_0, yq8_0); - xq8_1 = _mm256_maddubs_epi16(xq8_1, yq8_1); - xq8_2 = _mm256_maddubs_epi16(xq8_2, yq8_2); - xq8_3 = _mm256_maddubs_epi16(xq8_3, yq8_3); - - accu32 = _mm256_add_epi16(accu32, _mm256_add_epi16(xq8_0, xq8_1)); - accu32 = _mm256_add_epi16(accu32, _mm256_add_epi16(xq8_2, xq8_3)); + // each 32 index + xq8_3 = _mm256_and_si256(xq8_3, mask); + xq8_2 = _mm256_and_si256(xq8_2, mask); + xq8_1 = _mm256_and_si256(xq8_1, mask); + xq8_0 = _mm256_and_si256(xq8_0, mask); + + // each 32 index + __m256i yq8_0 = _mm256_loadu_si256((const __m256i*)(py)); + __m256i yq8_1 = _mm256_loadu_si256((const __m256i*)(py + 32)); + __m256i yq8_2 = _mm256_loadu_si256((const __m256i*)(py + 64)); + __m256i yq8_3 = _mm256_loadu_si256((const __m256i*)(py + 96)); + + xq8_0 = _mm256_maddubs_epi16(xq8_0, yq8_0); + xq8_1 = _mm256_maddubs_epi16(xq8_1, yq8_1); + xq8_2 = _mm256_maddubs_epi16(xq8_2, yq8_2); + xq8_3 = _mm256_maddubs_epi16(xq8_3, yq8_3); + + accu32 = _mm256_add_epi16(accu32, _mm256_add_epi16(xq8_0, xq8_1)); + accu32 = _mm256_add_epi16(accu32, _mm256_add_epi16(xq8_2, xq8_3)); + + px += 32; + py += 128; + } + accu = _mm256_add_epi32(_mm256_madd_epi16(accu32, one16), accu); } - accu = _mm256_add_epi32(_mm256_madd_epi16(accu32, _mm256_set1_epi16(1)), accu); + + for (int i = 0; i < groupla_num; i++) { + __m256i accula = _mm256_setzero_si256(); + const uint8_t *px = x_row + group32_num * 1024; // 32 * 32 + const int8_t *py = y + group32_num * 4096; // 32 * 128 + + for (int j = 0; j < la_num; j++) { + // 128 index + __m256i xq8_3 = _mm256_loadu_si256((const __m256i*)(px)); + __m256i xq8_2 = _mm256_srli_epi16(xq8_3, 2); + __m256i xq8_1 = _mm256_srli_epi16(xq8_3, 4); + __m256i xq8_0 = _mm256_srli_epi16(xq8_3, 6); + + // each 32 index + xq8_3 = _mm256_and_si256(xq8_3, mask); + xq8_2 = _mm256_and_si256(xq8_2, mask); + xq8_1 = _mm256_and_si256(xq8_1, mask); + xq8_0 = _mm256_and_si256(xq8_0, mask); + + // each 32 index + __m256i yq8_0 = _mm256_loadu_si256((const __m256i*)(py)); + __m256i yq8_1 = _mm256_loadu_si256((const __m256i*)(py + 32)); + __m256i yq8_2 = _mm256_loadu_si256((const __m256i*)(py + 64)); + __m256i yq8_3 = _mm256_loadu_si256((const __m256i*)(py + 96)); + + xq8_0 = _mm256_maddubs_epi16(xq8_0, yq8_0); + xq8_1 = _mm256_maddubs_epi16(xq8_1, yq8_1); + xq8_2 = _mm256_maddubs_epi16(xq8_2, yq8_2); + xq8_3 = _mm256_maddubs_epi16(xq8_3, yq8_3); + + accula = _mm256_add_epi16(accula, _mm256_add_epi16(xq8_0, xq8_1)); + accula = _mm256_add_epi16(accula, _mm256_add_epi16(xq8_2, xq8_3)); + + px += 32; + py += 128; + } + accu = _mm256_add_epi32(accu, _mm256_madd_epi16(accula, one16)); + } + + int sumi = hsum_i32_8(accu); + s[row] = (float)sumi; } +#elif defined(__ARM_NEON) + const uint8_t * x = (uint8_t *)vx; + const int8_t * y = (int8_t *)vy; + + const int nb = n / QK_I2_S; + const int group32_num = nb / 32; + const int la_num = nb % 32; + const int groupla_num = nb % 32 != 0 ? 1 : 0; + + const uint8x16_t mask = vdupq_n_u8(3); + + // 处理多列,nrc表示要处理的列数 + for (int row = 0; row < nrc; row++) { + int32x4_t accu = vdupq_n_s32(0); + + // 计算当前行的x指针偏移 + const uint8_t * x_row = x + row * bx / 4; + + for (int i=0; i < group32_num; i++) { + +#if defined(__ARM_FEATURE_DOTPROD) + +#else + int16x8_t accu32 = vdupq_n_s16(0); +#endif + for (int j=0; j < 32; j++) { + uint8x16_t xq8_3 = vld1q_u8(x_row + i * 32 * 16 + j * 16); + uint8x16_t xq8_2 = vshrq_n_u8(xq8_3, 2); + uint8x16_t xq8_1 = vshrq_n_u8(xq8_3, 4); + uint8x16_t xq8_0 = vshrq_n_u8(xq8_3, 6); + + int8x16_t q8_0 = vreinterpretq_s8_u8(vandq_u8(xq8_0, mask)); + int8x16_t q8_1 = vreinterpretq_s8_u8(vandq_u8(xq8_1, mask)); + int8x16_t q8_2 = vreinterpretq_s8_u8(vandq_u8(xq8_2, mask)); + int8x16_t q8_3 = vreinterpretq_s8_u8(vandq_u8(xq8_3, mask)); + + const int8x16_t yq8_0 = vld1q_s8(y + i * 32 * 64 + j * 64 + 0); + const int8x16_t yq8_1 = vld1q_s8(y + i * 32 * 64 + j * 64 + 16); + const int8x16_t yq8_2 = vld1q_s8(y + i * 32 * 64 + j * 64 + 32); + const int8x16_t yq8_3 = vld1q_s8(y + i * 32 * 64 + j * 64 + 48); + +#if defined(__ARM_FEATURE_DOTPROD) + accu = vdotq_s32(accu, q8_0, yq8_0); + accu = vdotq_s32(accu, q8_1, yq8_1); + accu = vdotq_s32(accu, q8_2, yq8_2); + accu = vdotq_s32(accu, q8_3, yq8_3); +#else + accu32 = vmlal_s8(accu32, vget_low_s8(q8_0), vget_low_s8(yq8_0)); + accu32 = vmlal_s8(accu32, vget_high_s8(q8_0), vget_high_s8(yq8_0)); + accu32 = vmlal_s8(accu32, vget_low_s8(q8_1), vget_low_s8(yq8_1)); + accu32 = vmlal_s8(accu32, vget_high_s8(q8_1), vget_high_s8(yq8_1)); + accu32 = vmlal_s8(accu32, vget_low_s8(q8_2), vget_low_s8(yq8_2)); + accu32 = vmlal_s8(accu32, vget_high_s8(q8_2), vget_high_s8(yq8_2)); + accu32 = vmlal_s8(accu32, vget_low_s8(q8_3), vget_low_s8(yq8_3)); + accu32 = vmlal_s8(accu32, vget_high_s8(q8_3), vget_high_s8(yq8_3)); +#endif + } + +#if defined(__ARM_FEATURE_DOTPROD) - for (int i = 0; i < groupla_num; i++){ - __m256i accula = _mm256_setzero_si256(); - for (int j = 0; j < la_num; j++) { - // 128 index - __m256i xq8_3 = _mm256_loadu_si256((const __m256i*)(x + group32_num * 32 * 32 + j * 32)); - __m256i xq8_2 = _mm256_srli_epi16(xq8_3, 2); - __m256i xq8_1 = _mm256_srli_epi16(xq8_3, 4); - __m256i xq8_0 = _mm256_srli_epi16(xq8_3, 6); - - // each 32 index - xq8_3 = _mm256_and_si256(xq8_3, mask); - xq8_2 = _mm256_and_si256(xq8_2, mask); - xq8_1 = _mm256_and_si256(xq8_1, mask); - xq8_0 = _mm256_and_si256(xq8_0, mask); - - // each 32 index - __m256i yq8_0 = _mm256_loadu_si256((const __m256i*)(y + group32_num * 128 * 32 + j * 128 + 0)); - __m256i yq8_1 = _mm256_loadu_si256((const __m256i*)(y + group32_num * 128 * 32 + j * 128 + 32)); - __m256i yq8_2 = _mm256_loadu_si256((const __m256i*)(y + group32_num * 128 * 32 + j * 128 + 64)); - __m256i yq8_3 = _mm256_loadu_si256((const __m256i*)(y + group32_num * 128 * 32 + j * 128 + 96)); - - // 128 index accumulation add - // split into 32 accumulation block - // each block each 128 index accumulated 4index - // each index maximum 256 - // each block maximum 4 * 256 - // each block accumulation maximum 127 * 256 - // each 32 group index (128 index in one group) needs cast to int32 - xq8_0 = _mm256_maddubs_epi16(xq8_0, yq8_0); - xq8_1 = _mm256_maddubs_epi16(xq8_1, yq8_1); - xq8_2 = _mm256_maddubs_epi16(xq8_2, yq8_2); - xq8_3 = _mm256_maddubs_epi16(xq8_3, yq8_3); - - accula = _mm256_add_epi16(accula, _mm256_add_epi16(xq8_0, xq8_1)); - accula = _mm256_add_epi16(accula, _mm256_add_epi16(xq8_2, xq8_3)); +#else + accu = vaddq_s32(accu, vmovl_s16(vget_low_s16(accu32))); + accu = vaddq_s32(accu, vmovl_high_s16(accu32)); +#endif } - accu = _mm256_add_epi32(accu, _mm256_madd_epi16(accula, _mm256_set1_epi16(1))); + + for (int i = 0; i < groupla_num; i++){ +#if defined(__ARM_FEATURE_DOTPROD) + +#else + int16x8_t accula = vdupq_n_s16(0); +#endif + for (int j = 0; j < la_num; j++) { + uint8x16_t xq8_3 = vld1q_u8(x_row + group32_num * 32 * 16 + j * 16); + uint8x16_t xq8_2 = vshrq_n_u8(xq8_3, 2); + uint8x16_t xq8_1 = vshrq_n_u8(xq8_3, 4); + uint8x16_t xq8_0 = vshrq_n_u8(xq8_3, 6); + + int8x16_t q8_0 = vreinterpretq_s8_u8(vandq_u8(xq8_0, mask)); + int8x16_t q8_1 = vreinterpretq_s8_u8(vandq_u8(xq8_1, mask)); + int8x16_t q8_2 = vreinterpretq_s8_u8(vandq_u8(xq8_2, mask)); + int8x16_t q8_3 = vreinterpretq_s8_u8(vandq_u8(xq8_3, mask)); + + const int8x16_t yq8_0 = vld1q_s8(y + group32_num * 32 * 64 + j * 64 + 0); + const int8x16_t yq8_1 = vld1q_s8(y + group32_num * 32 * 64 + j * 64 + 16); + const int8x16_t yq8_2 = vld1q_s8(y + group32_num * 32 * 64 + j * 64 + 32); + const int8x16_t yq8_3 = vld1q_s8(y + group32_num * 32 * 64 + j * 64 + 48); + +#if defined(__ARM_FEATURE_DOTPROD) + accu = vdotq_s32(accu, q8_0, yq8_0); + accu = vdotq_s32(accu, q8_1, yq8_1); + accu = vdotq_s32(accu, q8_2, yq8_2); + accu = vdotq_s32(accu, q8_3, yq8_3); +#else + accula = vmlal_s8(accula, vget_low_s8(q8_0), vget_low_s8(yq8_0)); + accula = vmlal_s8(accula, vget_high_s8(q8_0), vget_high_s8(yq8_0)); + accula = vmlal_s8(accula, vget_low_s8(q8_1), vget_low_s8(yq8_1)); + accula = vmlal_s8(accula, vget_high_s8(q8_1), vget_high_s8(yq8_1)); + accula = vmlal_s8(accula, vget_low_s8(q8_2), vget_low_s8(yq8_2)); + accula = vmlal_s8(accula, vget_high_s8(q8_2), vget_high_s8(yq8_2)); + accula = vmlal_s8(accula, vget_low_s8(q8_3), vget_low_s8(yq8_3)); + accula = vmlal_s8(accula, vget_high_s8(q8_3), vget_high_s8(yq8_3)); +#endif + } +#if defined(__ARM_FEATURE_DOTPROD) + +#else + accu = vaddq_s32(accu, vmovl_s16(vget_low_s16(accula))); + accu = vaddq_s32(accu, vmovl_high_s16(accula)); +#endif + } + int sumi = vaddlvq_s32(accu); + s[row] = (float)sumi; } - int sumi = hsum_i32_8(accu); - *s = (float)sumi; +#endif +} + +void ggml_vec_dot_i2_i8_s_1x4_32W(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if defined(__AVX2__) + const uint8_t * x = (uint8_t *)vx; + const int8_t * y = (int8_t *)vy; + const int nb = n / QK_I2_S; + const int group32_num = nb / 32; + const int la_num = nb % 32; + const int groupla_num = nb % 32 != 0 ? 1 : 0; + + const __m256i mask = _mm256_set1_epi8(0x03); + const __m256i one16 = _mm256_set1_epi16(1); + + // 处理多行,nrc表示要处理的行数 + for (int row = 0; row < nrc; row+=4) { + __m256i accu[4]; + for(int rb = 0; rb < 4; rb++) { + accu[rb] = _mm256_setzero_si256(); + } + const uint8_t * x_row = x + (row) * bx / 4; + // 计算当前行的x指针偏移 + + for (int i = 0; i < group32_num; i++) { + const uint8_t * px = x_row + i * 1024 * 4; + __m256i accu32[4]; + for(int rb = 0; rb < 4; rb++) { + accu32[rb] = _mm256_setzero_si256(); + } + const int8_t *py = y + i * 4096; + + for (int j = 0; j < 32 * 4; j++) { + // each 32 index + __m256i yq8_0 = _mm256_loadu_si256((const __m256i*)(py)); + __m256i xq8[4]; + xq8[3] = _mm256_loadu_si256((const __m256i*)(px)); + xq8[2] = _mm256_srli_epi16(xq8[3], 2); + xq8[1] = _mm256_srli_epi16(xq8[3], 4); + xq8[0] = _mm256_srli_epi16(xq8[3], 6); + xq8[3] = _mm256_and_si256(xq8[3], mask); + xq8[2] = _mm256_and_si256(xq8[2], mask); + xq8[1] = _mm256_and_si256(xq8[1], mask); + xq8[0] = _mm256_and_si256(xq8[0], mask); + for (int rb = 0; rb < 4; rb++) + { + xq8[rb] = _mm256_maddubs_epi16(xq8[rb], yq8_0); + accu32[rb] = _mm256_add_epi16(accu32[rb], xq8[rb]); + } + px += 32; + py += 32; + } + for(int rb = 0; rb < 4; rb++) { + accu[rb] = _mm256_add_epi32(_mm256_madd_epi16(accu32[rb], one16), accu[rb]); + } + } + + for (int i = 0; i < groupla_num; i++) { + const int8_t *py = y + group32_num * 4096; // 32 * 128 + __m256i accula[4]; + for(int rb = 0; rb < 4; rb++) { + accula[rb] = _mm256_setzero_si256(); + } + const uint8_t * px = x_row + group32_num * 1024 * 4; + + for (int j = 0; j < la_num * 4; j++) { + // each 32 index + __m256i yq8_0 = _mm256_loadu_si256((const __m256i*)(py)); + __m256i xq8[4]; + xq8[3] = _mm256_loadu_si256((const __m256i*)(px)); + xq8[2] = _mm256_srli_epi16(xq8[3], 2); + xq8[1] = _mm256_srli_epi16(xq8[3], 4); + xq8[0] = _mm256_srli_epi16(xq8[3], 6); + xq8[3] = _mm256_and_si256(xq8[3], mask); + xq8[2] = _mm256_and_si256(xq8[2], mask); + xq8[1] = _mm256_and_si256(xq8[1], mask); + xq8[0] = _mm256_and_si256(xq8[0], mask); + + for (int rb = 0; rb < 4; rb++) { + xq8[rb] = _mm256_maddubs_epi16(xq8[rb], yq8_0); + accula[rb] = _mm256_add_epi16(accula[rb], xq8[rb]); + } + px += 32; + py += 32; + } + for(int rb = 0; rb < 4; rb++) { + accu[rb] = _mm256_add_epi32(accu[rb], _mm256_madd_epi16(accula[rb], one16)); + } + } + + for(int rb = 0; rb < 4; rb++) { + int sumi = hsum_i32_8(accu[rb]); + s[row + rb] = (float)sumi; + } + } #elif defined(__ARM_NEON) - int32x4_t accu_0 = vdupq_n_s32(0); - int32x4_t accu_1 = vdupq_n_s32(0); - int32x4_t accu_2 = vdupq_n_s32(0); - int32x4_t accu_3 = vdupq_n_s32(0); +#endif +} + +void ggml_vec_dot_i2_i8_s_1xN(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if defined(__AVX2__) + const uint8_t * x = (uint8_t *)vx; + const int8_t * y = (int8_t *)vy; + + const int nb = n / QK_I2_S; + const int group32_num = nb / 32; + const int la_num = nb % 32; + const int groupla_num = nb % 32 != 0 ? 1 : 0; + + const __m256i mask = _mm256_set1_epi8(0x03); + const __m256i one16 = _mm256_set1_epi16(1); + + // 处理多行,nrc表示要处理的行数 + for (int row = 0; row < nrc; row+=PARALLEL_SIZE) { + //__m256i accu = _mm256_setzero_si256(); + __m256i accu[PARALLEL_SIZE]; + const uint8_t * x_row[PARALLEL_SIZE]; + for(int rb = 0; rb < PARALLEL_SIZE; rb++) { + accu[rb] = _mm256_setzero_si256(); + x_row[rb] = x + (row + rb) * bx / 4; + } + // 计算当前行的x指针偏移 + + for (int i = 0; i < group32_num; i++) { + const uint8_t * px[PARALLEL_SIZE]; + __m256i accu32[PARALLEL_SIZE]; + for(int rb = 0; rb < PARALLEL_SIZE; rb++) { + px[rb] = x_row[rb] + i * 1024; // 32 * 32 + accu32[rb] = _mm256_setzero_si256(); + } + const int8_t *py = y + i * 4096; // 32 * 128 + + for (int j = 0; j < 32; j++) { + // each 32 index + __m256i yq8_0 = _mm256_loadu_si256((const __m256i*)(py)); + __m256i yq8_1 = _mm256_loadu_si256((const __m256i*)(py + 32)); + __m256i yq8_2 = _mm256_loadu_si256((const __m256i*)(py + 64)); + __m256i yq8_3 = _mm256_loadu_si256((const __m256i*)(py + 96)); + for (int rb = 0; rb < PARALLEL_SIZE; rb++) + { + __m256i xq8_3 = _mm256_loadu_si256((const __m256i*)(px[rb])); + __m256i xq8_2 = _mm256_srli_epi16(xq8_3, 2); + __m256i xq8_1 = _mm256_srli_epi16(xq8_3, 4); + __m256i xq8_0 = _mm256_srli_epi16(xq8_3, 6); + + // each 32 index + xq8_3 = _mm256_and_si256(xq8_3, mask); + xq8_2 = _mm256_and_si256(xq8_2, mask); + xq8_1 = _mm256_and_si256(xq8_1, mask); + xq8_0 = _mm256_and_si256(xq8_0, mask); + + xq8_0 = _mm256_maddubs_epi16(xq8_0, yq8_0); + xq8_1 = _mm256_maddubs_epi16(xq8_1, yq8_1); + xq8_2 = _mm256_maddubs_epi16(xq8_2, yq8_2); + xq8_3 = _mm256_maddubs_epi16(xq8_3, yq8_3); + + accu32[rb] = _mm256_add_epi16(accu32[rb], _mm256_add_epi16(xq8_0, xq8_1)); + accu32[rb] = _mm256_add_epi16(accu32[rb], _mm256_add_epi16(xq8_2, xq8_3)); + + px[rb] += 32; + } + py += 128; + } + for(int rb = 0; rb < PARALLEL_SIZE; rb++) { + accu[rb] = _mm256_add_epi32(_mm256_madd_epi16(accu32[rb], one16), accu[rb]); + } + } + + for (int i = 0; i < groupla_num; i++) { + const int8_t *py = y + group32_num * 4096; // 32 * 128 + const uint8_t * px[PARALLEL_SIZE]; + __m256i accula[PARALLEL_SIZE]; + for(int rb = 0; rb < PARALLEL_SIZE; rb++) { + px[rb] = x_row[rb] + group32_num * 1024; // 32 * 32 + accula[rb] = _mm256_setzero_si256(); + } + + for (int j = 0; j < la_num; j++) { + // each 32 index + __m256i yq8_0 = _mm256_loadu_si256((const __m256i*)(py)); + __m256i yq8_1 = _mm256_loadu_si256((const __m256i*)(py + 32)); + __m256i yq8_2 = _mm256_loadu_si256((const __m256i*)(py + 64)); + __m256i yq8_3 = _mm256_loadu_si256((const __m256i*)(py + 96)); + + for (int rb = 0; rb < PARALLEL_SIZE; rb++) { + // 128 index + __m256i xq8_3 = _mm256_loadu_si256((const __m256i*)(px[rb])); + __m256i xq8_2 = _mm256_srli_epi16(xq8_3, 2); + __m256i xq8_1 = _mm256_srli_epi16(xq8_3, 4); + __m256i xq8_0 = _mm256_srli_epi16(xq8_3, 6); + + // each 32 index + xq8_3 = _mm256_and_si256(xq8_3, mask); + xq8_2 = _mm256_and_si256(xq8_2, mask); + xq8_1 = _mm256_and_si256(xq8_1, mask); + xq8_0 = _mm256_and_si256(xq8_0, mask); + + + + xq8_0 = _mm256_maddubs_epi16(xq8_0, yq8_0); + xq8_1 = _mm256_maddubs_epi16(xq8_1, yq8_1); + xq8_2 = _mm256_maddubs_epi16(xq8_2, yq8_2); + xq8_3 = _mm256_maddubs_epi16(xq8_3, yq8_3); + + accula[rb] = _mm256_add_epi16(accula[rb], _mm256_add_epi16(xq8_0, xq8_1)); + accula[rb] = _mm256_add_epi16(accula[rb], _mm256_add_epi16(xq8_2, xq8_3)); + + px[rb] += 32; + } + py += 128; + } + for(int rb = 0; rb < PARALLEL_SIZE; rb++) { + accu[rb] = _mm256_add_epi32(accu[rb], _mm256_madd_epi16(accula[rb], one16)); + } + } + + for(int rb = 0; rb < PARALLEL_SIZE; rb++) { + int sumi = hsum_i32_8(accu[rb]); + s[row + rb] = (float)sumi; + } + } +#elif defined(__ARM_NEON) + const uint8_t * x = (uint8_t *)vx; + const int8_t * y = (int8_t *)vy; + + const int nb = n / QK_I2_S; + const int group32_num = nb / 32; + const int la_num = nb % 32; + const int groupla_num = nb % 32 != 0 ? 1 : 0; + const uint8x16_t mask = vdupq_n_u8(3); - for (int i=0; i < group32_num; i++) { + // 处理多行,nrc表示要处理的行数 + for (int row = 0; row < nrc; row += PARALLEL_SIZE) { + int32x4_t accu[PARALLEL_SIZE]; + const uint8_t * x_row[PARALLEL_SIZE]; + + for (int rb = 0; rb < PARALLEL_SIZE; rb++) { + accu[rb] = vdupq_n_s32(0); + x_row[rb] = x + (row + rb) * bx / 4; + } + + for (int i = 0; i < group32_num; i++) { #if defined(__ARM_FEATURE_DOTPROD) #else - int16x8_t accu32_0 = vdupq_n_s16(0); - int16x8_t accu32_1 = vdupq_n_s16(0); - int16x8_t accu32_2 = vdupq_n_s16(0); - int16x8_t accu32_3 = vdupq_n_s16(0); + int16x8_t accu32[PARALLEL_SIZE]; + for (int rb = 0; rb < PARALLEL_SIZE; rb++) { + accu32[rb] = vdupq_n_s16(0); + } #endif + const uint8_t * px[PARALLEL_SIZE]; + for (int rb = 0; rb < PARALLEL_SIZE; rb++) { + px[rb] = x_row[rb] + i * 32 * 16; + } + + for (int j = 0; j < 32; j++) { + // 加载 y 数据(对所有行共享) + const int8x16_t yq8_0 = vld1q_s8(y + i * 32 * 64 + j * 64 + 0); + const int8x16_t yq8_1 = vld1q_s8(y + i * 32 * 64 + j * 64 + 16); + const int8x16_t yq8_2 = vld1q_s8(y + i * 32 * 64 + j * 64 + 32); + const int8x16_t yq8_3 = vld1q_s8(y + i * 32 * 64 + j * 64 + 48); - for (int j=0; j < 32; j++) { - uint8x16_t xq8_6 = vld1q_u8(x + i * 32 * 32 + j * 32); - uint8x16_t xq8_7 = vld1q_u8(x + i * 32 * 32 + j * 32 + 16); - uint8x16_t xq8_4 = vshrq_n_u8(xq8_6, 2); - uint8x16_t xq8_5 = vshrq_n_u8(xq8_7, 2); - uint8x16_t xq8_2 = vshrq_n_u8(xq8_6, 4); - uint8x16_t xq8_3 = vshrq_n_u8(xq8_7, 4); - uint8x16_t xq8_0 = vshrq_n_u8(xq8_6, 6); - uint8x16_t xq8_1 = vshrq_n_u8(xq8_7, 6); - - int8x16_t q8_0 = vreinterpretq_s8_u8(vandq_u8(xq8_0, mask)); - int8x16_t q8_1 = vreinterpretq_s8_u8(vandq_u8(xq8_1, mask)); - int8x16_t q8_2 = vreinterpretq_s8_u8(vandq_u8(xq8_2, mask)); - int8x16_t q8_3 = vreinterpretq_s8_u8(vandq_u8(xq8_3, mask)); - int8x16_t q8_4 = vreinterpretq_s8_u8(vandq_u8(xq8_4, mask)); - int8x16_t q8_5 = vreinterpretq_s8_u8(vandq_u8(xq8_5, mask)); - int8x16_t q8_6 = vreinterpretq_s8_u8(vandq_u8(xq8_6, mask)); - int8x16_t q8_7 = vreinterpretq_s8_u8(vandq_u8(xq8_7, mask)); - - const int8x16_t yq8_0 = vld1q_s8(y + i * 128 * 32 + j * 128 + 0); - const int8x16_t yq8_1 = vld1q_s8(y + i * 128 * 32 + j * 128 + 16); - const int8x16_t yq8_2 = vld1q_s8(y + i * 128 * 32 + j * 128 + 32); - const int8x16_t yq8_3 = vld1q_s8(y + i * 128 * 32 + j * 128 + 48); - const int8x16_t yq8_4 = vld1q_s8(y + i * 128 * 32 + j * 128 + 64); - const int8x16_t yq8_5 = vld1q_s8(y + i * 128 * 32 + j * 128 + 80); - const int8x16_t yq8_6 = vld1q_s8(y + i * 128 * 32 + j * 128 + 96); - const int8x16_t yq8_7 = vld1q_s8(y + i * 128 * 32 + j * 128 + 112); + // 处理每一行 + for (int rb = 0; rb < PARALLEL_SIZE; rb++) { + uint8x16_t xq8_3 = vld1q_u8(px[rb] + 0); + uint8x16_t xq8_2 = vshrq_n_u8(xq8_3, 2); + uint8x16_t xq8_1 = vshrq_n_u8(xq8_3, 4); + uint8x16_t xq8_0 = vshrq_n_u8(xq8_3, 6); + + int8x16_t q8_3 = vreinterpretq_s8_u8(vandq_u8(xq8_3, mask)); + int8x16_t q8_2 = vreinterpretq_s8_u8(vandq_u8(xq8_2, mask)); + int8x16_t q8_1 = vreinterpretq_s8_u8(vandq_u8(xq8_1, mask)); + int8x16_t q8_0 = vreinterpretq_s8_u8(vandq_u8(xq8_0, mask)); #if defined(__ARM_FEATURE_DOTPROD) - accu_0 = vdotq_s32(accu_0, q8_0, yq8_0); - accu_1 = vdotq_s32(accu_1, q8_1, yq8_1); - accu_2 = vdotq_s32(accu_2, q8_2, yq8_2); - accu_3 = vdotq_s32(accu_3, q8_3, yq8_3); - accu_0 = vdotq_s32(accu_0, q8_4, yq8_4); - accu_1 = vdotq_s32(accu_1, q8_5, yq8_5); - accu_2 = vdotq_s32(accu_2, q8_6, yq8_6); - accu_3 = vdotq_s32(accu_3, q8_7, yq8_7); + accu[rb] = vdotq_s32(accu[rb], q8_0, yq8_0); + accu[rb] = vdotq_s32(accu[rb], q8_1, yq8_1); + accu[rb] = vdotq_s32(accu[rb], q8_2, yq8_2); + accu[rb] = vdotq_s32(accu[rb], q8_3, yq8_3); #else - accu32_0 = vmlal_s8(accu32_0, vget_low_s8(q8_0), vget_low_s8(yq8_0)); - accu32_1 = vmlal_s8(accu32_1, vget_high_s8(q8_0), vget_high_s8(yq8_0)); - accu32_2 = vmlal_s8(accu32_2, vget_low_s8(q8_1), vget_low_s8(yq8_1)); - accu32_3 = vmlal_s8(accu32_3, vget_high_s8(q8_1), vget_high_s8(yq8_1)); - accu32_0 = vmlal_s8(accu32_0, vget_low_s8(q8_2), vget_low_s8(yq8_2)); - accu32_1 = vmlal_s8(accu32_1, vget_high_s8(q8_2), vget_high_s8(yq8_2)); - accu32_2 = vmlal_s8(accu32_2, vget_low_s8(q8_3), vget_low_s8(yq8_3)); - accu32_3 = vmlal_s8(accu32_3, vget_high_s8(q8_3), vget_high_s8(yq8_3)); - accu32_0 = vmlal_s8(accu32_0, vget_low_s8(q8_4), vget_low_s8(yq8_4)); - accu32_1 = vmlal_s8(accu32_1, vget_high_s8(q8_4), vget_high_s8(yq8_4)); - accu32_2 = vmlal_s8(accu32_2, vget_low_s8(q8_5), vget_low_s8(yq8_5)); - accu32_3 = vmlal_s8(accu32_3, vget_high_s8(q8_5), vget_high_s8(yq8_5)); - accu32_0 = vmlal_s8(accu32_0, vget_low_s8(q8_6), vget_low_s8(yq8_6)); - accu32_1 = vmlal_s8(accu32_1, vget_high_s8(q8_6), vget_high_s8(yq8_6)); - accu32_2 = vmlal_s8(accu32_2, vget_low_s8(q8_7), vget_low_s8(yq8_7)); - accu32_3 = vmlal_s8(accu32_3, vget_high_s8(q8_7), vget_high_s8(yq8_7)); + accu32[rb] = vmlal_s8(accu32[rb], vget_low_s8(q8_3), vget_low_s8(yq8_3)); + accu32[rb] = vmlal_s8(accu32[rb], vget_high_s8(q8_3), vget_high_s8(yq8_3)); + accu32[rb] = vmlal_s8(accu32[rb], vget_low_s8(q8_2), vget_low_s8(yq8_2)); + accu32[rb] = vmlal_s8(accu32[rb], vget_high_s8(q8_2), vget_high_s8(yq8_2)); + accu32[rb] = vmlal_s8(accu32[rb], vget_low_s8(q8_1), vget_low_s8(yq8_1)); + accu32[rb] = vmlal_s8(accu32[rb], vget_high_s8(q8_1), vget_high_s8(yq8_1)); + accu32[rb] = vmlal_s8(accu32[rb], vget_low_s8(q8_0), vget_low_s8(yq8_0)); + accu32[rb] = vmlal_s8(accu32[rb], vget_high_s8(q8_0), vget_high_s8(yq8_0)); + +#endif + px[rb] += 16; + } + } + +#if defined(__ARM_FEATURE_DOTPROD) + +#else + for (int rb = 0; rb < PARALLEL_SIZE; rb++) { + accu[rb] = vaddq_s32(accu[rb], vmovl_s16(vget_low_s16(accu32[rb]))); + accu[rb] = vaddq_s32(accu[rb], vmovl_high_s16(accu32[rb])); + } #endif } + for (int i = 0; i < groupla_num; i++) { #if defined(__ARM_FEATURE_DOTPROD) #else - accu_0 = vaddq_s32(accu_0, vmovl_s16(vget_low_s16(accu32_0))); - accu_0 = vaddq_s32(accu_0, vmovl_high_s16(accu32_0)); - accu_1 = vaddq_s32(accu_1, vmovl_s16(vget_low_s16(accu32_1))); - accu_1 = vaddq_s32(accu_1, vmovl_high_s16(accu32_1)); - accu_2 = vaddq_s32(accu_2, vmovl_s16(vget_low_s16(accu32_2))); - accu_2 = vaddq_s32(accu_2, vmovl_high_s16(accu32_2)); - accu_3 = vaddq_s32(accu_3, vmovl_s16(vget_low_s16(accu32_3))); - accu_3 = vaddq_s32(accu_3, vmovl_high_s16(accu32_3)); + int16x8_t accula[PARALLEL_SIZE]; + for (int rb = 0; rb < PARALLEL_SIZE; rb++) { + accula[rb] = vdupq_n_s16(0); + } #endif + const uint8_t * px[PARALLEL_SIZE]; + for (int rb = 0; rb < PARALLEL_SIZE; rb++) { + px[rb] = x_row[rb] + group32_num * 32 * 16; + } + + for (int j = 0; j < la_num; j++) { + // 加载 y 数据(对所有行共享) + const int8x16_t yq8_0 = vld1q_s8(y + group32_num * 32 * 64 + j * 64 + 0); + const int8x16_t yq8_1 = vld1q_s8(y + group32_num * 32 * 64 + j * 64 + 16); + const int8x16_t yq8_2 = vld1q_s8(y + group32_num * 32 * 64 + j * 64 + 32); + const int8x16_t yq8_3 = vld1q_s8(y + group32_num * 32 * 64 + j * 64 + 48); + + // 处理每一行 + for (int rb = 0; rb < PARALLEL_SIZE; rb++) { + uint8x16_t xq8_3 = vld1q_u8(px[rb] + 0); + uint8x16_t xq8_2 = vshrq_n_u8(xq8_3, 2); + uint8x16_t xq8_1 = vshrq_n_u8(xq8_3, 4); + uint8x16_t xq8_0 = vshrq_n_u8(xq8_3, 6); + + int8x16_t q8_3 = vreinterpretq_s8_u8(vandq_u8(xq8_3, mask)); + int8x16_t q8_2 = vreinterpretq_s8_u8(vandq_u8(xq8_2, mask)); + int8x16_t q8_1 = vreinterpretq_s8_u8(vandq_u8(xq8_1, mask)); + int8x16_t q8_0 = vreinterpretq_s8_u8(vandq_u8(xq8_0, mask)); + +#if defined(__ARM_FEATURE_DOTPROD) + accu[rb] = vdotq_s32(accu[rb], q8_0, yq8_0); + accu[rb] = vdotq_s32(accu[rb], q8_1, yq8_1); + accu[rb] = vdotq_s32(accu[rb], q8_2, yq8_2); + accu[rb] = vdotq_s32(accu[rb], q8_3, yq8_3); +#else + accula[rb] = vmlal_s8(accula[rb], vget_low_s8(q8_3), vget_low_s8(yq8_3)); + accula[rb] = vmlal_s8(accula[rb], vget_high_s8(q8_3), vget_high_s8(yq8_3)); + accula[rb] = vmlal_s8(accula[rb], vget_low_s8(q8_2), vget_low_s8(yq8_2)); + accula[rb] = vmlal_s8(accula[rb], vget_high_s8(q8_2), vget_high_s8(yq8_2)); + accula[rb] = vmlal_s8(accula[rb], vget_low_s8(q8_1), vget_low_s8(yq8_1)); + accula[rb] = vmlal_s8(accula[rb], vget_high_s8(q8_1), vget_high_s8(yq8_1)); + accula[rb] = vmlal_s8(accula[rb], vget_low_s8(q8_0), vget_low_s8(yq8_0)); + accula[rb] = vmlal_s8(accula[rb], vget_high_s8(q8_0), vget_high_s8(yq8_0)); + +#endif + px[rb] += 16; + } + } + +#if defined(__ARM_FEATURE_DOTPROD) + +#else + for (int rb = 0; rb < PARALLEL_SIZE; rb++) { + accu[rb] = vaddq_s32(accu[rb], vmovl_s16(vget_low_s16(accula[rb]))); + accu[rb] = vaddq_s32(accu[rb], vmovl_high_s16(accula[rb])); + } +#endif + } + + // 合并结果并写回 + for (int rb = 0; rb < PARALLEL_SIZE; rb++) { + int sumi = vaddlvq_s32(accu[rb]); + s[row + rb] = (float)sumi; + } } +#endif +} + +void ggml_vec_dot_i2_i8_s_Nx1(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { +#if defined(__AVX2__) + const uint8_t * x = (uint8_t *)vx; + const int8_t * y = (int8_t *)vy; + + const int nb = n / QK_I2_S; + const int group32_num = nb / 32; + const int la_num = nb % 32; + const int groupla_num = nb % 32 != 0 ? 1 : 0; + + __m256i mask = _mm256_set1_epi8(0x03); + __m256i one16 = _mm256_set1_epi16(1); + + for (int col = 0; col < nrc; col += PARALLEL_SIZE) { + __m256i accu[PARALLEL_SIZE]; + + for (int iy = 0; iy < PARALLEL_SIZE; iy++) { + accu[iy] = _mm256_setzero_si256(); + } + + int8_t * y_col = y + col * by; + + for (int i = 0; i < group32_num; i++) { + const uint8_t *px = x + i * 1024; + const int8_t *py = y_col + i * 4096; + __m256i accu32[PARALLEL_SIZE]; + + for (int iy = 0; iy < PARALLEL_SIZE; iy++) { + accu32[iy] = _mm256_setzero_si256(); + } + + for (int j = 0; j < 32; j++) { + + __m256i xq8 = _mm256_loadu_si256((const __m256i*)(px)); + __m256i xq8_3 = _mm256_and_si256(xq8, mask); + __m256i xq8_2 = _mm256_and_si256(_mm256_srli_epi16(xq8, 2), mask); + __m256i xq8_1 = _mm256_and_si256(_mm256_srli_epi16(xq8, 4), mask); + __m256i xq8_0 = _mm256_and_si256(_mm256_srli_epi16(xq8, 6), mask); + + for (int iy = 0; iy < PARALLEL_SIZE; iy++) + { + accu32[iy] = _mm256_add_epi16(accu32[iy], _mm256_add_epi16( + _mm256_add_epi16(_mm256_maddubs_epi16(xq8_0, _mm256_loadu_si256((const __m256i*)(py + 0 * 32 + iy * by))), + _mm256_maddubs_epi16(xq8_1, _mm256_loadu_si256((const __m256i*)(py + 1 * 32 + iy * by)))), + _mm256_add_epi16(_mm256_maddubs_epi16(xq8_2, _mm256_loadu_si256((const __m256i*)(py + 2 * 32 + iy * by))), + _mm256_maddubs_epi16(xq8_3, _mm256_loadu_si256((const __m256i*)(py + 3 * 32 + iy * by)))))); + } + + px += 32; + py += 128; + } + + for (int iy = 0; iy < PARALLEL_SIZE; iy++) { + accu[iy] = _mm256_add_epi32(_mm256_madd_epi16(accu32[iy], one16), accu[iy]); + } + } + + for (int i = 0; i < groupla_num; i++) { + const uint8_t *px = x + group32_num * 1024; + const int8_t *py = y_col + group32_num * 4096; + __m256i accula[PARALLEL_SIZE]; + + for (int iy = 0; iy < PARALLEL_SIZE; iy++) { + accula[iy] = _mm256_setzero_si256(); + } + + for (int j = 0; j < la_num; j++) { + + __m256i xq8 = _mm256_loadu_si256((const __m256i*)(px)); + __m256i xq8_3 = _mm256_and_si256(xq8, mask); + __m256i xq8_2 = _mm256_and_si256(_mm256_srli_epi16(xq8, 2), mask); + __m256i xq8_1 = _mm256_and_si256(_mm256_srli_epi16(xq8, 4), mask); + __m256i xq8_0 = _mm256_and_si256(_mm256_srli_epi16(xq8, 6), mask); + + for (int iy = 0; iy < PARALLEL_SIZE; iy++) + { + accula[iy] = _mm256_add_epi16(accula[iy], _mm256_add_epi16( + _mm256_add_epi16(_mm256_maddubs_epi16(xq8_0, _mm256_loadu_si256((const __m256i*)(py + 0 * 32 + iy * by))), + _mm256_maddubs_epi16(xq8_1, _mm256_loadu_si256((const __m256i*)(py + 1 * 32 + iy * by)))), + _mm256_add_epi16(_mm256_maddubs_epi16(xq8_2, _mm256_loadu_si256((const __m256i*)(py + 2 * 32 + iy * by))), + _mm256_maddubs_epi16(xq8_3, _mm256_loadu_si256((const __m256i*)(py + 3 * 32 + iy * by)))))); + } + + px += 32; + py += 128; + } + + for (int iy = 0; iy < PARALLEL_SIZE; iy++) { + accu[iy] = _mm256_add_epi32(_mm256_madd_epi16(accula[iy], one16), accu[iy]); + } + } + + for (int iy = 0; iy < PARALLEL_SIZE; iy++) { + int sumi = hsum_i32_8(accu[iy]); + s[(col + iy) * bs] = (float)sumi; + } + } +#elif defined(__ARM_NEON) + const uint8_t * x = (uint8_t *)vx; + const int8_t * y = (int8_t *)vy; + + const int nb = n / QK_I2_S; + const int group32_num = nb / 32; + const int la_num = nb % 32; + const int groupla_num = nb % 32 != 0 ? 1 : 0; + + const uint8x16_t mask = vdupq_n_u8(3); + + for (int col = 0; col < nrc; col += PARALLEL_SIZE) { + int32x4_t accu[PARALLEL_SIZE]; + + for (int iy = 0; iy < PARALLEL_SIZE; iy++) { + accu[iy] = vdupq_n_s32(0); + } + + const int8_t * y_col = y + col * by; + + for (int i = 0; i < group32_num; i++) { + const uint8_t *px = x + i * 512; // i * 32 * 16 + const int8_t *py = y_col + i * 2048; // i * 32 * 64 - for (int i = 0; i < groupla_num; i++){ #if defined(__ARM_FEATURE_DOTPROD) #else - int16x8_t accula_0 = vdupq_n_s16(0); - int16x8_t accula_1 = vdupq_n_s16(0); - int16x8_t accula_2 = vdupq_n_s16(0); - int16x8_t accula_3 = vdupq_n_s16(0); + int16x8_t accu32[PARALLEL_SIZE]; + + for (int iy = 0; iy < PARALLEL_SIZE; iy++) { + accu32[iy] = vdupq_n_s16(0); + } +#endif + for (int j = 0; j < 32; j++) { + // 加载并解包 x 数据(对所有列共享) + uint8x16_t xq8_3 = vld1q_u8(px + 0); + uint8x16_t xq8_2 = vshrq_n_u8(xq8_3, 2); + uint8x16_t xq8_1 = vshrq_n_u8(xq8_3, 4); + uint8x16_t xq8_0 = vshrq_n_u8(xq8_3, 6); + + int8x16_t q8_0 = vreinterpretq_s8_u8(vandq_u8(xq8_0, mask)); + int8x16_t q8_1 = vreinterpretq_s8_u8(vandq_u8(xq8_1, mask)); + int8x16_t q8_2 = vreinterpretq_s8_u8(vandq_u8(xq8_2, mask)); + int8x16_t q8_3 = vreinterpretq_s8_u8(vandq_u8(xq8_3, mask)); + + // 处理每一列 + for (int iy = 0; iy < PARALLEL_SIZE; iy++) { + const int8x16_t yq8_0 = vld1q_s8(py + 0 * 16 + iy * by); + const int8x16_t yq8_1 = vld1q_s8(py + 1 * 16 + iy * by); + const int8x16_t yq8_2 = vld1q_s8(py + 2 * 16 + iy * by); + const int8x16_t yq8_3 = vld1q_s8(py + 3 * 16 + iy * by); + +#if defined(__ARM_FEATURE_DOTPROD) + accu[iy] = vdotq_s32(accu[iy], q8_0, yq8_0); + accu[iy] = vdotq_s32(accu[iy], q8_1, yq8_1); + accu[iy] = vdotq_s32(accu[iy], q8_2, yq8_2); + accu[iy] = vdotq_s32(accu[iy], q8_3, yq8_3); +#else + accu32[iy] = vmlal_s8(accu32[iy], vget_low_s8(q8_0), vget_low_s8(yq8_0)); + accu32[iy] = vmlal_s8(accu32[iy], vget_high_s8(q8_0), vget_high_s8(yq8_0)); + accu32[iy] = vmlal_s8(accu32[iy], vget_low_s8(q8_1), vget_low_s8(yq8_1)); + accu32[iy] = vmlal_s8(accu32[iy], vget_high_s8(q8_1), vget_high_s8(yq8_1)); + accu32[iy] = vmlal_s8(accu32[iy], vget_low_s8(q8_2), vget_low_s8(yq8_2)); + accu32[iy] = vmlal_s8(accu32[iy], vget_high_s8(q8_2), vget_high_s8(yq8_2)); + accu32[iy] = vmlal_s8(accu32[iy], vget_low_s8(q8_3), vget_low_s8(yq8_3)); + accu32[iy] = vmlal_s8(accu32[iy], vget_high_s8(q8_3), vget_high_s8(yq8_3)); #endif - for (int j = 0; j < la_num; j++) { - uint8x16_t xq8_6 = vld1q_u8(x + group32_num * 32 * 32 + j * 32); - uint8x16_t xq8_7 = vld1q_u8(x + group32_num * 32 * 32 + j * 32 + 16); - uint8x16_t xq8_4 = vshrq_n_u8(xq8_6, 2); - uint8x16_t xq8_5 = vshrq_n_u8(xq8_7, 2); - uint8x16_t xq8_2 = vshrq_n_u8(xq8_6, 4); - uint8x16_t xq8_3 = vshrq_n_u8(xq8_7, 4); - uint8x16_t xq8_0 = vshrq_n_u8(xq8_6, 6); - uint8x16_t xq8_1 = vshrq_n_u8(xq8_7, 6); - - int8x16_t q8_0 = vreinterpretq_s8_u8(vandq_u8(xq8_0, mask)); - int8x16_t q8_1 = vreinterpretq_s8_u8(vandq_u8(xq8_1, mask)); - int8x16_t q8_2 = vreinterpretq_s8_u8(vandq_u8(xq8_2, mask)); - int8x16_t q8_3 = vreinterpretq_s8_u8(vandq_u8(xq8_3, mask)); - int8x16_t q8_4 = vreinterpretq_s8_u8(vandq_u8(xq8_4, mask)); - int8x16_t q8_5 = vreinterpretq_s8_u8(vandq_u8(xq8_5, mask)); - int8x16_t q8_6 = vreinterpretq_s8_u8(vandq_u8(xq8_6, mask)); - int8x16_t q8_7 = vreinterpretq_s8_u8(vandq_u8(xq8_7, mask)); - - const int8x16_t yq8_0 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 0); - const int8x16_t yq8_1 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 16); - const int8x16_t yq8_2 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 32); - const int8x16_t yq8_3 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 48); - const int8x16_t yq8_4 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 64); - const int8x16_t yq8_5 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 80); - const int8x16_t yq8_6 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 96); - const int8x16_t yq8_7 = vld1q_s8(y + group32_num * 128 * 32 + j * 128 + 112); + } + + px += 16; + py += 64; + } #if defined(__ARM_FEATURE_DOTPROD) - accu_0 = vdotq_s32(accu_0, q8_0, yq8_0); - accu_1 = vdotq_s32(accu_1, q8_1, yq8_1); - accu_2 = vdotq_s32(accu_2, q8_2, yq8_2); - accu_3 = vdotq_s32(accu_3, q8_3, yq8_3); - accu_0 = vdotq_s32(accu_0, q8_4, yq8_4); - accu_1 = vdotq_s32(accu_1, q8_5, yq8_5); - accu_2 = vdotq_s32(accu_2, q8_6, yq8_6); - accu_3 = vdotq_s32(accu_3, q8_7, yq8_7); + #else - accula_0 = vmlal_s8(accula_0, vget_low_s8(q8_0), vget_low_s8(yq8_0)); - accula_1 = vmlal_s8(accula_1, vget_high_s8(q8_0), vget_high_s8(yq8_0)); - accula_2 = vmlal_s8(accula_2, vget_low_s8(q8_1), vget_low_s8(yq8_1)); - accula_3 = vmlal_s8(accula_3, vget_high_s8(q8_1), vget_high_s8(yq8_1)); - accula_0 = vmlal_s8(accula_0, vget_low_s8(q8_2), vget_low_s8(yq8_2)); - accula_1 = vmlal_s8(accula_1, vget_high_s8(q8_2), vget_high_s8(yq8_2)); - accula_2 = vmlal_s8(accula_2, vget_low_s8(q8_3), vget_low_s8(yq8_3)); - accula_3 = vmlal_s8(accula_3, vget_high_s8(q8_3), vget_high_s8(yq8_3)); - accula_0 = vmlal_s8(accula_0, vget_low_s8(q8_4), vget_low_s8(yq8_4)); - accula_1 = vmlal_s8(accula_1, vget_high_s8(q8_4), vget_high_s8(yq8_4)); - accula_2 = vmlal_s8(accula_2, vget_low_s8(q8_5), vget_low_s8(yq8_5)); - accula_3 = vmlal_s8(accula_3, vget_high_s8(q8_5), vget_high_s8(yq8_5)); - accula_0 = vmlal_s8(accula_0, vget_low_s8(q8_6), vget_low_s8(yq8_6)); - accula_1 = vmlal_s8(accula_1, vget_high_s8(q8_6), vget_high_s8(yq8_6)); - accula_2 = vmlal_s8(accula_2, vget_low_s8(q8_7), vget_low_s8(yq8_7)); - accula_3 = vmlal_s8(accula_3, vget_high_s8(q8_7), vget_high_s8(yq8_7)); + for (int iy = 0; iy < PARALLEL_SIZE; iy++) { + accu[iy] = vaddq_s32(accu[iy], vaddq_s32(vmovl_high_s16(accu32[iy]), vmovl_s16(vget_low_s16(accu32[iy])))); + } #endif } + + for (int i = 0; i < groupla_num; i++) { + const uint8_t *px = x + group32_num * 512; + const int8_t *py = y_col + group32_num * 2048; + +#if defined(__ARM_FEATURE_DOTPROD) + +#else + int16x8_t accula[PARALLEL_SIZE]; + + for (int iy = 0; iy < PARALLEL_SIZE; iy++) { + accula[iy] = vdupq_n_s16(0); + } +#endif + + for (int j = 0; j < la_num; j++) { + // 加载并解包 x 数据(对所有列共享) + uint8x16_t xq8_3 = vld1q_u8(px + 0); + uint8x16_t xq8_2 = vshrq_n_u8(xq8_3, 2); + uint8x16_t xq8_1 = vshrq_n_u8(xq8_3, 4); + uint8x16_t xq8_0 = vshrq_n_u8(xq8_3, 6); + + int8x16_t q8_0 = vreinterpretq_s8_u8(vandq_u8(xq8_0, mask)); + int8x16_t q8_1 = vreinterpretq_s8_u8(vandq_u8(xq8_1, mask)); + int8x16_t q8_2 = vreinterpretq_s8_u8(vandq_u8(xq8_2, mask)); + int8x16_t q8_3 = vreinterpretq_s8_u8(vandq_u8(xq8_3, mask)); + + // 处理每一列 + for (int iy = 0; iy < PARALLEL_SIZE; iy++) { + const int8x16_t yq8_0 = vld1q_s8(py + 0 * 16 + iy * by); + const int8x16_t yq8_1 = vld1q_s8(py + 1 * 16 + iy * by); + const int8x16_t yq8_2 = vld1q_s8(py + 2 * 16 + iy * by); + const int8x16_t yq8_3 = vld1q_s8(py + 3 * 16 + iy * by); + +#if defined(__ARM_FEATURE_DOTPROD) + accu[iy] = vdotq_s32(accu[iy], q8_0, yq8_0); + accu[iy] = vdotq_s32(accu[iy], q8_1, yq8_1); + accu[iy] = vdotq_s32(accu[iy], q8_2, yq8_2); + accu[iy] = vdotq_s32(accu[iy], q8_3, yq8_3); +#else + accula[iy] = vmlal_s8(accula[iy], vget_low_s8(q8_0), vget_low_s8(yq8_0)); + accula[iy] = vmlal_s8(accula[iy], vget_high_s8(q8_0), vget_high_s8(yq8_0)); + accula[iy] = vmlal_s8(accula[iy], vget_low_s8(q8_1), vget_low_s8(yq8_1)); + accula[iy] = vmlal_s8(accula[iy], vget_high_s8(q8_1), vget_high_s8(yq8_1)); + accula[iy] = vmlal_s8(accula[iy], vget_low_s8(q8_2), vget_low_s8(yq8_2)); + accula[iy] = vmlal_s8(accula[iy], vget_high_s8(q8_2), vget_high_s8(yq8_2)); + accula[iy] = vmlal_s8(accula[iy], vget_low_s8(q8_3), vget_low_s8(yq8_3)); + accula[iy] = vmlal_s8(accula[iy], vget_high_s8(q8_3), vget_high_s8(yq8_3)); +#endif + } + + px += 16; + py += 64; + } + #if defined(__ARM_FEATURE_DOTPROD) #else - accu_0 = vaddq_s32(accu_0, vmovl_s16(vget_low_s16(accula_0))); - accu_0 = vaddq_s32(accu_0, vmovl_high_s16(accula_0)); - accu_1 = vaddq_s32(accu_1, vmovl_s16(vget_low_s16(accula_1))); - accu_1 = vaddq_s32(accu_1, vmovl_high_s16(accula_1)); - accu_2 = vaddq_s32(accu_2, vmovl_s16(vget_low_s16(accula_2))); - accu_2 = vaddq_s32(accu_2, vmovl_high_s16(accula_2)); - accu_3 = vaddq_s32(accu_3, vmovl_s16(vget_low_s16(accula_3))); - accu_3 = vaddq_s32(accu_3, vmovl_high_s16(accula_3)); + for (int iy = 0; iy < PARALLEL_SIZE; iy++) { + accu[iy] = vaddq_s32(accu[iy], vaddq_s32(vmovl_high_s16(accula[iy]), vmovl_s16(vget_low_s16(accula[iy])))); + } #endif + } + + // 合并结果并写回 + for (int iy = 0; iy < PARALLEL_SIZE; iy++) { + int sumi = vaddlvq_s32(accu[iy]); + s[(col + iy) * bs] = (float)sumi; + } } - accu_0 = vaddq_s32(accu_0, accu_1); - accu_2 = vaddq_s32(accu_2, accu_3); - accu_0 = vaddq_s32(accu_0, accu_2); - int sumi = vaddlvq_s32(accu_0); - *s = (float)sumi; +#endif +} + +void ggml_vec_dot_i2_i8_s(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc) { + if (nrc % PARALLEL_SIZE == 0) + { +#if defined(ACT_PARALLEL) + ggml_vec_dot_i2_i8_s_Nx1(n, s, bs, vx, bx, vy, by, nrc); +#else + ggml_vec_dot_i2_i8_s_1xN(n, s, bs, vx, bx, vy, by, nrc); #endif + } + else + { + ggml_vec_dot_i2_i8_s_1x1(n, s, bs, vx, bx, vy, by, nrc); + } } \ No newline at end of file diff --git a/utils/convert-helper-bitnet.py b/utils/convert-helper-bitnet.py new file mode 100644 index 000000000..9ed8db013 --- /dev/null +++ b/utils/convert-helper-bitnet.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python3 + +import sys +import os +import shutil +import subprocess +from pathlib import Path + +def run_command(command_list, cwd=None, check=True): + print(f"Executing: {' '.join(map(str, command_list))}") + try: + process = subprocess.run(command_list, cwd=cwd, check=check, capture_output=False, text=True) + return process + except subprocess.CalledProcessError as e: + print(f"Error executing command: {' '.join(map(str, e.cmd))}") + print(f"Return code: {e.returncode}") + raise + +def main(): + if len(sys.argv) < 2: + script_name = Path(sys.argv[0]).name + print(f"Usage: python {script_name} ") + sys.exit(1) + + model_dir_arg = sys.argv[1] + model_dir = Path(model_dir_arg).resolve() + + if not model_dir.is_dir(): + print(f"Error: Model directory '{model_dir}' not found or is not a directory.") + sys.exit(1) + + utils_dir = Path(__file__).parent.resolve() + project_root_dir = utils_dir.parent + + preprocess_script = utils_dir / "preprocess-huggingface-bitnet.py" + convert_script = utils_dir / "convert-ms-to-gguf-bitnet.py" + + llama_quantize_binary = project_root_dir / "build" / "bin" / "llama-quantize" + + input_file = model_dir / "model.safetensors" + input_backup_file = model_dir / "model.safetensors.backup" + preprocessed_output_file = model_dir / "model.safetensors" + + gguf_f32_output = model_dir / "ggml-model-f32-bitnet.gguf" + gguf_i2s_output = model_dir / "ggml-model-i2s-bitnet.gguf" + + if not preprocess_script.is_file(): + print(f"Error: Preprocess script not found at '{preprocess_script}'") + sys.exit(1) + if not convert_script.is_file(): + print(f"Error: Convert script not found at '{convert_script}'") + sys.exit(1) + if not llama_quantize_binary.is_file(): + print(f"Error: llama-quantize binary not found at '{llama_quantize_binary}'") + sys.exit(1) + + if not input_file.is_file(): + print(f"Error: Input safetensors file not found at '{input_file}'") + sys.exit(1) + + try: + print(f"Backing up '{input_file}' to '{input_backup_file}'") + if input_backup_file.exists(): + print(f"Warning: Removing existing backup file '{input_backup_file}'") + input_backup_file.unlink() + shutil.move(input_file, input_backup_file) + + print("Preprocessing huggingface checkpoint...") + cmd_preprocess = [ + sys.executable, + str(preprocess_script), + "--input", str(input_backup_file), + "--output", str(preprocessed_output_file) + ] + run_command(cmd_preprocess) + + print("Converting to GGUF (f32)...") + cmd_convert = [ + sys.executable, + str(convert_script), + str(model_dir), + "--vocab-type", "bpe", + "--outtype", "f32", + "--concurrency", "1", + "--outfile", str(gguf_f32_output) + ] + run_command(cmd_convert) + + print("Quantizing model to I2_S...") + cmd_quantize = [ + str(llama_quantize_binary), + str(gguf_f32_output), + str(gguf_i2s_output), + "I2_S", + "1" + ] + run_command(cmd_quantize) + + print("Convert successfully.") + + except Exception as e: + print(f"An error occurred: {e}") + finally: + print("Cleaning up intermediate files...") + if preprocessed_output_file.exists() and preprocessed_output_file != input_backup_file: + print(f"Removing preprocessed file: {preprocessed_output_file}") + try: + preprocessed_output_file.unlink() + except OSError as e: + print(f"Warning: Could not remove {preprocessed_output_file}: {e}") + + # if gguf_f32_output.exists(): + # print(f"Removing f32 GGUF: {gguf_f32_output}") + # try: + # gguf_f32_output.unlink() + # except OSError as e: + # print(f"Warning: Could not remove {gguf_f32_output}: {e}") + + if input_backup_file.exists(): + if not input_file.exists(): + print(f"Restoring original '{input_file}' from '{input_backup_file}'") + try: + shutil.move(input_backup_file, input_file) + except Exception as e: + print(f"Warning: Could not restore {input_file} from backup: {e}") + else: + print(f"Removing backup '{input_backup_file}' as original '{input_file}' should be present.") + try: + input_backup_file.unlink() + except OSError as e: + print(f"Warning: Could not remove backup {input_backup_file}: {e}") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/utils/convert-hf-to-gguf-bitnet.py b/utils/convert-hf-to-gguf-bitnet.py index f525f58f8..23e84384c 100644 --- a/utils/convert-hf-to-gguf-bitnet.py +++ b/utils/convert-hf-to-gguf-bitnet.py @@ -319,6 +319,9 @@ def get_vocab_base_pre(self, tokenizer) -> str: if chkhsh == "8aeee3860c56296a157a1fe2fad249ec40aa59b1bb5709f4ade11c4e6fe652ed": # ref: https://huggingface.co/tiiuae/falcon-7b res = "falcon" + if chkhsh == "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6": + # ref: https://huggingface.co/tiiuae/Falcon-E-3B-Instruct + res = "falcon_e" if chkhsh == "0876d13b50744004aa9aeae05e7b0647eac9d801b5ba4668afc01e709c15e19f": # ref: https://huggingface.co/BAAI/bge-small-en-v1.5 res = "bert-bge" diff --git a/utils/convert-ms-to-gguf-bitnet.py b/utils/convert-ms-to-gguf-bitnet.py index 23a1a2c89..edf702788 100644 --- a/utils/convert-ms-to-gguf-bitnet.py +++ b/utils/convert-ms-to-gguf-bitnet.py @@ -12,14 +12,12 @@ import math import mmap import os -import pickle import re import signal import struct import sys import textwrap import time -import zipfile from abc import ABC, abstractmethod from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor from dataclasses import dataclass @@ -945,7 +943,6 @@ def load() -> Tensor: import torch -@torch.compile def forward_t(x): dtype = x.dtype x = x.float() @@ -956,7 +953,8 @@ def forward_t(x): def weight_quant(weight): weight = torch.tensor(weight, dtype=torch.float32) weight = forward_t(weight) - weight = weight.numpy().astype(np.float32) + # Use tolist() then convert to numpy to avoid PyTorch-NumPy compatibility issues + weight = np.array(weight.tolist(), dtype=np.float32) return weight def part_lazy_q(lazy_tensor: LazyTensor, n_part: int) -> LazyTensor: @@ -1028,98 +1026,6 @@ def load() -> Tensor: return LazyTensor(load, s, lazy_tensors[0].data_type, 'pack_experts ' + ' | '.join(lt.description for lt in lazy_tensors)) -# Functionality that simulates `torch.load` but where individual tensors are -# only loaded into memory on demand, not all at once. -# PyTorch can't do this natively as of time of writing: -# - https://github.com/pytorch/pytorch/issues/64327 -# This allows us to de-shard without multiplying RAM usage, and also -# conveniently drops the PyTorch dependency (though we still need numpy). - - -@dataclass -class LazyStorageKind: - data_type: DataType - - -@dataclass -class LazyStorage: - load: Callable[[int, int], NDArray] - kind: LazyStorageKind - description: str - - -class LazyUnpickler(pickle.Unpickler): - def __init__(self, fp: IO[bytes], data_base_path: str, zip_file: zipfile.ZipFile): - super().__init__(fp) - self.data_base_path = data_base_path - self.zip_file = zip_file - - def persistent_load(self, pid: Any) -> Any: - assert pid[0] == 'storage' - assert isinstance(pid[1], LazyStorageKind) - data_type = pid[1].data_type - filename_stem = pid[2] - filename = f'{self.data_base_path}/{filename_stem}' - info = self.zip_file.getinfo(filename) - - def load(offset: int, elm_count: int) -> NDArray: - dtype = data_type.dtype - with self.zip_file.open(info) as fp: - fp.seek(offset * dtype.itemsize) - size = elm_count * dtype.itemsize - data = fp.read(size) - assert len(data) == size - return np.frombuffer(data, dtype) - description = f'storage data_type={data_type} path-in-zip={filename} path={self.zip_file.filename}' - return LazyStorage(load=load, kind=pid[1], description=description) - - @staticmethod - def lazy_rebuild_tensor_v2(storage: Any, storage_offset: Any, size: Any, stride: Any, - requires_grad: Any, backward_hooks: Any, metadata: Any = None) -> LazyTensor: - assert isinstance(storage, LazyStorage) - - def load() -> UnquantizedTensor: - elm_count = stride[0] * size[0] - return UnquantizedTensor(storage.load(storage_offset, elm_count).reshape(size)) - description = f'pickled storage_offset={storage_offset} in {storage.description}' - return LazyTensor(load, list(size), storage.kind.data_type, description) - - @staticmethod - def rebuild_from_type_v2(func, new_type, args, state): - return func(*args) - - CLASSES = { - # getattr used here as a workaround for mypy not being smart enough to determine - # the staticmethods have a __func__ attribute. - ('torch._tensor', '_rebuild_from_type_v2'): getattr(rebuild_from_type_v2, '__func__'), - ('torch._utils', '_rebuild_tensor_v2'): getattr(lazy_rebuild_tensor_v2, '__func__'), - ('torch', 'BFloat16Storage'): LazyStorageKind(DT_BF16), - ('torch', 'HalfStorage'): LazyStorageKind(DT_F16), - ('torch', 'FloatStorage'): LazyStorageKind(DT_F32), - ('torch', 'IntStorage'): LazyStorageKind(DT_I32), - ('torch', 'Tensor'): LazyTensor, - } - - def find_class(self, module: str, name: str) -> Any: - if not module.startswith('torch'): - return super().find_class(module, name) - return self.CLASSES[(module, name)] - - -def lazy_load_torch_file(outer_fp: IO[bytes], path: Path) -> ModelPlus: - zf = zipfile.ZipFile(outer_fp) - pickle_paths = [name for name in zf.namelist() if name.endswith('.pkl')] - assert len(pickle_paths) == 1, pickle_paths - pickle_fp = zf.open(pickle_paths[0], 'r') - unpickler = LazyUnpickler(pickle_fp, - data_base_path=pickle_paths[0][:-4], - zip_file=zf) - model = unpickler.load() - if 'model' in model: model = model['model'] - as_dict = dict(model.items()) - return ModelPlus(model=as_dict, paths=[path], format='torch', vocab=None) - - def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus: header_size, = struct.unpack(' ModelPlus: fp = open(path, 'rb') first8 = fp.read(8) fp.seek(0) - if first8[:2] == b'PK': - # A zip file, i.e. PyTorch format - return lazy_load_torch_file(fp, path) - elif struct.unpack(' # tmp[f"layers.{i_l}.feed_forward.experts.w{w}.weight"] = pack_experts_lazy(experts) # tmp[f"rope.freqs"] = part_lazy_rope(1.0 / (torch.tensor(500000) ** (torch.arange(0, 128, 2).float().to("cpu") / 128))) # 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) - rope_ndarray = (1.0 / (torch.tensor(500000.0) ** (torch.arange(0, 128, 2).float() / 128))).numpy().astype(np.float32) + # Use pure NumPy instead of torch to avoid NumPy compatibility issues + rope_ndarray = (1.0 / (np.float32(500000.0) ** (np.arange(0, 128, 2, dtype=np.float32) / 128))).astype(np.float32) # print(rope_ndarray) @@ -1580,7 +1487,7 @@ def load() -> UnquantizedTensor: out: LazyModel = {} for name, lazy_tensor in model.items(): - tensor_type, name_new = tmap.get_type_and_name(name, try_suffixes = (".weight", ".bias")) or (None, None) + tensor_type, name_new = tmap.get_type_and_name(name, try_suffixes = (".weight", ".bias", ".weight_scale")) or (None, None) if name_new is None: if skip_unknown: logger.info(f"Unexpected tensor name: {name} - skipping") @@ -1641,15 +1548,11 @@ def load_some_model(path: Path) -> ModelPlus: '''Load a model of any supported format.''' # Be extra-friendly and accept either a file or a directory: if path.is_dir(): - # Check if it's a set of safetensors files first - globs = ["model-00001-of-*.safetensors", "model.safetensors", "consolidated.safetensors", "model-int2.pth"] + # Check if it's a set of safetensors files + globs = ["model-00001-of-*.safetensors", "model.safetensors", "consolidated.safetensors"] files = [file for glob in globs for file in path.glob(glob)] if not files: - # Try the PyTorch patterns too, with lower priority - globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt", "pytorch_model.bin"] - files = [file for glob in globs for file in path.glob(glob)] - if not files: - raise FileNotFoundError(f"Can't find model in directory {path}") + raise FileNotFoundError(f"Can't find safetensors model in directory {path}") if len(files) > 1: raise ValueError(f"Found multiple models in {path}, not sure which to pick: {files}") path = files[0] @@ -1741,7 +1644,7 @@ def do_dump_model(model_plus: ModelPlus) -> None: def main(args_in: list[str] | None = None) -> None: output_choices = ["f32", "f16", "i2"] - if np.uint32(1) == np.uint32(1).newbyteorder("<"): + if sys.byteorder == "little": # We currently only support Q8_0 output on little endian systems. output_choices.append("q8_0") parser = argparse.ArgumentParser(description="Convert a LLaMA model to a GGML compatible file") @@ -1849,4 +1752,4 @@ def main(args_in: list[str] | None = None) -> None: if __name__ == '__main__': - main() + main() \ No newline at end of file diff --git a/utils/convert.py b/utils/convert.py index 5938c42f2..f0f689724 100644 --- a/utils/convert.py +++ b/utils/convert.py @@ -12,14 +12,12 @@ import math import mmap import os -import pickle import re import signal import struct import sys import textwrap import time -import zipfile from abc import ABC, abstractmethod from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor from dataclasses import dataclass @@ -954,98 +952,6 @@ def load() -> Tensor: return LazyTensor(load, s, lazy_tensors[0].data_type, 'pack_experts ' + ' | '.join(lt.description for lt in lazy_tensors)) -# Functionality that simulates `torch.load` but where individual tensors are -# only loaded into memory on demand, not all at once. -# PyTorch can't do this natively as of time of writing: -# - https://github.com/pytorch/pytorch/issues/64327 -# This allows us to de-shard without multiplying RAM usage, and also -# conveniently drops the PyTorch dependency (though we still need numpy). - - -@dataclass -class LazyStorageKind: - data_type: DataType - - -@dataclass -class LazyStorage: - load: Callable[[int, int], NDArray] - kind: LazyStorageKind - description: str - - -class LazyUnpickler(pickle.Unpickler): - def __init__(self, fp: IO[bytes], data_base_path: str, zip_file: zipfile.ZipFile): - super().__init__(fp) - self.data_base_path = data_base_path - self.zip_file = zip_file - - def persistent_load(self, pid: Any) -> Any: - assert pid[0] == 'storage' - assert isinstance(pid[1], LazyStorageKind) - data_type = pid[1].data_type - filename_stem = pid[2] - filename = f'{self.data_base_path}/{filename_stem}' - info = self.zip_file.getinfo(filename) - - def load(offset: int, elm_count: int) -> NDArray: - dtype = data_type.dtype - with self.zip_file.open(info) as fp: - fp.seek(offset * dtype.itemsize) - size = elm_count * dtype.itemsize - data = fp.read(size) - assert len(data) == size - return np.frombuffer(data, dtype) - description = f'storage data_type={data_type} path-in-zip={filename} path={self.zip_file.filename}' - return LazyStorage(load=load, kind=pid[1], description=description) - - @staticmethod - def lazy_rebuild_tensor_v2(storage: Any, storage_offset: Any, size: Any, stride: Any, - requires_grad: Any, backward_hooks: Any, metadata: Any = None) -> LazyTensor: - assert isinstance(storage, LazyStorage) - - def load() -> UnquantizedTensor: - elm_count = stride[0] * size[0] - return UnquantizedTensor(storage.load(storage_offset, elm_count).reshape(size)) - description = f'pickled storage_offset={storage_offset} in {storage.description}' - return LazyTensor(load, list(size), storage.kind.data_type, description) - - @staticmethod - def rebuild_from_type_v2(func, new_type, args, state): - return func(*args) - - CLASSES = { - # getattr used here as a workaround for mypy not being smart enough to determine - # the staticmethods have a __func__ attribute. - ('torch._tensor', '_rebuild_from_type_v2'): getattr(rebuild_from_type_v2, '__func__'), - ('torch._utils', '_rebuild_tensor_v2'): getattr(lazy_rebuild_tensor_v2, '__func__'), - ('torch', 'BFloat16Storage'): LazyStorageKind(DT_BF16), - ('torch', 'HalfStorage'): LazyStorageKind(DT_F16), - ('torch', 'FloatStorage'): LazyStorageKind(DT_F32), - ('torch', 'IntStorage'): LazyStorageKind(DT_I32), - ('torch', 'Tensor'): LazyTensor, - } - - def find_class(self, module: str, name: str) -> Any: - if not module.startswith('torch'): - return super().find_class(module, name) - return self.CLASSES[(module, name)] - - -def lazy_load_torch_file(outer_fp: IO[bytes], path: Path) -> ModelPlus: - zf = zipfile.ZipFile(outer_fp) - pickle_paths = [name for name in zf.namelist() if name.endswith('.pkl')] - assert len(pickle_paths) == 1, pickle_paths - pickle_fp = zf.open(pickle_paths[0], 'r') - unpickler = LazyUnpickler(pickle_fp, - data_base_path=pickle_paths[0][:-4], - zip_file=zf) - model = unpickler.load() - if 'model' in model: model = model['model'] - as_dict = dict(model.items()) - return ModelPlus(model=as_dict, paths=[path], format='torch', vocab=None) - - def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus: header_size, = struct.unpack(' ModelPlus: fp = open(path, 'rb') first8 = fp.read(8) fp.seek(0) - if first8[:2] == b'PK': - # A zip file, i.e. PyTorch format - return lazy_load_torch_file(fp, path) - elif struct.unpack(' ModelPlus: '''Load a model of any supported format.''' # Be extra-friendly and accept either a file or a directory: if path.is_dir(): - # Check if it's a set of safetensors files first - globs = ["model-00001-of-*.safetensors", "model.safetensors", "consolidated.safetensors", "model-int2.pth"] + # Check if it's a set of safetensors files + globs = ["model-00001-of-*.safetensors", "model.safetensors", "consolidated.safetensors"] files = [file for glob in globs for file in path.glob(glob)] if not files: - # Try the PyTorch patterns too, with lower priority - globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt", "pytorch_model.bin"] - files = [file for glob in globs for file in path.glob(glob)] - if not files: - raise FileNotFoundError(f"Can't find model in directory {path}") + raise FileNotFoundError(f"Can't find safetensors model in directory {path}") if len(files) > 1: raise ValueError(f"Found multiple models in {path}, not sure which to pick: {files}") path = files[0] @@ -1600,7 +1499,7 @@ def do_dump_model(model_plus: ModelPlus) -> None: def main(args_in: list[str] | None = None) -> None: output_choices = ["f32", "f16", "i2"] - if np.uint32(1) == np.uint32(1).newbyteorder("<"): + if sys.byteorder == "little": # We currently only support Q8_0 output on little endian systems. output_choices.append("q8_0") parser = argparse.ArgumentParser(description="Convert a LLaMA model to a GGML compatible file") diff --git a/utils/kernel_tuning.py b/utils/kernel_tuning.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/utils/preprocess-huggingface-bitnet.py b/utils/preprocess-huggingface-bitnet.py new file mode 100644 index 000000000..af75cd6df --- /dev/null +++ b/utils/preprocess-huggingface-bitnet.py @@ -0,0 +1,50 @@ +from safetensors import safe_open +from safetensors.torch import save_file +import torch + +def quant_weight_fp16(weight): + weight = weight.to(torch.float) + s = 1.0 / weight.abs().mean().clamp_(min=1e-5) + new_weight = (weight * s).round().clamp(-1, 1) / s + return new_weight + +def quant_model(input, output): + tensors = {} + + with safe_open(input, framework='pt') as f: + for name in f.keys(): + tensors[name] = f.get_tensor(name) + + keyword_list = [ + 'q_proj.weight', + 'k_proj.weight', + 'v_proj.weight', + 'o_proj.weight', + 'gate_proj.weight', + 'up_proj.weight', + 'down_proj.weight' + ] + + if any(keyword in name for keyword in keyword_list): + print(f'[INFO] Quantizing {name}') + tensors[name] = quant_weight_fp16(tensors[name]) + + print(f'[INFO] Saving to {output}\nThis may take a while.') + save_file(tensors, output) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description="Convert Safetensors back to Torch .pth checkpoint") + parser.add_argument( + "--input", type=str, required=True, + ) + parser.add_argument( + "--output", type=str, required=True, + ) + args = parser.parse_args() + + quant_model( + input=args.input, + output=args.output, + ) \ No newline at end of file diff --git a/utils/quantize_embeddings.py b/utils/quantize_embeddings.py new file mode 100644 index 000000000..90b802045 --- /dev/null +++ b/utils/quantize_embeddings.py @@ -0,0 +1,473 @@ +#!/usr/bin/env python3 +""" +Embedding Quantization Script +This script converts ggml-model-f32.gguf to multiple quantized versions +with different token embedding types. +""" + +import subprocess +import os +import argparse +import re +import csv +from pathlib import Path +from datetime import datetime + + +class EmbeddingQuantizer: + def __init__(self, input_model, output_dir, quantize_bin="../build/bin/llama-quantize", + bench_bin="../build/bin/llama-bench", stats_dir="../stats", csv_output=None): + self.input_model = Path(input_model) + self.output_dir = Path(output_dir) + self.quantize_bin = Path(quantize_bin) + self.bench_bin = Path(bench_bin) + self.stats_dir = Path(stats_dir) + self.csv_output = Path(csv_output) if csv_output else None + + # Verify input file exists + if not self.input_model.exists(): + raise FileNotFoundError(f"Input model not found: {self.input_model}") + + # Verify quantize tool exists + if not self.quantize_bin.exists(): + raise FileNotFoundError(f"Quantize binary not found: {self.quantize_bin}") + + # Verify bench tool exists + if not self.bench_bin.exists(): + raise FileNotFoundError(f"Benchmark binary not found: {self.bench_bin}") + + # Create output directories + self.output_dir.mkdir(parents=True, exist_ok=True) + self.stats_dir.mkdir(parents=True, exist_ok=True) + + self.results = [] + self.newly_created_files = set() # Track newly created files + + def quantize(self, embedding_type, output_suffix): + """ + Perform single quantization + + Args: + embedding_type: Token embedding type (uppercase format, e.g., Q6_K) + output_suffix: Output file suffix (lowercase format, e.g., q6_k) + + Returns: + bool: Whether successful + """ + output_file = self.output_dir / f"ggml-model-i2_s-embed-{output_suffix}.gguf" + + # Check if file already exists + file_already_existed = output_file.exists() + + if file_already_existed: + print(f"ℹ️ File already exists: {output_file}") + print(f" Skipping quantization, will use existing file for benchmark") + return True + + cmd = [ + str(self.quantize_bin), + "--token-embedding-type", embedding_type, + str(self.input_model), + str(output_file), + "I2_S", + "1", + "1" + ] + + print(f"\n{'='*80}") + print(f"🔄 Quantizing with embedding type: {embedding_type}") + print(f"📥 Input: {self.input_model}") + print(f"📤 Output: {output_file}") + print(f"💻 Command: {' '.join(cmd)}") + print(f"{'='*80}\n") + + start_time = datetime.now() + + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + cwd=os.getcwd(), + timeout=600 # 10 minute timeout + ) + + end_time = datetime.now() + duration = (end_time - start_time).total_seconds() + + if result.returncode == 0: + # Get output file size + file_size_mb = output_file.stat().st_size / (1024 * 1024) + + print(f"✅ Success! Duration: {duration:.2f}s, Size: {file_size_mb:.2f} MB") + + # Record newly created file + if not file_already_existed: + self.newly_created_files.add(output_file) + + # Print part of output + if result.stdout: + print("\n📊 Quantization output:") + print(result.stdout[-500:] if len(result.stdout) > 500 else result.stdout) + + return True + else: + print(f"❌ Failed with return code {result.returncode}") + print(f"Error: {result.stderr}") + return False + + except subprocess.TimeoutExpired: + print(f"❌ Timeout (exceeded 10 minutes)") + return False + + except Exception as e: + print(f"❌ Exception: {e}") + return False + + def benchmark_model(self, output_suffix): + """ + Benchmark model + + Args: + output_suffix: Output file suffix (lowercase format, e.g., q6_k) + + Returns: + dict: Dictionary with benchmark results, or None if failed + """ + model_file = self.output_dir / f"ggml-model-i2_s-embed-{output_suffix}.gguf" + + if not model_file.exists(): + print(f"❌ Model file not found for benchmarking: {model_file}") + return None + + cmd = [ + str(self.bench_bin), + "-m", str(model_file), + "-p", "128", + "-n", "0", + "-t", "1,2,4,8", + "-ngl", "0" + ] + + print(f"\n{'='*80}") + print(f"🏃 Running benchmark for: {output_suffix}") + print(f"💻 Command: {' '.join(cmd)}") + print(f"{'='*80}\n") + + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + cwd=os.getcwd(), + timeout=300 # 5 minute timeout + ) + + if result.returncode == 0: + print("✅ Benchmark completed successfully") + print("\n📊 Benchmark output:") + print(result.stdout) + + # 解析输出 + bench_results = self.parse_benchmark_output(result.stdout, output_suffix) + return bench_results + else: + print(f"❌ Benchmark failed with return code {result.returncode}") + print(f"Error: {result.stderr}") + return None + + except subprocess.TimeoutExpired: + print(f"❌ Benchmark timeout (exceeded 5 minutes)") + return None + + except Exception as e: + print(f"❌ Benchmark exception: {e}") + return None + + def parse_benchmark_output(self, output, output_suffix): + """ + Parse benchmark output to extract t/s data (mean±std) + + Args: + output: Benchmark command output + output_suffix: Output file suffix + + Returns: + dict: Dictionary with parsed results + """ + results = { + 'embedding_type': output_suffix, + 'threads_1': None, + 'threads_2': None, + 'threads_4': None, + 'threads_8': None, + } + + # Parse table data + # Find lines containing pp128 and t/s + lines = output.strip().split('\n') + + for line in lines: + # Skip header and separator lines + if '|' not in line or 'model' in line or '---' in line: + continue + + # Try to extract data + # Format similar to: | bitnet-25 2B I2_S - 2 bpw ternary | 1012.28 MiB | 2.74 B | CPU | 12 | pp128 | 405.73 ± 3.69 | + parts = [p.strip() for p in line.split('|')] + + if len(parts) >= 8 and 'pp128' in parts[6]: + threads_str = parts[5].strip() + throughput_str = parts[7].strip() + + # Extract thread count + try: + threads = int(threads_str) + except: + continue + + # Extract t/s data (format: "405.73 ± 3.69" or "405.73") + # Try to match "mean ± std" format + match_with_std = re.search(r'([\d.]+)\s*±\s*([\d.]+)', throughput_str) + if match_with_std: + mean = float(match_with_std.group(1)) + std = float(match_with_std.group(2)) + throughput = f"{mean:.2f}±{std:.2f}" + else: + # Only mean, no std + match = re.search(r'([\d.]+)', throughput_str) + if match: + throughput = f"{float(match.group(1)):.2f}" + else: + continue + + # Store result based on thread count + if threads == 1: + results['threads_1'] = throughput + elif threads == 2: + results['threads_2'] = throughput + elif threads == 4: + results['threads_4'] = throughput + elif threads == 8: + results['threads_8'] = throughput + + return results + + def cleanup_model(self, output_suffix): + """ + Cleanup model files (only delete newly created files) + + Args: + output_suffix: Output file suffix + """ + model_file = self.output_dir / f"ggml-model-i2_s-embed-{output_suffix}.gguf" + + if model_file in self.newly_created_files: + try: + model_file.unlink() + print(f"🗑️ Deleted newly created file: {model_file}") + self.newly_created_files.remove(model_file) + except Exception as e: + print(f"⚠️ Failed to delete {model_file}: {e}") + else: + print(f"ℹ️ Keeping existing file: {model_file}") + + def run_all_quantizations(self, types_to_quantize): + """ + Run all quantizations + + Args: + types_to_quantize: List of quantization types, tuples of (embedding_type, output_suffix) + """ + print(f"\n{'='*80}") + print(f"🚀 Starting Embedding Quantization and Benchmarking") + print(f"{'='*80}") + print(f"📥 Input model: {self.input_model}") + print(f"📤 Output directory: {self.output_dir}") + print(f"📊 Stats directory: {self.stats_dir}") + print(f"🔢 Total quantizations: {len(types_to_quantize)}") + print(f"{'='*80}\n") + + total_start = datetime.now() + + for i, (embedding_type, output_suffix) in enumerate(types_to_quantize, 1): + print(f"\n{'#'*80}") + print(f"[{i}/{len(types_to_quantize)}] Processing {output_suffix} ({embedding_type})") + print(f"{'#'*80}\n") + + # Quantize model + success = self.quantize(embedding_type, output_suffix) + + if not success: + print(f"⚠️ Skipping benchmark for {output_suffix} due to quantization failure") + continue + + # Run benchmark + bench_results = self.benchmark_model(output_suffix) + + if bench_results: + self.results.append(bench_results) + else: + print(f"⚠️ Benchmark failed for {output_suffix}") + + # Cleanup model files (only delete newly created files) + self.cleanup_model(output_suffix) + + print(f"\n{'#'*80}") + print(f"✅ Completed {output_suffix}") + print(f"{'#'*80}\n") + + total_end = datetime.now() + total_duration = (total_end - total_start).total_seconds() + + # 保存结果到CSV + self.save_results_to_csv() + + # 打印总结 + self.print_summary(total_duration) + + def save_results_to_csv(self): + """将benchmark结果保存到CSV文件""" + if not self.results: + print("⚠️ No results to save") + return + + # Use user-specified CSV path, otherwise use default path + if self.csv_output: + csv_file = self.csv_output + # Ensure parent directory exists + csv_file.parent.mkdir(parents=True, exist_ok=True) + else: + csv_file = self.stats_dir / f"embedding_benchmark.csv" + + print(f"\n💾 Saving results to: {csv_file}") + + try: + with open(csv_file, 'w', newline='') as f: + fieldnames = ['embedding_type', 'threads_1', 'threads_2', 'threads_4', 'threads_8'] + writer = csv.DictWriter(f, fieldnames=fieldnames) + + writer.writeheader() + for result in self.results: + writer.writerow(result) + + print(f"✅ Results saved successfully") + + # Also print table + print(f"\n📊 Benchmark Results:") + print(f"{'Type':<15} {'1 thread':<18} {'2 threads':<18} {'4 threads':<18} {'8 threads':<18}") + print("-" * 87) + for result in self.results: + t1 = result['threads_1'] if result['threads_1'] else "N/A" + t2 = result['threads_2'] if result['threads_2'] else "N/A" + t4 = result['threads_4'] if result['threads_4'] else "N/A" + t8 = result['threads_8'] if result['threads_8'] else "N/A" + print(f"{result['embedding_type']:<15} {t1:<18} {t2:<18} {t4:<18} {t8:<18}") + + except Exception as e: + print(f"❌ Failed to save results: {e}") + + def print_summary(self, total_duration): + """Print quantization summary""" + print(f"\n\n{'='*80}") + print(f"📊 QUANTIZATION AND BENCHMARK SUMMARY") + print(f"{'='*80}\n") + + successful = len(self.results) + total = len(self.results) + + print(f"✅ Completed: {successful} benchmarks") + print(f"⏱️ Total duration: {total_duration/60:.2f} minutes\n") + + if self.results: + if self.csv_output and self.csv_output.exists(): + print(f"📁 Results saved to: {self.csv_output}") + else: + csv_files = list(self.stats_dir.glob("embedding_benchmark*.csv")) + if csv_files: + latest_csv = max(csv_files, key=lambda p: p.stat().st_mtime) + print(f"📁 Results saved to: {latest_csv}") + + print(f"\n{'='*80}\n") + + +def main(): + parser = argparse.ArgumentParser(description='Quantize model embeddings to multiple formats') + parser.add_argument('--input', '-i', + default='../models/BitNet-b1.58-2B-4T/ggml-model-f32.gguf', + help='Input model path (default: ../models/BitNet-b1.58-2B-4T/ggml-model-f32.gguf)') + parser.add_argument('--output-dir', '-o', + default='../models/BitNet-b1.58-2B-4T', + help='Output directory (default: ../models/BitNet-b1.58-2B-4T)') + parser.add_argument('--quantize-bin', '-q', + default='../build/bin/llama-quantize', + help='Path to llama-quantize binary (default: ../build/bin/llama-quantize)') + parser.add_argument('--bench-bin', '-b', + default='../build/bin/llama-bench', + help='Path to llama-bench binary (default: ../build/bin/llama-bench)') + parser.add_argument('--stats-dir', + default='../stats', + help='Directory to save benchmark results (default: ../stats)') + parser.add_argument('--csv-output', '-c', + help='Custom path for CSV output file (e.g., stats/my_results.csv)') + parser.add_argument('--types', '-t', + nargs='+', + help='Specific types to quantize (e.g., f32 q6_k q4_0)') + parser.add_argument('--skip-existing', '-s', + action='store_true', + help='Skip quantization if output file already exists (will still benchmark existing files)') + + args = parser.parse_args() + + # Define all supported quantization types + # Format: (embedding_type for command line, output_suffix for filename) + all_types = [ + ('F32', 'f32'), + ('F16', 'f16'), + ('Q8_0', 'q8_0'), + ('Q6_K', 'q6_k'), + ('Q5_0', 'q5_0'), + ('Q4_0', 'q4_0'), + ('Q3_K', 'q3_k'), + ('TQ2_0', 'tq2_0'), + ] + + # If specific types are specified, filter the list + if args.types: + types_lower = [t.lower() for t in args.types] + types_to_quantize = [(et, os) for et, os in all_types if os.lower() in types_lower] + if not types_to_quantize: + print(f"❌ No valid types specified. Available types: {', '.join([os for _, os in all_types])}") + return + else: + types_to_quantize = all_types + + # If skip existing files is enabled, no need to filter + # Because new logic will automatically detect and skip during quantization, but will still benchmark + + # 创建量化器并运行 + try: + quantizer = EmbeddingQuantizer( + args.input, + args.output_dir, + args.quantize_bin, + args.bench_bin, + args.stats_dir, + args.csv_output + ) + quantizer.run_all_quantizations(types_to_quantize) + except FileNotFoundError as e: + print(f"❌ Error: {e}") + return 1 + except KeyboardInterrupt: + print("\n\n⚠️ Quantization interrupted by user") + return 1 + except Exception as e: + print(f"\n❌ Unexpected error: {e}") + import traceback + traceback.print_exc() + return 1 + + +if __name__ == "__main__": + exit(main() or 0) diff --git a/utils/test_gemm_kernel.sh b/utils/test_gemm_kernel.sh new file mode 100755 index 000000000..ae72c7266 --- /dev/null +++ b/utils/test_gemm_kernel.sh @@ -0,0 +1,573 @@ +#!/bin/bash +# Unified GEMM kernel benchmark script +# Builds, tests, and benchmarks the GEMM kernel with configurable output + +set -e + +# Default values +BUILD_DIR="../build" +ITERATIONS=1000 +OUTPUT_CSV="" +SKIP_BUILD=false +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +# Print usage +print_usage() { + cat << EOF +Usage: $0 [options] + +Options: + -o, --output Output CSV file path (default: ../stats/gemm_kernel_test_noparal.csv) + -i, --iterations Number of iterations per test (default: 1000) + -s, --skip-build Skip building the benchmark binary + -h, --help Show this help message + +Examples: + # Run with default settings + $0 + + # Specify custom output file + $0 -o /path/to/my_results.csv + + # Quick test with fewer iterations + $0 -i 100 -o quick_test.csv + + # Skip build if already compiled + $0 -s -o results.csv +EOF +} + +# Parse command line arguments +while [[ $# -gt 0 ]]; do + case $1 in + -o|--output) + OUTPUT_CSV="$2" + shift 2 + ;; + -i|--iterations) + ITERATIONS="$2" + shift 2 + ;; + -s|--skip-build) + SKIP_BUILD=true + shift + ;; + -h|--help) + print_usage + exit 0 + ;; + *) + echo "Unknown option: $1" + print_usage + exit 1 + ;; + esac +done + +# Set default output CSV if not specified +if [ -z "$OUTPUT_CSV" ]; then + OUTPUT_CSV="${SCRIPT_DIR}/../stats/gemm_kernel_test_noparal.csv" +fi + +# Create output directory first +mkdir -p "$(dirname "$OUTPUT_CSV")" + +# Convert to absolute path +if [[ "$OUTPUT_CSV" = /* ]]; then + # Already absolute path + OUTPUT_CSV="$OUTPUT_CSV" +else + # Convert relative path to absolute + OUTPUT_CSV="$(cd "$(dirname "$OUTPUT_CSV")" && pwd)/$(basename "$OUTPUT_CSV")" +fi + +echo "==========================================" +echo "GEMM Kernel Benchmark Suite" +echo "==========================================" +echo "Configuration:" +echo " Iterations: $ITERATIONS" +echo " Output CSV: $OUTPUT_CSV" +echo " Skip build: $SKIP_BUILD" +echo "==========================================" +echo "" + +# Build the benchmark binary +if [ "$SKIP_BUILD" = false ]; then + echo "Step 1: Building GEMM kernel benchmark..." + echo "------------------------------------------" + + CXX=${CXX:-g++} + + # Create build directory if it doesn't exist + mkdir -p "${SCRIPT_DIR}/${BUILD_DIR}" + + # Create temporary C++ source file + TEMP_CPP="${SCRIPT_DIR}/${BUILD_DIR}/test_gemm_kernel_temp.cpp" + + cat > "${TEMP_CPP}" << 'EOF' +/** + * Standalone benchmark for ggml_gemm_i2_i8_s kernel + * + * This program tests the performance of the ggml_gemm_i2_i8_s kernel + * with configurable matrix sizes and iteration counts. + * + * Usage: ./test_gemm_kernel [options] + * -n : embedding dimension (must be divisible by 4, default: 2048) + * -r : number of rows in matrix Y (default: 32) + * -c : number of columns in matrix X (default: 128) + * -i : number of iterations (default: 1000) + * -w : number of warmup iterations (default: 10) + */ + +#include +#include +#include +#include +#include +#include +#include + +// Include necessary headers +#include "../include/gemm-config.h" + +// Function declarations (from ggml-quants.h) +extern "C" void ggml_vec_dot_i2_i8_s(int n, float * s, size_t bs, const void * vx, size_t bx, const void * vy, size_t by, int nrc); + +// GEMM kernel definition +void ggml_gemm_i2_i8_s(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { +#if defined(ACT_PARALLEL) + const int64_t row_block = ROW_BLOCK_SIZE; + const int64_t col_block = COL_BLOCK_SIZE; + + for (int64_t c0 = 0; c0 < nc; c0 += col_block) { + int64_t cur_c = (c0 + col_block <= nc) ? col_block : (nc - c0); + for (int64_t r0 = 0; r0 < nr; r0 += row_block) { + int64_t cur_r = (r0 + row_block <= nr) ? row_block : (nr - r0); + const void * vy_r = (const uint8_t *)vy + r0 * n; + for (int64_t c = 0; c < cur_c; ++c) { + const int64_t col = c0 + c; + float * s_col = s + col; + const void * vx_col = (const uint8_t *)vx + col * n / 4; + ggml_vec_dot_i2_i8_s(n, s_col + r0 * bs, bs, vx_col, n, vy_r, n, cur_r); + } + } + } +#else + const int64_t row_block = ROW_BLOCK_SIZE; + const int64_t col_block = COL_BLOCK_SIZE; + + for (int64_t r0 = 0; r0 < nr; r0 += row_block) { + int64_t cur_r = (r0 + row_block <= nr) ? row_block : (nr - r0); + for (int64_t c0 = 0; c0 < nc; c0 += col_block) { + int64_t cur_c = (c0 + col_block <= nc) ? col_block : (nc - c0); + const void * vx_c = (const uint8_t *)vx + c0 * n / 4; + for (int64_t r = 0; r < cur_r; ++r) { + const int64_t row = r0 + r; + float * s_row = s + row * bs; + const void * vy_row = (const uint8_t *)vy + row * n; + ggml_vec_dot_i2_i8_s(n, s_row + c0, bs, vx_c, n, vy_row, n, cur_c); + } + } + } +#endif +} + +// Helper function to get current time in nanoseconds +double get_time_ns() { + struct timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + return ts.tv_sec * 1e9 + ts.tv_nsec; +} + +// Initialize matrix with random i2 values (2-bit quantized) +void init_matrix_i2(uint8_t* data, int n, int cols) { + // i2 format: 4 values per byte (2 bits each) + int total_bytes = n * cols / 4; + for (int i = 0; i < total_bytes; i++) { + data[i] = rand() & 0xFF; + } +} + +// Initialize matrix with random i8 values +void init_matrix_i8(int8_t* data, int n, int rows) { + int total_elements = n * rows; + for (int i = 0; i < total_elements; i++) { + data[i] = (int8_t)((rand() % 256) - 128); + } +} + +// Benchmark configuration +struct BenchmarkConfig { + int n; // embedding dimension (must be divisible by 4) + int nr; // number of rows in Y matrix + int nc; // number of columns in X matrix + int iterations; // number of benchmark iterations + int warmup; // number of warmup iterations +}; + +void print_config(const BenchmarkConfig& config) { + printf("=" "=%.78s\n", "==============================================================================="); + printf("Benchmark Configuration:\n"); + printf("=" "=%.78s\n", "==============================================================================="); + printf(" Embedding dimension (n) : %d\n", config.n); + printf(" Matrix Y rows (nr) : %d\n", config.nr); + printf(" Matrix X columns (nc) : %d\n", config.nc); + printf(" Iterations : %d\n", config.iterations); + printf(" Warmup iterations : %d\n", config.warmup); + printf("\nMatrix sizes:\n"); + printf(" X (i2): %d x %d (%.2f KB)\n", config.nc, config.n, + (config.nc * config.n / 4) / 1024.0); + printf(" Y (i8): %d x %d (%.2f KB)\n", config.nr, config.n, + (config.nr * config.n) / 1024.0); + printf(" S (f32): %d x %d (%.2f KB)\n", config.nr, config.nc, + (config.nr * config.nc * sizeof(float)) / 1024.0); + printf("\nGEMM Config:\n"); +#if defined(ACT_PARALLEL) + printf(" ACT_PARALLEL : ON\n"); +#else + printf(" ACT_PARALLEL : OFF\n"); +#endif + printf(" ROW_BLOCK_SIZE : %d\n", ROW_BLOCK_SIZE); + printf(" COL_BLOCK_SIZE : %d\n", COL_BLOCK_SIZE); + printf(" PARALLEL_SIZE : %d\n", PARALLEL_SIZE); + printf("=" "=%.78s\n\n", "==============================================================================="); +} + +void run_benchmark(const BenchmarkConfig& config) { + // Allocate matrices + printf("Allocating matrices...\n"); + + // X matrix (i2 format): nc x n, but stored as nc x (n/4) bytes + // Align to 64 bytes for AVX-512, which is backward compatible with AVX2 (32 bytes) + size_t x_size = config.nc * config.n / 4; + size_t x_size_aligned = ((x_size + 63) / 64) * 64; + uint8_t* X = (uint8_t*)aligned_alloc(64, x_size_aligned); + + // Y matrix (i8 format): nr x n + size_t y_size = config.nr * config.n; + size_t y_size_aligned = ((y_size + 63) / 64) * 64; + int8_t* Y = (int8_t*)aligned_alloc(64, y_size_aligned); + + // Result matrix (float32): nr x nc + size_t s_size = config.nr * config.nc * sizeof(float); + size_t s_size_aligned = ((s_size + 63) / 64) * 64; + float* S = (float*)aligned_alloc(64, s_size_aligned); + + if (!X || !Y || !S) { + fprintf(stderr, "Failed to allocate memory\n"); + exit(1); + } + + // Initialize matrices with random data + printf("Initializing matrices with random data...\n"); + srand(time(NULL)); + init_matrix_i2(X, config.n, config.nc); + init_matrix_i8(Y, config.n, config.nr); + memset(S, 0, config.nr * config.nc * sizeof(float)); + + // Warmup + printf("Running %d warmup iterations...\n", config.warmup); + for (int i = 0; i < config.warmup; i++) { + ggml_gemm_i2_i8_s(config.n, S, config.nc, X, Y, config.nr, config.nc); + } + + // Benchmark + printf("Running %d benchmark iterations...\n", config.iterations); + double total_time = 0.0; + double min_time = 1e20; + double max_time = 0.0; + + for (int i = 0; i < config.iterations; i++) { + double start = get_time_ns(); + ggml_gemm_i2_i8_s(config.n, S, config.nc, X, Y, config.nr, config.nc); + double end = get_time_ns(); + + double elapsed = end - start; + total_time += elapsed; + if (elapsed < min_time) min_time = elapsed; + if (elapsed > max_time) max_time = elapsed; + + if ((i + 1) % 100 == 0) { + printf(" Progress: %d/%d iterations\n", i + 1, config.iterations); + } + } + + // Calculate statistics + double avg_time_ns = total_time / config.iterations; + double avg_time_ms = avg_time_ns / 1e6; + double min_time_ms = min_time / 1e6; + double max_time_ms = max_time / 1e6; + + // Calculate GFLOPS + // For GEMM: nr x nc x n multiply-adds = 2 * nr * nc * n FLOPs + double flops = 2.0 * config.nr * config.nc * config.n; + double gflops = (flops / avg_time_ns); + + // Calculate throughput (tokens/s assuming each column is a token) + double throughput = (config.nc * 1e9) / avg_time_ns; + + // Print results + printf("\n"); + printf("=" "=%.78s\n", "==============================================================================="); + printf("Benchmark Results:\n"); + printf("=" "=%.78s\n", "==============================================================================="); + printf(" Average time : %.3f ms\n", avg_time_ms); + printf(" Min time : %.3f ms\n", min_time_ms); + printf(" Max time : %.3f ms\n", max_time_ms); + printf(" Std dev : %.3f ms\n", sqrt((max_time_ms - min_time_ms) * (max_time_ms - min_time_ms) / 12)); + printf("\nPerformance:\n"); + printf(" GFLOPS : %.2f\n", gflops); + printf(" Throughput : %.2f tokens/s\n", throughput); + printf(" Latency/token : %.3f us\n", (avg_time_ms * 1000) / config.nc); + printf("=" "=%.78s\n", "==============================================================================="); + + // Cleanup + free(X); + free(Y); + free(S); +} + +void print_usage(const char* program) { + printf("Usage: %s [options]\n", program); + printf("Options:\n"); + printf(" -n Embedding dimension (must be divisible by 4, default: 2048)\n"); + printf(" -r Number of rows in matrix Y (default: 32)\n"); + printf(" -c Number of columns in matrix X (default: 128)\n"); + printf(" -i Number of iterations (default: 1000)\n"); + printf(" -w Number of warmup iterations (default: 10)\n"); + printf(" -h Show this help message\n"); +} + +int main(int argc, char** argv) { + BenchmarkConfig config = { + .n = 2048, + .nr = 32, + .nc = 128, + .iterations = 1000, + .warmup = 10 + }; + + // Parse command line arguments + for (int i = 1; i < argc; i++) { + if (strcmp(argv[i], "-n") == 0 && i + 1 < argc) { + config.n = atoi(argv[++i]); + } else if (strcmp(argv[i], "-r") == 0 && i + 1 < argc) { + config.nr = atoi(argv[++i]); + } else if (strcmp(argv[i], "-c") == 0 && i + 1 < argc) { + config.nc = atoi(argv[++i]); + } else if (strcmp(argv[i], "-i") == 0 && i + 1 < argc) { + config.iterations = atoi(argv[++i]); + } else if (strcmp(argv[i], "-w") == 0 && i + 1 < argc) { + config.warmup = atoi(argv[++i]); + } else if (strcmp(argv[i], "-h") == 0) { + print_usage(argv[0]); + return 0; + } else { + fprintf(stderr, "Unknown option: %s\n", argv[i]); + print_usage(argv[0]); + return 1; + } + } + + // Validate configuration + if (config.n % 4 != 0) { + fprintf(stderr, "Error: Embedding dimension (-n) must be divisible by 4\n"); + return 1; + } + + if (config.n <= 0 || config.nr <= 0 || config.nc <= 0 || config.iterations <= 0) { + fprintf(stderr, "Error: All size parameters must be positive\n"); + return 1; + } + + // Run benchmark + print_config(config); + run_benchmark(config); + + return 0; +} +EOF + + # Compiler flags + CXXFLAGS="-O3 -march=native -mtune=native -std=c++17 -fopenmp" + CXXFLAGS+=" -I${SCRIPT_DIR}/.. -I${SCRIPT_DIR}/../include" + CXXFLAGS+=" -I${SCRIPT_DIR}/../3rdparty/llama.cpp/ggml/include" + CXXFLAGS+=" -I${SCRIPT_DIR}/../3rdparty/llama.cpp/ggml/src" + CXXFLAGS+=" -I${SCRIPT_DIR}/../3rdparty/llama.cpp/include" + CXXFLAGS+=" -DNDEBUG -ffast-math" + + # Link flags + LDFLAGS="-lm -lpthread" + + # Link with pre-built libraries + GGML_LIB_DIR="${SCRIPT_DIR}/../build/3rdparty/llama.cpp/ggml/src" + GGML_SO="${GGML_LIB_DIR}/libggml.so" + + if [ ! -f "${GGML_SO}" ]; then + echo "❌ Error: Cannot find libggml.so at ${GGML_SO}" + echo "Please build the project first with: cmake --build build" + rm -f "${TEMP_CPP}" + exit 1 + fi + + LDFLAGS+=" -L${GGML_LIB_DIR} -lggml -Wl,-rpath,${GGML_LIB_DIR}" + + # Output binary + BENCHMARK_BIN="${SCRIPT_DIR}/${BUILD_DIR}/test_gemm_kernel" + + echo "Compiler: ${CXX}" + echo "Building from embedded source..." + echo "" + + # Build + ${CXX} ${CXXFLAGS} "${TEMP_CPP}" -o ${BENCHMARK_BIN} ${LDFLAGS} + + if [ $? -eq 0 ]; then + echo "✅ Build successful!" + rm -f "${TEMP_CPP}" + echo "" + else + echo "❌ Build failed!" + rm -f "${TEMP_CPP}" + exit 1 + fi +else + echo "Step 1: Skipping build (using existing binary)" + echo "------------------------------------------" + BENCHMARK_BIN="${SCRIPT_DIR}/${BUILD_DIR}/test_gemm_kernel" + + if [ ! -f "${BENCHMARK_BIN}" ]; then + echo "❌ Error: Benchmark binary not found at ${BENCHMARK_BIN}" + echo "Please run without -s to build it first." + exit 1 + fi + echo "✅ Found existing binary" + echo "" +fi + +# Set LD_LIBRARY_PATH to include the GGML library directory +GGML_LIB_DIR="${SCRIPT_DIR}/../build/3rdparty/llama.cpp/ggml/src" +export LD_LIBRARY_PATH="${GGML_LIB_DIR}:${LD_LIBRARY_PATH}" + +echo "Step 2: Running benchmark tests" +echo "------------------------------------------" +echo "Library path: ${GGML_LIB_DIR}" +echo "" + +# Write CSV header +echo "test_name,n,nr,nc,time_ms,gflops,throughput_tokens_per_sec" > "$OUTPUT_CSV" +echo "Results will be saved to: $OUTPUT_CSV" +echo "" + +# Function to extract metrics and append to CSV +extract_and_save() { + local test_name="$1" + local output="$2" + + # Extract values using grep and awk + local n=$(echo "$output" | grep "Embedding dimension" | awk '{print $5}') + local nr=$(echo "$output" | grep "Matrix Y rows" | awk '{print $6}') + local nc=$(echo "$output" | grep "Matrix X columns" | awk '{print $6}') + local avg_time=$(echo "$output" | grep "Average time" | awk '{print $4}') + local min_time=$(echo "$output" | grep "Min time" | awk '{print $4}') + local max_time=$(echo "$output" | grep "Max time" | awk '{print $4}') + local gflops=$(echo "$output" | grep "GFLOPS" | awk '{print $3}') + local throughput=$(echo "$output" | grep "Throughput" | awk '{print $3}') + + # Check if values were extracted successfully + if [ -z "$avg_time" ] || [ -z "$min_time" ] || [ -z "$max_time" ]; then + echo "Warning: Failed to extract timing data for ${test_name}" + echo "${test_name},${n},${nr},${nc},N/A,N/A,N/A" >> "$OUTPUT_CSV" + return + fi + + # Calculate standard deviation estimate from range + # Using awk with proper variable passing + local std_time=$(awk -v min="$min_time" -v max="$max_time" 'BEGIN {printf "%.4f", (max - min) / 4}') + + # Format as mean±std + local time_formatted="${avg_time}±${std_time}" + + # Append to CSV + echo "${test_name},${n},${nr},${nc},${time_formatted},${gflops},${throughput}" >> "$OUTPUT_CSV" +} + +# Run benchmark tests +echo "==========================================" +echo "BitNet-2B Typical Shapes Performance Test" +echo "==========================================" +echo "" + +echo "Test 1: Single Token Generation (Attention QKV projection)" +echo " Scenario: Generating 1 token at a time" +echo " Shape: n=2048, r=1, c=2048" +OUTPUT=$($BENCHMARK_BIN -n 2048 -r 1 -c 2048 -i $ITERATIONS 2>&1) +echo "$OUTPUT" +extract_and_save "single_token_gen" "$OUTPUT" +echo "" + +echo "Test 2: Small Batch Prompt Processing (Attention QKV projection)" +echo " Scenario: Processing prompt with 128 tokens, batch size 1" +echo " Shape: n=2048, r=128, c=2048" +OUTPUT=$($BENCHMARK_BIN -n 2048 -r 128 -c 2048 -i $ITERATIONS 2>&1) +echo "$OUTPUT" +extract_and_save "small_batch_prompt" "$OUTPUT" +echo "" + +echo "Test 3: Medium Batch Prompt Processing (Attention QKV projection)" +echo " Scenario: Processing prompt with 256 tokens or batch of 256" +echo " Shape: n=2048, r=256, c=2048" +OUTPUT=$($BENCHMARK_BIN -n 2048 -r 256 -c 2048 -i $ITERATIONS 2>&1) +echo "$OUTPUT" +extract_and_save "medium_batch_prompt" "$OUTPUT" +echo "" + +echo "Test 4: Large Batch Processing (Attention QKV projection)" +echo " Scenario: Processing 512 tokens or batch of 512" +echo " Shape: n=2048, r=512, c=2048" +OUTPUT=$($BENCHMARK_BIN -n 2048 -r 512 -c 2048 -i $ITERATIONS 2>&1) +echo "$OUTPUT" +extract_and_save "large_batch_prompt" "$OUTPUT" +echo "" + +echo "Test 5: FFN Up-projection (Small batch)" +echo " Scenario: Feed-forward network expansion, 128 tokens" +echo " Shape: n=2048, r=128, c=8192" +OUTPUT=$($BENCHMARK_BIN -n 2048 -r 128 -c 8192 -i $ITERATIONS 2>&1) +echo "$OUTPUT" +extract_and_save "ffn_up_projection" "$OUTPUT" +echo "" + +echo "Test 6: FFN Down-projection (Small batch)" +echo " Scenario: Feed-forward network reduction, 128 tokens" +echo " Shape: n=8192, r=128, c=2048" +OUTPUT=$($BENCHMARK_BIN -n 8192 -r 128 -c 2048 -i $ITERATIONS 2>&1) +echo "$OUTPUT" +extract_and_save "ffn_down_projection" "$OUTPUT" +echo "" + +echo "Test 7: Long Context Processing" +echo " Scenario: Processing very long context (2048 tokens)" +echo " Shape: n=2048, r=2048, c=2048" +OUTPUT=$($BENCHMARK_BIN -n 2048 -r 2048 -c 2048 -i $ITERATIONS 2>&1) +echo "$OUTPUT" +extract_and_save "long_context" "$OUTPUT" +echo "" + +echo "Test 8: Batched Token Generation" +echo " Scenario: Generating tokens for 32 sequences simultaneously" +echo " Shape: n=2048, r=32, c=2048" +OUTPUT=$($BENCHMARK_BIN -n 2048 -r 32 -c 2048 -i $ITERATIONS 2>&1) +echo "$OUTPUT" +extract_and_save "batched_token_gen" "$OUTPUT" +echo "" + +echo "==========================================" +echo "All tests completed successfully!" +echo "==========================================" +echo "Results saved to: $OUTPUT_CSV" +echo "" +echo "Summary:" +wc -l "$OUTPUT_CSV" | awk '{print " Total records:", $1 - 1}' +echo " Output file: $OUTPUT_CSV" +echo "==========================================" diff --git a/utils/test_perplexity.py b/utils/test_perplexity.py new file mode 100644 index 000000000..f2d9788c0 --- /dev/null +++ b/utils/test_perplexity.py @@ -0,0 +1,608 @@ +#!/usr/bin/env python3 +""" +Perplexity Test Script +Tests GGUF model perplexity on multiple datasets using llama-perplexity. +""" + +import os +import subprocess +import time +import csv +import re +from datetime import datetime +from pathlib import Path +import argparse +import tempfile +import shutil +import statistics + + +class PerplexityTester: + def __init__(self, model_path, llama_perplexity_bin="../build/bin/llama-perplexity", + data_dir="../data", output_dir="perplexity_results", quick_mode=False, + quantize_bin="../build/bin/llama-quantize", test_embeddings=False, csv_output=None): + self.model_path = Path(model_path) + self.llama_perplexity_bin = Path(llama_perplexity_bin) + self.quantize_bin = Path(quantize_bin) + self.data_dir = Path(data_dir) + self.output_dir = Path(output_dir) + self.quick_mode = quick_mode + self.test_embeddings = test_embeddings + self.csv_output = Path(csv_output) if csv_output else None + self.results = [] + self.created_models = set() # Track newly created model files + self.temp_files = [] # Track temporary files for cleanup + + # Embedding types to test + self.embedding_types = [ + ('F32', 'f32'), + ('F16', 'f16'), + ('Q8_0', 'q8_0'), + ('Q6_K', 'q6_k'), + ('Q5_0', 'q5_0'), + ('Q4_0', 'q4_0'), + ('Q3_K', 'q3_k'), + ('TQ2_0', 'tq2_0'), + ] + + # Create output directory + self.output_dir.mkdir(parents=True, exist_ok=True) + + # Verify llama-perplexity binary exists + if not self.llama_perplexity_bin.exists(): + raise FileNotFoundError(f"llama-perplexity binary not found: {self.llama_perplexity_bin}") + + # Verify quantize binary exists if testing embeddings + if self.test_embeddings and not self.quantize_bin.exists(): + raise FileNotFoundError(f"llama-quantize binary not found: {self.quantize_bin}") + + # Verify model file exists + if not self.model_path.exists(): + raise FileNotFoundError(f"Model file not found: {self.model_path}") + + def find_datasets(self): + """Find all test.txt files in dataset directories.""" + datasets = [] + + if not self.data_dir.exists(): + print(f"❌ Data directory not found: {self.data_dir}") + return datasets + + print(f"\n🔍 Searching for datasets in {self.data_dir}...") + + # Look for test.txt files in subdirectories + for dataset_dir in sorted(self.data_dir.iterdir()): + if dataset_dir.is_dir(): + test_file = dataset_dir / "test.txt" + if test_file.exists(): + size_mb = test_file.stat().st_size / (1024 * 1024) + datasets.append({ + 'name': dataset_dir.name, + 'path': test_file, + 'size': test_file.stat().st_size, + 'size_mb': size_mb + }) + print(f" ✅ {dataset_dir.name:<20} ({size_mb:.2f} MB)") + else: + print(f" ⚠️ {dataset_dir.name:<20} (no test.txt found)") + + return datasets + + def create_quick_dataset(self, dataset_path, num_chars=4096): + """Create a temporary dataset with only the first N characters for quick testing.""" + temp_file = tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.txt', encoding='utf-8') + self.temp_files.append(temp_file.name) + + try: + with open(dataset_path, 'r', encoding='utf-8', errors='ignore') as f: + content = f.read(num_chars) + temp_file.write(content) + temp_file.close() + return Path(temp_file.name) + except Exception as e: + print(f"⚠️ Failed to create quick dataset: {e}") + temp_file.close() + return dataset_path + + def cleanup_temp_files(self): + """Clean up temporary files.""" + for temp_file in self.temp_files: + try: + os.unlink(temp_file) + except: + pass + self.temp_files = [] + + def run_perplexity_test(self, dataset_name, dataset_path, threads=16, ctx_size=512, model_override=None): + """Run perplexity test on a single dataset.""" + test_model = model_override if model_override else self.model_path + + print(f"\n{'='*80}") + print(f"📊 Testing on dataset: {dataset_name}") + print(f" File: {dataset_path}") + print(f" Model: {test_model.name}") + print(f"{'='*80}") + + cmd = [ + str(self.llama_perplexity_bin), + "-m", str(test_model), + "-f", str(dataset_path), + "-t", str(threads), + "-c", str(ctx_size), + "-ngl", "0" # CPU only + ] + + print(f"💻 Command: {' '.join(cmd)}") + print(f"⏱️ Starting test...\n") + + start_time = time.time() + + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + timeout=3600, # 1 hour timeout + cwd=os.getcwd() + ) + + elapsed_time = time.time() - start_time + + if result.returncode == 0: + # Parse perplexity from output (check both stdout and stderr) + combined_output = result.stdout + "\n" + result.stderr + ppl = self.parse_perplexity(combined_output) + + if ppl is not None: + print(f"\n✅ Perplexity: {ppl}") + print(f"⏱️ Time: {elapsed_time:.2f}s ({elapsed_time/60:.2f} min)") + status = "success" + else: + print(f"\n⚠️ Test completed but could not parse perplexity") + print(f"Last 500 chars of stdout:") + print(result.stdout[-500:]) + print(f"Last 500 chars of stderr:") + print(result.stderr[-500:]) + status = "parse_error" + ppl = None + else: + print(f"\n❌ Test failed with return code {result.returncode}") + print(f"Error: {result.stderr[:500]}") + status = "failed" + ppl = None + elapsed_time = time.time() - start_time + + return { + 'dataset': dataset_name, + 'perplexity': ppl, + 'time': elapsed_time, + 'status': status, + 'stdout': result.stdout, + 'stderr': result.stderr + } + + except subprocess.TimeoutExpired: + elapsed_time = time.time() - start_time + print(f"\n❌ Timeout after {elapsed_time:.2f}s") + return { + 'dataset': dataset_name, + 'perplexity': None, + 'time': elapsed_time, + 'status': 'timeout', + 'stdout': '', + 'stderr': 'Test exceeded 1 hour timeout' + } + except Exception as e: + elapsed_time = time.time() - start_time + print(f"\n❌ Error: {e}") + return { + 'dataset': dataset_name, + 'perplexity': None, + 'time': elapsed_time, + 'status': 'error', + 'stdout': '', + 'stderr': str(e) + } + + def parse_perplexity(self, output): + """Parse perplexity value (mean±std format) from llama-perplexity output.""" + # First try to match "PPL = mean +/- std" format + pattern_with_std = r'PPL\s*=\s*(\d+\.?\d*)\s*\+/-\s*(\d+\.?\d*)' + match = re.search(pattern_with_std, output, re.IGNORECASE | re.MULTILINE) + if match: + try: + mean = float(match.group(1)) + std = float(match.group(2)) + return f"{mean:.4f}±{std:.4f}" + except ValueError: + pass + + # Fallback to patterns without std + patterns = [ + r'Final estimate:\s*PPL\s*=\s*(\d+\.?\d*)', + r'Final perplexity:\s*(\d+\.?\d*)', + r'PPL\s*=\s*(\d+\.?\d*)', + r'PPL:\s*(\d+\.?\d*)', + r'perplexity:\s*(\d+\.?\d*)', + r'ppl\s*=\s*(\d+\.?\d*)', + r'Perplexity:\s*(\d+\.?\d*)', + ] + + for pattern in patterns: + match = re.search(pattern, output, re.IGNORECASE | re.MULTILINE) + if match: + try: + return f"{float(match.group(1)):.4f}" + except ValueError: + continue + + return None + + def quantize_embedding(self, embedding_type, output_suffix): + """ + Quantize model with specific embedding type. + + Args: + embedding_type: Token embedding type (uppercase, e.g., 'Q6_K') + output_suffix: Output file suffix (lowercase, e.g., 'q6_k') + + Returns: + Path to quantized model or None if failed + """ + # Construct output path + model_dir = self.model_path.parent + output_path = model_dir / f"ggml-model-i2_s-embed-{output_suffix}.gguf" + + # Check if file already exists + file_existed = output_path.exists() + + if file_existed: + print(f"ℹ️ Model already exists: {output_path.name}") + return output_path + + cmd = [ + str(self.quantize_bin), + "--token-embedding-type", embedding_type, + str(self.model_path), + str(output_path), + "I2_S", + "1", + "1" + ] + + print(f"\n{'='*80}") + print(f"🔄 Quantizing with embedding type: {embedding_type}") + print(f"📥 Input: {self.model_path.name}") + print(f"📤 Output: {output_path.name}") + print(f"💻 Command: {' '.join(cmd)}") + print(f"{'='*80}\n") + + start_time = time.time() + + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + cwd=os.getcwd(), + timeout=600 # 10 minutes timeout + ) + + duration = time.time() - start_time + + if result.returncode == 0: + file_size_mb = output_path.stat().st_size / (1024 * 1024) + print(f"✅ Quantization successful!") + print(f" Duration: {duration:.2f}s") + print(f" Size: {file_size_mb:.2f} MB") + + # Mark as newly created + self.created_models.add(output_path) + return output_path + else: + print(f"❌ Quantization failed with return code {result.returncode}") + print(f"Error: {result.stderr[:500]}") + return None + + except subprocess.TimeoutExpired: + print(f"❌ Quantization timeout (exceeded 10 minutes)") + return None + except Exception as e: + print(f"❌ Quantization error: {e}") + return None + + def cleanup_model(self, model_path): + """Delete model file if it was created during this session.""" + if model_path in self.created_models: + try: + model_path.unlink() + print(f"🗑️ Deleted: {model_path.name}") + self.created_models.remove(model_path) + except Exception as e: + print(f"⚠️ Failed to delete {model_path.name}: {e}") + else: + print(f"ℹ️ Keeping existing file: {model_path.name}") + + def run_all_tests(self, threads=16, ctx_size=512): + """Run perplexity tests on all datasets.""" + datasets = self.find_datasets() + + if not datasets: + print(f"\n❌ No datasets found in {self.data_dir}") + print(f" Make sure each dataset directory has a test.txt file") + return + + # Quick mode: test all datasets but only first 4096 chars with smaller context + if self.quick_mode: + ctx_size = min(ctx_size, 128) # Use smaller context in quick mode + print(f"\n⚡ QUICK TEST MODE ENABLED") + print(f" - Testing all datasets with first 4096 characters only") + print(f" - Using reduced context size: {ctx_size}") + + # Determine models to test + if self.test_embeddings: + print(f"\n{'='*80}") + print(f"🧪 EMBEDDING QUANTIZATION TEST MODE") + print(f"{'='*80}") + print(f"📦 Base model: {self.model_path.name}") + print(f"🔢 Embedding types to test: {len(self.embedding_types)}") + print(f"📊 Datasets: {len(datasets)}") + print(f"🧵 Threads: {threads}") + print(f"📏 Context size: {ctx_size}") + print(f"{'='*80}") + + total_start = time.time() + + # Test each embedding type + for i, (embedding_type, output_suffix) in enumerate(self.embedding_types, 1): + print(f"\n\n{'#'*80}") + print(f"[{i}/{len(self.embedding_types)}] Testing embedding type: {output_suffix} ({embedding_type})") + print(f"{'#'*80}") + + # Quantize model + quantized_model = self.quantize_embedding(embedding_type, output_suffix) + + if quantized_model is None: + print(f"⚠️ Skipping tests for {output_suffix} due to quantization failure") + continue + + # Test on all datasets + for j, dataset in enumerate(datasets, 1): + print(f"\n[{j}/{len(datasets)}] Testing {dataset['name']} with {output_suffix}...") + + # Use quick dataset if in quick mode + test_path = dataset['path'] + if self.quick_mode: + test_path = self.create_quick_dataset(dataset['path']) + + result = self.run_perplexity_test( + f"{dataset['name']}_embed-{output_suffix}", + test_path, + threads, + ctx_size, + model_override=quantized_model + ) + self.results.append(result) + + # Cleanup model after testing + print(f"\n🧹 Cleaning up {output_suffix} model...") + self.cleanup_model(quantized_model) + + print(f"\n{'#'*80}") + print(f"✅ Completed {output_suffix}") + print(f"{'#'*80}") + + total_time = time.time() - total_start + + else: + # Regular single model test + print(f"\n{'='*80}") + print(f"🚀 PERPLEXITY TEST SESSION{' (QUICK MODE)' if self.quick_mode else ''}") + print(f"{'='*80}") + print(f"📦 Model: {self.model_path.name}") + print(f"📁 Model path: {self.model_path}") + print(f"📊 Datasets {'to test' if self.quick_mode else 'found'}: {len(datasets)}") + print(f"🧵 Threads: {threads}") + print(f"📏 Context size: {ctx_size}") + print(f"{'='*80}") + + total_start = time.time() + + # Run tests + for i, dataset in enumerate(datasets, 1): + print(f"\n\n[{i}/{len(datasets)}] Processing {dataset['name']}...") + + # Use quick dataset if in quick mode + test_path = dataset['path'] + if self.quick_mode: + test_path = self.create_quick_dataset(dataset['path']) + + result = self.run_perplexity_test( + dataset['name'], + test_path, + threads, + ctx_size + ) + self.results.append(result) + + total_time = time.time() - total_start + + # Clean up temporary files + if self.quick_mode: + print(f"\n🧹 Cleaning up temporary files...") + self.cleanup_temp_files() + + # Save results + self.save_results() + + # Print summary + self.print_summary(total_time) + + def save_results(self): + """Save results to CSV file.""" + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + model_name = self.model_path.stem + + # Use custom CSV path if provided + if self.csv_output: + csv_file = self.csv_output + # Create parent directory if needed + csv_file.parent.mkdir(parents=True, exist_ok=True) + else: + csv_file = self.output_dir / f"ppl_{model_name}_{timestamp}.csv" + + print(f"\n💾 Saving results...") + + with open(csv_file, 'w', newline='') as f: + writer = csv.DictWriter(f, fieldnames=['dataset', 'perplexity', 'time_seconds', 'status']) + writer.writeheader() + for result in self.results: + writer.writerow({ + 'dataset': result['dataset'], + 'perplexity': result['perplexity'] if result['perplexity'] is not None else 'N/A', + 'time_seconds': f"{result['time']:.2f}", + 'status': result['status'] + }) + + print(f" ✅ CSV saved: {csv_file}") + + # Save detailed log + log_file = self.output_dir / f"ppl_{model_name}_{timestamp}.log" + with open(log_file, 'w') as f: + f.write(f"Perplexity Test Results\n") + f.write(f"{'='*80}\n") + f.write(f"Model: {self.model_path}\n") + f.write(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") + f.write(f"{'='*80}\n\n") + + for result in self.results: + f.write(f"\n{'='*80}\n") + f.write(f"Dataset: {result['dataset']}\n") + f.write(f"Perplexity: {result['perplexity']}\n") + f.write(f"Time: {result['time']:.2f}s\n") + f.write(f"Status: {result['status']}\n") + f.write(f"\nOutput:\n{result['stdout']}\n") + if result['stderr']: + f.write(f"\nErrors:\n{result['stderr']}\n") + + print(f" ✅ Log saved: {log_file}") + + def print_summary(self, total_time): + """Print summary of all tests.""" + print(f"\n\n{'='*80}") + print(f"📊 TEST SUMMARY") + print(f"{'='*80}\n") + + # Sort results by perplexity (lower is better) + successful = [r for r in self.results if r['perplexity'] is not None] + failed = [r for r in self.results if r['perplexity'] is None] + + if successful: + # Extract numeric value from "mean±std" format for sorting + def get_ppl_value(result): + ppl = result['perplexity'] + if isinstance(ppl, str) and '±' in ppl: + return float(ppl.split('±')[0]) + elif isinstance(ppl, str): + try: + return float(ppl) + except ValueError: + return float('inf') + return ppl + + successful_sorted = sorted(successful, key=get_ppl_value) + + print(f"{'Dataset':<20} {'Perplexity':>20} {'Time (s)':>12} {'Status':<15}") + print(f"{'-'*80}") + + for result in successful_sorted: + ppl_str = str(result['perplexity']) if result['perplexity'] is not None else 'N/A' + print(f"{result['dataset']:<20} {ppl_str:>20} " + f"{result['time']:>12.2f} {result['status']:<15}") + + best_ppl = str(successful_sorted[0]['perplexity']) + print(f"\n🏆 Best result: {successful_sorted[0]['dataset']} " + f"(PPL: {best_ppl})") + + if failed: + print(f"\n❌ Failed tests ({len(failed)}):") + for result in failed: + print(f" - {result['dataset']}: {result['status']}") + + print(f"\n{'='*80}") + print(f"✅ Completed: {len(successful)}/{len(self.results)}") + print(f"⏱️ Total time: {total_time:.2f}s ({total_time/60:.2f} min)") + print(f"📁 Results saved in: {self.output_dir}") + print(f"{'='*80}\n") + + +def main(): + parser = argparse.ArgumentParser(description='Test model perplexity on multiple datasets') + parser.add_argument('--model', '-m', + required=True, + help='Path to GGUF model file') + parser.add_argument('--data-dir', '-d', + default='data', + help='Directory containing dataset folders (default: data)') + parser.add_argument('--threads', '-t', + type=int, + default=16, + help='Number of threads (default: 16)') + parser.add_argument('--ctx-size', '-c', + type=int, + default=512, + help='Context size (default: 512)') + parser.add_argument('--output-dir', '-o', + default='perplexity_results', + help='Output directory for results (default: perplexity_results)') + parser.add_argument('--llama-perplexity', + default='./build/bin/llama-perplexity', + help='Path to llama-perplexity binary (default: ./build/bin/llama-perplexity)') + parser.add_argument('--quick', '-q', + action='store_true', + help='Quick test mode: test all datasets with first 4096 characters and reduced context size (128)') + parser.add_argument('--test-embeddings', '-e', + action='store_true', + help='Test different embedding quantization types (f32, f16, q8_0, q6_k, q5_0, q4_0, q3_k, tq2_0)') + parser.add_argument('--csv-output', + help='Custom path for CSV output file (e.g., results/my_ppl_results.csv)') + parser.add_argument('--quantize-bin', + default='./build/bin/llama-quantize', + help='Path to llama-quantize binary (default: ./build/bin/llama-quantize)') + + args = parser.parse_args() + + try: + tester = PerplexityTester( + model_path=args.model, + llama_perplexity_bin=args.llama_perplexity, + data_dir=args.data_dir, + output_dir=args.output_dir, + quick_mode=args.quick, + quantize_bin=args.quantize_bin, + test_embeddings=args.test_embeddings, + csv_output=args.csv_output + ) + + tester.run_all_tests( + threads=args.threads, + ctx_size=args.ctx_size + ) + + except FileNotFoundError as e: + print(f"❌ Error: {e}") + return 1 + except KeyboardInterrupt: + print("\n\n⚠️ Test interrupted by user") + return 1 + except Exception as e: + print(f"\n❌ Unexpected error: {e}") + import traceback + traceback.print_exc() + return 1 + + return 0 + + +if __name__ == "__main__": + exit(main()) diff --git a/utils/test_power.sh b/utils/test_power.sh new file mode 100755 index 000000000..79a1a6845 --- /dev/null +++ b/utils/test_power.sh @@ -0,0 +1,151 @@ +#!/bin/bash +# Monitor power consumption for llama-bench with different thread configurations +# Usage: ./monitor_power.sh +# Example: ./monitor_power.sh models/model.gguf results.csv "1,2,4,8" "1,2,4,8" + +set -e + +# Parse arguments +if [ $# -ne 4 ]; then + echo "Usage: $0 " + echo "Example: $0 models/model.gguf results.csv \"1,2,4,8\" \"1,2,4,8\"" + exit 1 +fi + +MODEL_PATH="$1" +OUTPUT_CSV="$2" +PP_THREADS="$3" +TG_THREADS="$4" + +TEMP_LOG="/tmp/power_monitor_$$.log" +PID_FILE="/tmp/monitor_$$.pid" +BENCH_OUTPUT="/tmp/bench_output_$$.txt" + +# Validate model exists +if [ ! -f "$MODEL_PATH" ]; then + echo "Error: Model file not found: $MODEL_PATH" + exit 1 +fi + +# Create output directory if needed +mkdir -p "$(dirname "$OUTPUT_CSV")" + +# Function to monitor CPU stats +monitor_cpu() { + local log_file="$1" + echo "Timestamp,CPU_Usage(%),Avg_Freq(MHz)" > "$log_file" + while [ -f "$PID_FILE" ]; do + cpu_usage=$(top -bn1 | grep "Cpu(s)" | awk '{print 100-$8}') + avg_freq=$(grep "cpu MHz" /proc/cpuinfo | awk '{sum+=$4; count++} END {printf "%.0f", sum/count}') + timestamp=$(date +%s.%N) + echo "$timestamp,$cpu_usage,$avg_freq" >> "$log_file" + sleep 0.5 + done +} + +# Function to calculate average power +calculate_power() { + local log_file="$1" + awk -F',' 'NR>1 {sum_cpu+=$2; count++} END { + if (count > 0) { + avg_cpu = sum_cpu/count + est_power = avg_cpu * 200 / 100 + printf "%.2f", est_power + } else { + print "0" + } + }' "$log_file" +} + +# Function to extract throughput from llama-bench output +extract_throughput() { + local bench_output="$1" + local workload="$2" + grep "$workload" "$bench_output" | awk '{ + # Extract mean from "mean ± std" format + for (i=1; i<=NF; i++) { + if ($(i+1) == "±") { + printf "%.2f", $i + exit + } + } + }' +} + +# Function to run single benchmark +run_benchmark() { + local workload="$1" # "pp" or "tg" + local threads="$2" + local n_flag="" + + if [ "$workload" = "pp" ]; then + n_flag="-n 0" + workload_name="pp128" + else + n_flag="-n 128" + workload_name="tg128" + fi + + # Output progress to stderr (won't be captured in CSV) + echo "Testing $workload_name with $threads threads..." >&2 + + # Start monitoring + touch "$PID_FILE" + monitor_cpu "$TEMP_LOG" & + local monitor_pid=$! + + # Run benchmark + ./build/bin/llama-bench -m "$MODEL_PATH" -p 128 $n_flag -t "$threads" -ngl 0 > "$BENCH_OUTPUT" 2>&1 + + # Stop monitoring + rm -f "$PID_FILE" + wait $monitor_pid 2>/dev/null || true + + # Extract results + local throughput=$(extract_throughput "$BENCH_OUTPUT" "$workload_name") + local power=$(calculate_power "$TEMP_LOG") + + if [ -z "$throughput" ] || [ "$throughput" = "0" ]; then + echo "Warning: Failed to extract throughput for $workload_name, threads=$threads" >&2 + throughput="0" + fi + + # Calculate J/t (Joules per token) + local j_per_token=$(awk -v p="$power" -v t="$throughput" 'BEGIN { + if (t > 0) printf "%.4f", p/t; else print "0" + }') + + # Output progress to stderr + echo " Throughput: $throughput t/s, Power: $power W, Energy: $j_per_token J/t" >&2 + + # Only output CSV line to stdout (this will be captured) + echo "$workload_name,$threads,$throughput,$power,$j_per_token" +} + +# Initialize CSV +echo "Workload,Threads,Throughput(t/s),Power(W),Energy(J/t)" > "$OUTPUT_CSV" + +# Test PP workloads +IFS=',' read -ra PP_ARRAY <<< "$PP_THREADS" +for threads in "${PP_ARRAY[@]}"; do + threads=$(echo "$threads" | xargs) # trim whitespace + result=$(run_benchmark "pp" "$threads") + echo "$result" >> "$OUTPUT_CSV" +done + +# Test TG workloads +IFS=',' read -ra TG_ARRAY <<< "$TG_THREADS" +for threads in "${TG_ARRAY[@]}"; do + threads=$(echo "$threads" | xargs) # trim whitespace + result=$(run_benchmark "tg" "$threads") + echo "$result" >> "$OUTPUT_CSV" +done + +# Cleanup +rm -f "$TEMP_LOG" "$BENCH_OUTPUT" "$PID_FILE" + +echo "" +echo "=== Benchmark Complete ===" +echo "Results saved to: $OUTPUT_CSV" +echo "" +cat "$OUTPUT_CSV" diff --git a/utils/tune_gemm_config.py b/utils/tune_gemm_config.py new file mode 100644 index 000000000..e537cd832 --- /dev/null +++ b/utils/tune_gemm_config.py @@ -0,0 +1,362 @@ +#!/usr/bin/env python3 +""" +GEMM Configuration Tuning Script +This script automatically tunes ROW_BLOCK_SIZE, COL_BLOCK_SIZE, and PARALLEL_SIZE +to find the optimal configuration for maximum throughput (t/s). +""" + +import subprocess +import os +import re +import csv +import shutil +from datetime import datetime +from pathlib import Path +import argparse + + +class GemmTuner: + def __init__(self, config_path, model_path, threads=16): + self.config_path = Path(config_path) + self.model_path = model_path + self.threads = threads + self.backup_path = self.config_path.parent / f"gemm-config.h.backup_{datetime.now().strftime('%Y%m%d_%H%M%S')}" + self.build_dir = Path("../build") + self.results = [] + + def backup_config(self): + """Backup current configuration file""" + print(f"📦 Backing up current config to {self.backup_path}") + shutil.copy2(self.config_path, self.backup_path) + + def restore_config(self): + """Restore original configuration file""" + print(f"♻️ Restoring original config from {self.backup_path}") + shutil.copy2(self.backup_path, self.config_path) + + def generate_config(self, act_parallel, row_block_size, col_block_size, parallel_size): + """Generate new configuration file with simplified format""" + content = "" + + # Simplified configuration format + if act_parallel: + content += "#define ACT_PARALLEL\n" + + content += f"#define ROW_BLOCK_SIZE {row_block_size}\n" + content += f"#define COL_BLOCK_SIZE {col_block_size}\n" + content += f"#define PARALLEL_SIZE {parallel_size}\n" + + with open(self.config_path, 'w') as f: + f.write(content) + + def rebuild_project(self): + """Rebuild project""" + print("🔨 Rebuilding project...") + result = subprocess.run( + ["cmake", "--build", str(self.build_dir), "--target", "llama-bench"], + capture_output=True, + text=True, + cwd=os.getcwd() + ) + if result.returncode != 0: + print(f"⚠️ Build warning/error: {result.stderr}") + return False + return True + + def run_benchmark(self): + """Run benchmark test""" + cmd = [ + f"{self.build_dir}/bin/llama-bench", + "-m", self.model_path, + "-p", "128", + "-n", "0", + "-t", str(self.threads), + "-ngl", "0" + ] + + print(f"⚡ Running benchmark: {' '.join(cmd)}") + + result = subprocess.run( + cmd, + capture_output=True, + text=True, + cwd=os.getcwd(), + timeout=300 # 5分钟超时 + ) + + if result.returncode != 0: + print(f"❌ Benchmark failed: {result.stderr}") + return None + + return result.stdout + + def parse_throughput(self, output): + """Parse pp128 throughput from output""" + # 匹配 pp128: | pp128 | 501.06 ± 11.37 | + pp_pattern = r'\|\s+pp128\s+\|\s+([\d.]+)\s+±\s+([\d.]+)\s+\|' + pp_match = re.search(pp_pattern, output) + + if pp_match: + pp_throughput = float(pp_match.group(1)) + pp_std_dev = float(pp_match.group(2)) + + return { + 'pp_throughput': pp_throughput, + 'pp_std_dev': pp_std_dev + } + + return None + + def test_configuration(self, act_parallel, row_block_size, col_block_size, parallel_size): + """Test single configuration""" + config_name = f"ACT_{'ON' if act_parallel else 'OFF'}_R{row_block_size}_C{col_block_size}_P{parallel_size}" + print(f"\n{'='*80}") + print(f"🧪 Testing configuration: {config_name}") + print(f" ACT_PARALLEL: {act_parallel}") + print(f" ROW_BLOCK_SIZE: {row_block_size}") + print(f" COL_BLOCK_SIZE: {col_block_size}") + print(f" PARALLEL_SIZE: {parallel_size}") + print(f"{'='*80}") + + # Generate configuration + self.generate_config(act_parallel, row_block_size, col_block_size, parallel_size) + + # Rebuild project + if not self.rebuild_project(): + print("⚠️ Build failed, skipping this configuration") + return None + + # Run benchmark test + output = self.run_benchmark() + if output is None: + return None + + # Parse results + metrics = self.parse_throughput(output) + + if metrics is not None: + result = { + 'act_parallel': act_parallel, + 'row_block_size': row_block_size, + 'col_block_size': col_block_size, + 'parallel_size': parallel_size, + 'config_name': config_name, + **metrics + } + self.results.append(result) + print(f"✅ PP128: {metrics['pp_throughput']:.2f} ± {metrics['pp_std_dev']:.2f} t/s") + return result + else: + print("❌ Failed to parse throughput") + return None + + def save_results(self, csv_path): + """Save results to CSV file""" + print(f"\n💾 Saving results to {csv_path}") + + with open(csv_path, 'w', newline='') as f: + writer = csv.DictWriter(f, fieldnames=[ + 'config_name', 'act_parallel', 'row_block_size', + 'col_block_size', 'parallel_size', + 'pp_throughput', 'pp_std_dev' + ]) + writer.writeheader() + writer.writerows(self.results) + + def find_best_config(self): + """Find the best configuration with highest throughput""" + if not self.results: + print("❌ No valid results found") + return None + + best = max(self.results, key=lambda x: x['pp_throughput']) + return best + + def run_tuning(self, configurations, output_csv=None): + """Run test for all configurations""" + print(f"\n🚀 Starting tuning process with {len(configurations)} configurations") + print(f"📊 Model: {self.model_path}") + print(f"🧵 Threads: {self.threads}\n") + + # Backup configuration + self.backup_config() + + try: + # Test all configurations + for i, config in enumerate(configurations, 1): + print(f"\n[{i}/{len(configurations)}]") + self.test_configuration(**config) + + # Save results + if output_csv is None: + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + csv_path = f"../stats/tuning_results_{timestamp}.csv" + else: + csv_path = output_csv + + # Ensure stats directory exists + os.makedirs(os.path.dirname(csv_path), exist_ok=True) + self.save_results(csv_path) + + # Find best configuration + best = self.find_best_config() + if best: + print(f"\n{'='*80}") + print(f"🏆 BEST CONFIGURATION FOUND!") + print(f"{'='*80}") + print(f"Configuration: {best['config_name']}") + print(f"ACT_PARALLEL: {best['act_parallel']}") + print(f"ROW_BLOCK_SIZE: {best['row_block_size']}") + print(f"COL_BLOCK_SIZE: {best['col_block_size']}") + print(f"PARALLEL_SIZE: {best['parallel_size']}") + print(f"PP128 Throughput: {best['pp_throughput']:.2f} ± {best['pp_std_dev']:.2f} t/s") + print(f"{'='*80}\n") + + # Show the configuration that will be written + print("Configuration to be written to gemm-config.h:") + print("-" * 80) + if best['act_parallel']: + print("#define ACT_PARALLEL") + print(f"#define ROW_BLOCK_SIZE {best['row_block_size']}") + print(f"#define COL_BLOCK_SIZE {best['col_block_size']}") + print(f"#define PARALLEL_SIZE {best['parallel_size']}") + print("-" * 80) + + # Apply best configuration + apply = input("\nDo you want to apply this configuration to gemm-config.h? (y/n): ").strip().lower() + if apply == 'y': + self.generate_config( + best['act_parallel'], + best['row_block_size'], + best['col_block_size'], + best['parallel_size'] + ) + self.rebuild_project() + print("✅ Best configuration applied and project rebuilt!") + else: + self.restore_config() + print("✅ Original configuration restored") + + # Clean up backup file + if self.backup_path.exists(): + self.backup_path.unlink() + print(f"🗑️ Removed backup file: {self.backup_path}") + + except KeyboardInterrupt: + print("\n⚠️ Tuning interrupted by user") + self.restore_config() + # Clean up backup file + if self.backup_path.exists(): + self.backup_path.unlink() + print(f"🗑️ Removed backup file: {self.backup_path}") + except Exception as e: + print(f"\n❌ Error during tuning: {e}") + self.restore_config() + # Clean up backup file + if self.backup_path.exists(): + self.backup_path.unlink() + print(f"🗑️ Removed backup file: {self.backup_path}") + raise + + +def generate_configurations(): + """Generate list of configurations to test""" + configurations = [] + + act_parallel_options = [True] + + row_sizes = [2, 4, 8]#[2, 4, 8, 16, 32] + col_sizes = [32, 64]#[32, 64, 128, 256, 512, 1024] + parallelism_degree = [4] + + for act_parallel in act_parallel_options: + for row in row_sizes: + for col in col_sizes: + for parallel in parallelism_degree: + # Add filtering conditions + if act_parallel: + # When ACT_PARALLEL=True, only calculate combinations with parallel < row + if parallel > row: + continue + else: + # When ACT_PARALLEL=False, only calculate combinations with parallel < col + if parallel > col: + continue + + configurations.append({ + 'act_parallel': act_parallel, + 'row_block_size': row, + 'col_block_size': col, + 'parallel_size': parallel + }) + + return configurations + + +def main(): + parser = argparse.ArgumentParser(description='Tune GEMM configuration for optimal performance') + parser.add_argument('--config', default='../include/gemm-config.h', + help='Path to gemm-config.h file') + parser.add_argument('--model', default='../models/BitNet-b1.58-2B-4T/ggml-model-i2_s-embed-q6_k.gguf', + help='Path to model file') + parser.add_argument('--threads', type=int, default=8, + help='Number of threads to use') + parser.add_argument('--quick', action='store_true', + help='Quick test with fewer configurations') + parser.add_argument('--custom', action='store_true', + help='Manually specify configurations to test') + parser.add_argument('--output', type=str, default=None, + help='Output CSV file path (default: stats/tuning_results_.csv)') + + args = parser.parse_args() + + tuner = GemmTuner(args.config, args.model, args.threads) + + if args.custom: + # Custom configuration mode + print("Custom configuration mode") + configurations = [] + while True: + print("\nEnter configuration (or 'done' to finish):") + act = input("ACT_PARALLEL (y/n): ").strip().lower() == 'y' + if input == 'done': + break + row = int(input("ROW_BLOCK_SIZE: ")) + col = int(input("COL_BLOCK_SIZE: ")) + par = int(input("PARALLEL_SIZE: ")) + configurations.append({ + 'act_parallel': act, + 'row_block_size': row, + 'col_block_size': col, + 'parallel_size': par + }) + elif args.quick: + # Quick test mode - test only a few key configurations + configurations = [ + {'act_parallel': True, 'row_block_size': 4, 'col_block_size': 128, 'parallel_size': 4}, + {'act_parallel': True, 'row_block_size': 8, 'col_block_size': 128, 'parallel_size': 4}, + {'act_parallel': True, 'row_block_size': 4, 'col_block_size': 64, 'parallel_size': 4}, + {'act_parallel': False, 'row_block_size': 32, 'col_block_size': 4, 'parallel_size': 4}, + {'act_parallel': False, 'row_block_size': 16, 'col_block_size': 4, 'parallel_size': 4}, + ] + else: + # Full test mode + configurations = generate_configurations() + + print(f"\n{'='*80}") + print(f"GEMM Configuration Tuner") + print(f"{'='*80}") + print(f"Total configurations to test: {len(configurations)}") + print(f"Estimated time: ~{len(configurations) * 0.5:.1f} minutes (assuming 30s per test)") + print(f"{'='*80}\n") + + proceed = input("Proceed with tuning? (y/n): ").strip().lower() + if proceed != 'y': + print("Tuning cancelled") + return + + tuner.run_tuning(configurations, output_csv=args.output) + + +if __name__ == "__main__": + main()