alpm_compress/compression/
encoder.rs

1//! Encoder for compression which supports multiple backends.
2
3use std::{fmt::Debug, fs::File, io::Write};
4
5use alpm_types::CompressionAlgorithmFileExtension;
6use bzip2::write::BzEncoder;
7use flate2::write::GzEncoder;
8use fluent_i18n::t;
9use liblzma::write::XzEncoder;
10use zstd::Encoder;
11
12use crate::{
13    Error,
14    compression::{CompressionSettings, ZstdThreads, level::ZstdCompressionLevel},
15};
16
17/// Creates and configures an [`Encoder`].
18///
19/// Uses a dedicated `compression_level` and amount of `threads` to construct and configure an
20/// encoder for zstd compression.
21/// The `settings` are merely used for additional context in cases of error.
22///
23/// # Errors
24///
25/// Returns an error if
26///
27/// - the encoder cannot be created using the `file` and `compression_level`,
28/// - the encoder cannot be configured to use checksums at the end of each frame,
29/// - the amount of physical CPU cores can not be turned into a `u32`,
30/// - or multithreading can not be enabled based on the provided `threads` settings.
31fn create_zstd_encoder(
32    file: File,
33    compression_level: &ZstdCompressionLevel,
34    threads: &ZstdThreads,
35    settings: &CompressionSettings,
36) -> Result<Encoder<'static, File>, Error> {
37    let mut encoder = Encoder::new(file, compression_level.into()).map_err(|source| {
38        Error::CreateZstandardEncoder {
39            context: t!("error-create-zstd-encoder-init"),
40            compression_settings: settings.clone(),
41            source,
42        }
43    })?;
44
45    // Include a context checksum at the end of each frame.
46    encoder
47        .include_checksum(true)
48        .map_err(|source| Error::CreateZstandardEncoder {
49            context: t!("error-create-zstd-encoder-set-checksum"),
50            compression_settings: settings.clone(),
51            source,
52        })?;
53
54    // Get amount of threads to use.
55    let threads = match threads {
56        // Use available physical CPU cores if the special value `0` is used.
57        // NOTE: For the zstd executable `0` means "use all available threads", while for the zstd
58        // crate this means "disable multithreading".
59        ZstdThreads(0) => {
60            u32::try_from(num_cpus::get_physical()).map_err(Error::IntegerConversion)?
61        }
62        ZstdThreads(threads) => *threads,
63    };
64
65    // Use multi-threading if it is available.
66    encoder
67        .multithread(threads)
68        .map_err(|source| Error::CreateZstandardEncoder {
69            context: t!("error-create-zstd-encoder-set-threads"),
70            compression_settings: settings.clone(),
71            source,
72        })?;
73
74    Ok(encoder)
75}
76
77/// Encoder for compression which supports multiple backends.
78///
79/// Wraps [`BzEncoder`], [`GzEncoder`], [`XzEncoder`] and [`Encoder`].
80/// Provides a unified [`Write`] implementation across all of them.
81pub enum CompressionEncoder<'a> {
82    /// The bzip2 compression encoder.
83    Bzip2(BzEncoder<File>),
84
85    /// The gzip compression encoder.
86    Gzip(GzEncoder<File>),
87
88    /// The xz compression encoder.
89    Xz(XzEncoder<File>),
90
91    /// The zstd compression encoder.
92    Zstd(Encoder<'a, File>),
93
94    /// No compression.
95    None(File),
96}
97
98impl CompressionEncoder<'_> {
99    /// Creates a new [`CompressionEncoder`].
100    ///
101    /// Uses a [`File`] to stream to and initializes a specific backend based on the provided
102    /// [`CompressionSettings`].
103    ///
104    /// # Errors
105    ///
106    /// Returns an error if creating the encoder for zstd compression fails.
107    /// All other encoder initializations are infallible.
108    pub fn new(file: File, settings: &CompressionSettings) -> Result<Self, Error> {
109        Ok(match settings {
110            CompressionSettings::Bzip2 { compression_level } => Self::Bzip2(BzEncoder::new(
111                file,
112                bzip2::Compression::new(compression_level.into()),
113            )),
114            CompressionSettings::Gzip { compression_level } => Self::Gzip(GzEncoder::new(
115                file,
116                flate2::Compression::new(compression_level.into()),
117            )),
118            CompressionSettings::Xz { compression_level } => {
119                Self::Xz(XzEncoder::new_parallel(file, compression_level.into()))
120            }
121            CompressionSettings::Zstd {
122                compression_level,
123                threads,
124            } => Self::Zstd(create_zstd_encoder(
125                file,
126                compression_level,
127                threads,
128                settings,
129            )?),
130            CompressionSettings::None => Self::None(file),
131        })
132    }
133
134    /// Finishes the compression stream.
135    ///
136    /// # Error
137    ///
138    /// Returns an error if the wrapped encoder fails.
139    pub fn finish(self) -> Result<File, Error> {
140        match self {
141            CompressionEncoder::Bzip2(encoder) => {
142                encoder.finish().map_err(|source| Error::FinishEncoder {
143                    compression_type: CompressionAlgorithmFileExtension::Bzip2,
144                    source,
145                })
146            }
147            CompressionEncoder::Gzip(encoder) => {
148                encoder.finish().map_err(|source| Error::FinishEncoder {
149                    compression_type: CompressionAlgorithmFileExtension::Gzip,
150                    source,
151                })
152            }
153            CompressionEncoder::Xz(encoder) => {
154                encoder.finish().map_err(|source| Error::FinishEncoder {
155                    compression_type: CompressionAlgorithmFileExtension::Xz,
156                    source,
157                })
158            }
159            CompressionEncoder::Zstd(encoder) => {
160                encoder.finish().map_err(|source| Error::FinishEncoder {
161                    compression_type: CompressionAlgorithmFileExtension::Zstd,
162                    source,
163                })
164            }
165            CompressionEncoder::None(file) => Ok(file),
166        }
167    }
168}
169
170impl Debug for CompressionEncoder<'_> {
171    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172        write!(
173            f,
174            "CompressionEncoder({})",
175            match self {
176                CompressionEncoder::Bzip2(_) => "Bzip2",
177                CompressionEncoder::Gzip(_) => "Gzip",
178                CompressionEncoder::Xz(_) => "Xz",
179                CompressionEncoder::Zstd(_) => "Zstd",
180                &CompressionEncoder::None(_) => "None",
181            }
182        )
183    }
184}
185
186impl Write for CompressionEncoder<'_> {
187    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
188        match self {
189            CompressionEncoder::Bzip2(encoder) => encoder.write(buf),
190            CompressionEncoder::Gzip(encoder) => encoder.write(buf),
191            CompressionEncoder::Xz(encoder) => encoder.write(buf),
192            CompressionEncoder::Zstd(encoder) => encoder.write(buf),
193            CompressionEncoder::None(file) => file.write(buf),
194        }
195    }
196
197    fn write_vectored(&mut self, bufs: &[std::io::IoSlice<'_>]) -> std::io::Result<usize> {
198        match self {
199            CompressionEncoder::Bzip2(encoder) => encoder.write_vectored(bufs),
200            CompressionEncoder::Gzip(encoder) => encoder.write_vectored(bufs),
201            CompressionEncoder::Xz(encoder) => encoder.write_vectored(bufs),
202            CompressionEncoder::Zstd(encoder) => encoder.write_vectored(bufs),
203            CompressionEncoder::None(file) => file.write_vectored(bufs),
204        }
205    }
206
207    fn flush(&mut self) -> std::io::Result<()> {
208        match self {
209            CompressionEncoder::Bzip2(encoder) => encoder.flush(),
210            CompressionEncoder::Gzip(encoder) => encoder.flush(),
211            CompressionEncoder::Xz(encoder) => encoder.flush(),
212            CompressionEncoder::Zstd(encoder) => encoder.flush(),
213            CompressionEncoder::None(file) => file.flush(),
214        }
215    }
216
217    fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
218        match self {
219            CompressionEncoder::Bzip2(encoder) => encoder.write_all(buf),
220            CompressionEncoder::Gzip(encoder) => encoder.write_all(buf),
221            CompressionEncoder::Xz(encoder) => encoder.write_all(buf),
222            CompressionEncoder::Zstd(encoder) => encoder.write_all(buf),
223            CompressionEncoder::None(file) => file.write_all(buf),
224        }
225    }
226
227    fn write_fmt(&mut self, fmt: std::fmt::Arguments<'_>) -> std::io::Result<()> {
228        match self {
229            CompressionEncoder::Bzip2(encoder) => encoder.write_fmt(fmt),
230            CompressionEncoder::Gzip(encoder) => encoder.write_fmt(fmt),
231            CompressionEncoder::Xz(encoder) => encoder.write_fmt(fmt),
232            CompressionEncoder::Zstd(encoder) => encoder.write_fmt(fmt),
233            CompressionEncoder::None(file) => file.write_fmt(fmt),
234        }
235    }
236
237    fn by_ref(&mut self) -> &mut Self
238    where
239        Self: Sized,
240    {
241        self
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use std::io::IoSlice;
248
249    use rstest::rstest;
250    use tempfile::tempfile;
251    use testresult::TestResult;
252
253    use super::*;
254    use crate::compression::level::{
255        Bzip2CompressionLevel,
256        GzipCompressionLevel,
257        XzCompressionLevel,
258        ZstdCompressionLevel,
259    };
260
261    /// Ensures that the [`Write::write`] implementation works for each [`CompressionEncoder`].
262    #[rstest]
263    #[case::bzip2(CompressionSettings::Bzip2 { compression_level: Bzip2CompressionLevel::default()})]
264    #[case::gzip(CompressionSettings::Gzip { compression_level: GzipCompressionLevel::default()})]
265    #[case::xz(CompressionSettings::Xz { compression_level: XzCompressionLevel::default()})]
266    #[case::zstd_all_threads(CompressionSettings::Zstd { compression_level: ZstdCompressionLevel::default(), threads: ZstdThreads::new(0) })]
267    #[case::zstd_one_thread(CompressionSettings::Zstd { compression_level: ZstdCompressionLevel::default(), threads: ZstdThreads::new(1) })]
268    #[case::zstd_crazy_threads(CompressionSettings::Zstd { compression_level: ZstdCompressionLevel::default(), threads: ZstdThreads::new(99999) })]
269    #[case::no_compression(CompressionSettings::None)]
270    fn test_compression_encoder_write(#[case] settings: CompressionSettings) -> TestResult {
271        let file = tempfile()?;
272        let mut encoder = CompressionEncoder::new(file, &settings)?;
273        let ref_encoder = encoder.by_ref();
274        let buf = &[1; 8];
275
276        let mut write_len = 0;
277        while write_len < buf.len() {
278            let len_written = ref_encoder.write(buf)?;
279            write_len += len_written;
280        }
281
282        ref_encoder.flush()?;
283
284        Ok(())
285    }
286
287    /// Ensures that the [`Write::write_vectored`] implementation works for each
288    /// [`CompressionEncoder`].
289    #[rstest]
290    #[case::bzip2(CompressionSettings::Bzip2 { compression_level: Bzip2CompressionLevel::default()})]
291    #[case::gzip(CompressionSettings::Gzip { compression_level: GzipCompressionLevel::default()})]
292    #[case::xz(CompressionSettings::Xz { compression_level: XzCompressionLevel::default()})]
293    #[case::zstd_all_threads(CompressionSettings::Zstd { compression_level: ZstdCompressionLevel::default(), threads: ZstdThreads::new(0) })]
294    #[case::zstd_one_thread(CompressionSettings::Zstd { compression_level: ZstdCompressionLevel::default(), threads: ZstdThreads::new(1) })]
295    #[case::zstd_crazy_threads(CompressionSettings::Zstd { compression_level: ZstdCompressionLevel::default(), threads: ZstdThreads::new(99999) })]
296    #[case::no_compression(CompressionSettings::None)]
297    fn test_compression_encoder_write_vectored(
298        #[case] settings: CompressionSettings,
299    ) -> TestResult {
300        let file = tempfile()?;
301        let mut encoder = CompressionEncoder::new(file, &settings)?;
302        let ref_encoder = encoder.by_ref();
303
304        let data1 = [1; 8];
305        let data2 = [15; 8];
306        let io_slice1 = IoSlice::new(&data1);
307        let io_slice2 = IoSlice::new(&data2);
308
309        let mut write_len = 0;
310        while write_len < data1.len() + data2.len() {
311            let len_written = ref_encoder.write_vectored(&[io_slice1, io_slice2])?;
312            write_len += len_written;
313        }
314
315        ref_encoder.flush()?;
316
317        Ok(())
318    }
319
320    /// Ensures that the [`Write::write_all`] implementation works for each [`CompressionEncoder`].
321    #[rstest]
322    #[case::bzip2(CompressionSettings::Bzip2 { compression_level: Bzip2CompressionLevel::default()})]
323    #[case::gzip(CompressionSettings::Gzip { compression_level: GzipCompressionLevel::default()})]
324    #[case::xz(CompressionSettings::Xz { compression_level: XzCompressionLevel::default()})]
325    #[case::zstd_all_threads(CompressionSettings::Zstd { compression_level: ZstdCompressionLevel::default(), threads: ZstdThreads::new(0) })]
326    #[case::zstd_one_thread(CompressionSettings::Zstd { compression_level: ZstdCompressionLevel::default(), threads: ZstdThreads::new(1) })]
327    #[case::zstd_crazy_threads(CompressionSettings::Zstd { compression_level: ZstdCompressionLevel::default(), threads: ZstdThreads::new(99999) })]
328    #[case::no_compression(CompressionSettings::None)]
329    fn test_compression_encoder_write_all(#[case] settings: CompressionSettings) -> TestResult {
330        let file = tempfile()?;
331        let mut encoder = CompressionEncoder::new(file, &settings)?;
332        let ref_encoder = encoder.by_ref();
333        let buf = &[1; 8];
334
335        ref_encoder.write_all(buf)?;
336
337        ref_encoder.flush()?;
338
339        Ok(())
340    }
341
342    /// Ensures that the [`Write::write_fmt`] implementation works for each [`CompressionEncoder`].
343    #[rstest]
344    #[case::bzip2(CompressionSettings::Bzip2 { compression_level: Bzip2CompressionLevel::default()})]
345    #[case::gzip(CompressionSettings::Gzip { compression_level: GzipCompressionLevel::default()})]
346    #[case::xz(CompressionSettings::Xz { compression_level: XzCompressionLevel::default()})]
347    #[case::zstd_all_threads(CompressionSettings::Zstd { compression_level: ZstdCompressionLevel::default(), threads: ZstdThreads::new(0) })]
348    #[case::zstd_one_thread(CompressionSettings::Zstd { compression_level: ZstdCompressionLevel::default(), threads: ZstdThreads::new(1) })]
349    #[case::zstd_crazy_threads(CompressionSettings::Zstd { compression_level: ZstdCompressionLevel::default(), threads: ZstdThreads::new(99999) })]
350    #[case::no_compression(CompressionSettings::None)]
351    fn test_compression_encoder_write_fmt(#[case] settings: CompressionSettings) -> TestResult {
352        let file = tempfile()?;
353        let mut encoder = CompressionEncoder::new(file, &settings)?;
354        let ref_encoder = encoder.by_ref();
355
356        ref_encoder.write_fmt(format_args!("{:.*}", 2, 1.234567))?;
357
358        ref_encoder.flush()?;
359
360        Ok(())
361    }
362}