// SPDX-License-Identifier: GPL-2.0 #include #include "util/compress.h" #include "util/debug.h" int zstd_init(struct zstd_data *data, int level) { size_t ret; data->dstream = ZSTD_createDStream(); if (data->dstream == NULL) { pr_err("Couldn't create decompression stream.\n"); return -1; } ret = ZSTD_initDStream(data->dstream); if (ZSTD_isError(ret)) { pr_err("Failed to initialize decompression stream: %s\n", ZSTD_getErrorName(ret)); return -1; } if (!level) return 0; data->cstream = ZSTD_createCStream(); if (data->cstream == NULL) { pr_err("Couldn't create compression stream.\n"); return -1; } ret = ZSTD_initCStream(data->cstream, level); if (ZSTD_isError(ret)) { pr_err("Failed to initialize compression stream: %s\n", ZSTD_getErrorName(ret)); return -1; } return 0; } int zstd_fini(struct zstd_data *data) { if (data->dstream) { ZSTD_freeDStream(data->dstream); data->dstream = NULL; } if (data->cstream) { ZSTD_freeCStream(data->cstream); data->cstream = NULL; } return 0; } size_t zstd_compress_stream_to_records(struct zstd_data *data, void *dst, size_t dst_size, void *src, size_t src_size, size_t max_record_size, size_t process_header(void *record, size_t increment)) { size_t ret, size, compressed = 0; ZSTD_inBuffer input = { src, src_size, 0 }; ZSTD_outBuffer output; void *record; while (input.pos < input.size) { record = dst; size = process_header(record, 0); compressed += size; dst += size; dst_size -= size; output = (ZSTD_outBuffer){ dst, (dst_size > max_record_size) ? max_record_size : dst_size, 0 }; ret = ZSTD_compressStream(data->cstream, &output, &input); ZSTD_flushStream(data->cstream, &output); if (ZSTD_isError(ret)) { pr_err("failed to compress %ld bytes: %s\n", (long)src_size, ZSTD_getErrorName(ret)); memcpy(dst, src, src_size); return src_size; } size = output.pos; size = process_header(record, size); compressed += size; dst += size; dst_size -= size; } return compressed; } size_t zstd_decompress_stream(struct zstd_data *data, void *src, size_t src_size, void *dst, size_t dst_size) { size_t ret; ZSTD_inBuffer input = { src, src_size, 0 }; ZSTD_outBuffer output = { dst, dst_size, 0 }; while (input.pos < input.size) { ret = ZSTD_decompressStream(data->dstream, &output, &input); if (ZSTD_isError(ret)) { pr_err("failed to decompress (B): %zd -> %zd, dst_size %zd : %s\n", src_size, output.size, dst_size, ZSTD_getErrorName(ret)); break; } output.dst = dst + output.pos; output.size = dst_size - output.pos; } return output.pos; }