alpm_compress/compression/
encoder.rs

1//! Encoder for compression which supports multiple backends.
2
3use std::{fmt::Debug, fs::File, io::Write};
4
5use alpm_types::CompressionAlgorithmFileExtension;
6use bzip2::write::BzEncoder;
7use flate2::write::GzEncoder;
8use liblzma::write::XzEncoder;
9use zstd::Encoder;
10
11use crate::{
12    Error,
13    compression::{CompressionSettings, ZstdThreads, level::ZstdCompressionLevel},
14};
15
16/// Creates and configures an [`Encoder`].
17///
18/// Uses a dedicated `compression_level` and amount of `threads` to construct and configure an
19/// encoder for zstd compression.
20/// The `settings` are merely used for additional context in cases of error.
21///
22/// # Errors
23///
24/// Returns an error if
25///
26/// - the encoder cannot be created using the `file` and `compression_level`,
27/// - the encoder cannot be configured to use checksums at the end of each frame,
28/// - the amount of physical CPU cores can not be turned into a `u32`,
29/// - or multithreading can not be enabled based on the provided `threads` settings.
30fn create_zstd_encoder(
31    file: File,
32    compression_level: &ZstdCompressionLevel,
33    threads: &ZstdThreads,
34    settings: &CompressionSettings,
35) -> Result<Encoder<'static, File>, Error> {
36    let mut encoder = Encoder::new(file, compression_level.into()).map_err(|source| {
37        Error::CreateZstandardEncoder {
38            context: "initializing",
39            compression_settings: settings.clone(),
40            source,
41        }
42    })?;
43    // Include a context checksum at the end of each frame.
44    encoder
45        .include_checksum(true)
46        .map_err(|source| Error::CreateZstandardEncoder {
47            context: "setting checksums to be added",
48            compression_settings: settings.clone(),
49            source,
50        })?;
51
52    // Get amount of threads to use.
53    let threads = match threads {
54        // Use available physical CPU cores if the special value `0` is used.
55        // NOTE: For the zstd executable `0` means "use all available threads", while for the zstd
56        // crate this means "disable multithreading".
57        ZstdThreads(0) => {
58            u32::try_from(num_cpus::get_physical()).map_err(Error::IntegerConversion)?
59        }
60        ZstdThreads(threads) => *threads,
61    };
62
63    // Use multi-threading if it is available.
64    encoder
65        .multithread(threads)
66        .map_err(|source| Error::CreateZstandardEncoder {
67            context: "setting checksums to be added",
68            compression_settings: settings.clone(),
69            source,
70        })?;
71
72    Ok(encoder)
73}
74
75/// Encoder for compression which supports multiple backends.
76///
77/// Wraps [`BzEncoder`], [`GzEncoder`], [`XzEncoder`] and [`Encoder`].
78/// Provides a unified [`Write`] implementation across all of them.
79pub enum CompressionEncoder<'a> {
80    /// The bzip2 compression encoder.
81    Bzip2(BzEncoder<File>),
82
83    /// The gzip compression encoder.
84    Gzip(GzEncoder<File>),
85
86    /// The xz compression encoder.
87    Xz(XzEncoder<File>),
88
89    /// The zstd compression encoder.
90    Zstd(Encoder<'a, File>),
91
92    /// No compression.
93    None(File),
94}
95
96impl CompressionEncoder<'_> {
97    /// Creates a new [`CompressionEncoder`].
98    ///
99    /// Uses a [`File`] to stream to and initializes a specific backend based on the provided
100    /// [`CompressionSettings`].
101    ///
102    /// # Errors
103    ///
104    /// Returns an error if creating the encoder for zstd compression fails.
105    /// All other encoder initializations are infallible.
106    pub fn new(file: File, settings: &CompressionSettings) -> Result<Self, Error> {
107        Ok(match settings {
108            CompressionSettings::Bzip2 { compression_level } => Self::Bzip2(BzEncoder::new(
109                file,
110                bzip2::Compression::new(compression_level.into()),
111            )),
112            CompressionSettings::Gzip { compression_level } => Self::Gzip(GzEncoder::new(
113                file,
114                flate2::Compression::new(compression_level.into()),
115            )),
116            CompressionSettings::Xz { compression_level } => {
117                Self::Xz(XzEncoder::new_parallel(file, compression_level.into()))
118            }
119            CompressionSettings::Zstd {
120                compression_level,
121                threads,
122            } => Self::Zstd(create_zstd_encoder(
123                file,
124                compression_level,
125                threads,
126                settings,
127            )?),
128            CompressionSettings::None => Self::None(file),
129        })
130    }
131
132    /// Finishes the compression stream.
133    ///
134    /// # Error
135    ///
136    /// Returns an error if the wrapped encoder fails.
137    pub fn finish(self) -> Result<File, Error> {
138        match self {
139            CompressionEncoder::Bzip2(encoder) => {
140                encoder.finish().map_err(|source| Error::FinishEncoder {
141                    compression_type: CompressionAlgorithmFileExtension::Bzip2,
142                    source,
143                })
144            }
145            CompressionEncoder::Gzip(encoder) => {
146                encoder.finish().map_err(|source| Error::FinishEncoder {
147                    compression_type: CompressionAlgorithmFileExtension::Gzip,
148                    source,
149                })
150            }
151            CompressionEncoder::Xz(encoder) => {
152                encoder.finish().map_err(|source| Error::FinishEncoder {
153                    compression_type: CompressionAlgorithmFileExtension::Xz,
154                    source,
155                })
156            }
157            CompressionEncoder::Zstd(encoder) => {
158                encoder.finish().map_err(|source| Error::FinishEncoder {
159                    compression_type: CompressionAlgorithmFileExtension::Zstd,
160                    source,
161                })
162            }
163            CompressionEncoder::None(file) => Ok(file),
164        }
165    }
166}
167
168impl Debug for CompressionEncoder<'_> {
169    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
170        write!(
171            f,
172            "CompressionEncoder({})",
173            match self {
174                CompressionEncoder::Bzip2(_) => "Bzip2",
175                CompressionEncoder::Gzip(_) => "Gzip",
176                CompressionEncoder::Xz(_) => "Xz",
177                CompressionEncoder::Zstd(_) => "Zstd",
178                &CompressionEncoder::None(_) => "None",
179            }
180        )
181    }
182}
183
184impl Write for CompressionEncoder<'_> {
185    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
186        match self {
187            CompressionEncoder::Bzip2(encoder) => encoder.write(buf),
188            CompressionEncoder::Gzip(encoder) => encoder.write(buf),
189            CompressionEncoder::Xz(encoder) => encoder.write(buf),
190            CompressionEncoder::Zstd(encoder) => encoder.write(buf),
191            CompressionEncoder::None(file) => file.write(buf),
192        }
193    }
194
195    fn write_vectored(&mut self, bufs: &[std::io::IoSlice<'_>]) -> std::io::Result<usize> {
196        match self {
197            CompressionEncoder::Bzip2(encoder) => encoder.write_vectored(bufs),
198            CompressionEncoder::Gzip(encoder) => encoder.write_vectored(bufs),
199            CompressionEncoder::Xz(encoder) => encoder.write_vectored(bufs),
200            CompressionEncoder::Zstd(encoder) => encoder.write_vectored(bufs),
201            CompressionEncoder::None(file) => file.write_vectored(bufs),
202        }
203    }
204
205    fn flush(&mut self) -> std::io::Result<()> {
206        match self {
207            CompressionEncoder::Bzip2(encoder) => encoder.flush(),
208            CompressionEncoder::Gzip(encoder) => encoder.flush(),
209            CompressionEncoder::Xz(encoder) => encoder.flush(),
210            CompressionEncoder::Zstd(encoder) => encoder.flush(),
211            CompressionEncoder::None(file) => file.flush(),
212        }
213    }
214
215    fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
216        match self {
217            CompressionEncoder::Bzip2(encoder) => encoder.write_all(buf),
218            CompressionEncoder::Gzip(encoder) => encoder.write_all(buf),
219            CompressionEncoder::Xz(encoder) => encoder.write_all(buf),
220            CompressionEncoder::Zstd(encoder) => encoder.write_all(buf),
221            CompressionEncoder::None(file) => file.write_all(buf),
222        }
223    }
224
225    fn write_fmt(&mut self, fmt: std::fmt::Arguments<'_>) -> std::io::Result<()> {
226        match self {
227            CompressionEncoder::Bzip2(encoder) => encoder.write_fmt(fmt),
228            CompressionEncoder::Gzip(encoder) => encoder.write_fmt(fmt),
229            CompressionEncoder::Xz(encoder) => encoder.write_fmt(fmt),
230            CompressionEncoder::Zstd(encoder) => encoder.write_fmt(fmt),
231            CompressionEncoder::None(file) => file.write_fmt(fmt),
232        }
233    }
234
235    fn by_ref(&mut self) -> &mut Self
236    where
237        Self: Sized,
238    {
239        self
240    }
241}
242
243#[cfg(test)]
244mod tests {
245    use std::io::IoSlice;
246
247    use rstest::rstest;
248    use tempfile::tempfile;
249    use testresult::TestResult;
250
251    use super::*;
252    use crate::compression::level::{
253        Bzip2CompressionLevel,
254        GzipCompressionLevel,
255        XzCompressionLevel,
256        ZstdCompressionLevel,
257    };
258
259    /// Ensures that the [`Write::write`] implementation works for each [`CompressionEncoder`].
260    #[rstest]
261    #[case::bzip2(CompressionSettings::Bzip2 { compression_level: Bzip2CompressionLevel::default()})]
262    #[case::gzip(CompressionSettings::Gzip { compression_level: GzipCompressionLevel::default()})]
263    #[case::xz(CompressionSettings::Xz { compression_level: XzCompressionLevel::default()})]
264    #[case::zstd_all_threads(CompressionSettings::Zstd { compression_level: ZstdCompressionLevel::default(), threads: ZstdThreads::new(0) })]
265    #[case::zstd_one_thread(CompressionSettings::Zstd { compression_level: ZstdCompressionLevel::default(), threads: ZstdThreads::new(1) })]
266    #[case::zstd_crazy_threads(CompressionSettings::Zstd { compression_level: ZstdCompressionLevel::default(), threads: ZstdThreads::new(99999) })]
267    #[case::no_compression(CompressionSettings::None)]
268    fn test_compression_encoder_write(#[case] settings: CompressionSettings) -> TestResult {
269        let file = tempfile()?;
270        let mut encoder = CompressionEncoder::new(file, &settings)?;
271        let ref_encoder = encoder.by_ref();
272        let buf = &[1; 8];
273
274        let mut write_len = 0;
275        while write_len < buf.len() {
276            let len_written = ref_encoder.write(buf)?;
277            write_len += len_written;
278        }
279
280        ref_encoder.flush()?;
281
282        Ok(())
283    }
284
285    /// Ensures that the [`Write::write_vectored`] implementation works for each
286    /// [`CompressionEncoder`].
287    #[rstest]
288    #[case::bzip2(CompressionSettings::Bzip2 { compression_level: Bzip2CompressionLevel::default()})]
289    #[case::gzip(CompressionSettings::Gzip { compression_level: GzipCompressionLevel::default()})]
290    #[case::xz(CompressionSettings::Xz { compression_level: XzCompressionLevel::default()})]
291    #[case::zstd_all_threads(CompressionSettings::Zstd { compression_level: ZstdCompressionLevel::default(), threads: ZstdThreads::new(0) })]
292    #[case::zstd_one_thread(CompressionSettings::Zstd { compression_level: ZstdCompressionLevel::default(), threads: ZstdThreads::new(1) })]
293    #[case::zstd_crazy_threads(CompressionSettings::Zstd { compression_level: ZstdCompressionLevel::default(), threads: ZstdThreads::new(99999) })]
294    #[case::no_compression(CompressionSettings::None)]
295    fn test_compression_encoder_write_vectored(
296        #[case] settings: CompressionSettings,
297    ) -> TestResult {
298        let file = tempfile()?;
299        let mut encoder = CompressionEncoder::new(file, &settings)?;
300        let ref_encoder = encoder.by_ref();
301
302        let data1 = [1; 8];
303        let data2 = [15; 8];
304        let io_slice1 = IoSlice::new(&data1);
305        let io_slice2 = IoSlice::new(&data2);
306
307        let mut write_len = 0;
308        while write_len < data1.len() + data2.len() {
309            let len_written = ref_encoder.write_vectored(&[io_slice1, io_slice2])?;
310            write_len += len_written;
311        }
312
313        ref_encoder.flush()?;
314
315        Ok(())
316    }
317
318    /// Ensures that the [`Write::write_all`] implementation works for each [`CompressionEncoder`].
319    #[rstest]
320    #[case::bzip2(CompressionSettings::Bzip2 { compression_level: Bzip2CompressionLevel::default()})]
321    #[case::gzip(CompressionSettings::Gzip { compression_level: GzipCompressionLevel::default()})]
322    #[case::xz(CompressionSettings::Xz { compression_level: XzCompressionLevel::default()})]
323    #[case::zstd_all_threads(CompressionSettings::Zstd { compression_level: ZstdCompressionLevel::default(), threads: ZstdThreads::new(0) })]
324    #[case::zstd_one_thread(CompressionSettings::Zstd { compression_level: ZstdCompressionLevel::default(), threads: ZstdThreads::new(1) })]
325    #[case::zstd_crazy_threads(CompressionSettings::Zstd { compression_level: ZstdCompressionLevel::default(), threads: ZstdThreads::new(99999) })]
326    #[case::no_compression(CompressionSettings::None)]
327    fn test_compression_encoder_write_all(#[case] settings: CompressionSettings) -> TestResult {
328        let file = tempfile()?;
329        let mut encoder = CompressionEncoder::new(file, &settings)?;
330        let ref_encoder = encoder.by_ref();
331        let buf = &[1; 8];
332
333        ref_encoder.write_all(buf)?;
334
335        ref_encoder.flush()?;
336
337        Ok(())
338    }
339
340    /// Ensures that the [`Write::write_fmt`] implementation works for each [`CompressionEncoder`].
341    #[rstest]
342    #[case::bzip2(CompressionSettings::Bzip2 { compression_level: Bzip2CompressionLevel::default()})]
343    #[case::gzip(CompressionSettings::Gzip { compression_level: GzipCompressionLevel::default()})]
344    #[case::xz(CompressionSettings::Xz { compression_level: XzCompressionLevel::default()})]
345    #[case::zstd_all_threads(CompressionSettings::Zstd { compression_level: ZstdCompressionLevel::default(), threads: ZstdThreads::new(0) })]
346    #[case::zstd_one_thread(CompressionSettings::Zstd { compression_level: ZstdCompressionLevel::default(), threads: ZstdThreads::new(1) })]
347    #[case::zstd_crazy_threads(CompressionSettings::Zstd { compression_level: ZstdCompressionLevel::default(), threads: ZstdThreads::new(99999) })]
348    #[case::no_compression(CompressionSettings::None)]
349    fn test_compression_encoder_write_fmt(#[case] settings: CompressionSettings) -> TestResult {
350        let file = tempfile()?;
351        let mut encoder = CompressionEncoder::new(file, &settings)?;
352        let ref_encoder = encoder.by_ref();
353
354        ref_encoder.write_fmt(format_args!("{:.*}", 2, 1.234567))?;
355
356        ref_encoder.flush()?;
357
358        Ok(())
359    }
360}