alpm_types/version/
requirement.rs

1//! Version requirement declarations and comparisons based on them.
2
3use std::{
4    cmp::Ordering,
5    fmt::{Display, Formatter},
6    str::FromStr,
7};
8
9use alpm_parsers::iter_str_context;
10use serde::{Deserialize, Serialize};
11use strum::VariantNames;
12use winnow::{
13    ModalResult,
14    Parser,
15    combinator::{alt, eof, fail, seq},
16    error::{StrContext, StrContextValue},
17    token::take_while,
18};
19
20use crate::{Error, Version};
21
22/// A version requirement, e.g. for a dependency package.
23///
24/// It consists of a target version and a comparison function. A version requirement of `>=1.5` has
25/// a target version of `1.5` and a comparison function of [`VersionComparison::GreaterOrEqual`].
26/// See [alpm-comparison] for details on the format.
27///
28/// ## Examples
29///
30/// ```
31/// use std::str::FromStr;
32///
33/// use alpm_types::{Version, VersionComparison, VersionRequirement};
34///
35/// # fn main() -> Result<(), alpm_types::Error> {
36/// let requirement = VersionRequirement::from_str(">=1.5")?;
37///
38/// assert_eq!(requirement.comparison, VersionComparison::GreaterOrEqual);
39/// assert_eq!(requirement.version, Version::from_str("1.5")?);
40/// # Ok(())
41/// # }
42/// ```
43///
44/// [alpm-comparison]: https://alpm.archlinux.page/specifications/alpm-comparison.7.html
45#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
46pub struct VersionRequirement {
47    /// Version comparison function
48    pub comparison: VersionComparison,
49    /// Target version
50    pub version: Version,
51}
52
53impl VersionRequirement {
54    /// Create a new `VersionRequirement`
55    pub fn new(comparison: VersionComparison, version: Version) -> Self {
56        VersionRequirement {
57            comparison,
58            version,
59        }
60    }
61
62    /// Returns `true` if the requirement is satisfied by the given package version.
63    ///
64    /// ## Examples
65    ///
66    /// ```
67    /// use std::str::FromStr;
68    ///
69    /// use alpm_types::{Version, VersionRequirement};
70    ///
71    /// # fn main() -> Result<(), alpm_types::Error> {
72    /// let requirement = VersionRequirement::from_str(">=1.5-3")?;
73    ///
74    /// assert!(!requirement.is_satisfied_by(&Version::from_str("1.5")?));
75    /// assert!(requirement.is_satisfied_by(&Version::from_str("1.5-3")?));
76    /// assert!(requirement.is_satisfied_by(&Version::from_str("1.6")?));
77    /// assert!(requirement.is_satisfied_by(&Version::from_str("2:1.0")?));
78    /// assert!(!requirement.is_satisfied_by(&Version::from_str("1.0")?));
79    /// # Ok(())
80    /// # }
81    /// ```
82    pub fn is_satisfied_by(&self, ver: &Version) -> bool {
83        self.comparison.is_compatible_with(ver.cmp(&self.version))
84    }
85
86    /// Recognizes a [`VersionRequirement`] in a string slice.
87    ///
88    /// Consumes all of its input.
89    ///
90    /// # Errors
91    ///
92    /// Returns an error if `input` is not a valid _alpm-comparison_.
93    pub fn parser(input: &mut &str) -> ModalResult<Self> {
94        seq!(Self {
95            comparison: take_while(1.., ('<', '>', '='))
96                // add context here because otherwise take_while can fail and provide no information
97                .context(StrContext::Expected(StrContextValue::Description(
98                    "version comparison operator"
99                )))
100                .and_then(VersionComparison::parser),
101            version: Version::parser,
102        })
103        .parse_next(input)
104    }
105}
106
107impl Display for VersionRequirement {
108    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
109        write!(f, "{}{}", self.comparison, self.version)
110    }
111}
112
113impl FromStr for VersionRequirement {
114    type Err = Error;
115
116    /// Creates a new [`VersionRequirement`] from a string slice.
117    ///
118    /// Delegates to [`VersionRequirement::parser`].
119    ///
120    /// # Errors
121    ///
122    /// Returns an error if [`VersionRequirement::parser`] fails.
123    fn from_str(s: &str) -> Result<Self, Self::Err> {
124        Ok(Self::parser.parse(s)?)
125    }
126}
127
128/// Specifies the comparison function for a [`VersionRequirement`].
129///
130/// The package version can be required to be:
131/// - less than (`<`)
132/// - less than or equal to (`<=`)
133/// - equal to (`=`)
134/// - greater than or equal to (`>=`)
135/// - greater than (`>`)
136///
137/// the specified version.
138///
139/// See [alpm-comparison] for details on the format.
140///
141/// ## Note
142///
143/// The variants of this enum are sorted in a way, that prefers the two-letter comparators over
144/// the one-letter ones.
145/// This is because when splitting a string on the string representation of [`VersionComparison`]
146/// variant and relying on the ordering of [`strum::EnumIter`], the two-letter comparators must be
147/// checked before checking the one-letter ones to yield robust results.
148///
149/// [alpm-comparison]: https://alpm.archlinux.page/specifications/alpm-comparison.7.html
150#[derive(
151    strum::AsRefStr,
152    Clone,
153    Copy,
154    Debug,
155    strum::Display,
156    strum::EnumIter,
157    PartialEq,
158    Eq,
159    strum::VariantNames,
160    Serialize,
161    Deserialize,
162)]
163pub enum VersionComparison {
164    /// Less than or equal to
165    #[strum(to_string = "<=")]
166    LessOrEqual,
167
168    /// Greater than or equal to
169    #[strum(to_string = ">=")]
170    GreaterOrEqual,
171
172    /// Equal to
173    #[strum(to_string = "=")]
174    Equal,
175
176    /// Less than
177    #[strum(to_string = "<")]
178    Less,
179
180    /// Greater than
181    #[strum(to_string = ">")]
182    Greater,
183}
184
185impl VersionComparison {
186    /// Returns `true` if the result of a comparison between the actual and required package
187    /// versions satisfies the comparison function.
188    fn is_compatible_with(self, ord: Ordering) -> bool {
189        match (self, ord) {
190            (VersionComparison::Less, Ordering::Less)
191            | (VersionComparison::LessOrEqual, Ordering::Less | Ordering::Equal)
192            | (VersionComparison::Equal, Ordering::Equal)
193            | (VersionComparison::GreaterOrEqual, Ordering::Greater | Ordering::Equal)
194            | (VersionComparison::Greater, Ordering::Greater) => true,
195
196            (VersionComparison::Less, Ordering::Equal | Ordering::Greater)
197            | (VersionComparison::LessOrEqual, Ordering::Greater)
198            | (VersionComparison::Equal, Ordering::Less | Ordering::Greater)
199            | (VersionComparison::GreaterOrEqual, Ordering::Less)
200            | (VersionComparison::Greater, Ordering::Less | Ordering::Equal) => false,
201        }
202    }
203
204    /// Recognizes a [`VersionComparison`] in a string slice.
205    ///
206    /// Consumes all of its input.
207    ///
208    /// # Errors
209    ///
210    /// Returns an error if `input` is not a valid _alpm-comparison_.
211    pub fn parser(input: &mut &str) -> ModalResult<Self> {
212        alt((
213            // insert eofs here (instead of after alt call) so correct error message is thrown
214            ("<=", eof).value(Self::LessOrEqual),
215            (">=", eof).value(Self::GreaterOrEqual),
216            ("=", eof).value(Self::Equal),
217            ("<", eof).value(Self::Less),
218            (">", eof).value(Self::Greater),
219            fail.context(StrContext::Label("comparison operator"))
220                .context_with(iter_str_context!([VersionComparison::VARIANTS])),
221        ))
222        .parse_next(input)
223    }
224}
225
226impl FromStr for VersionComparison {
227    type Err = Error;
228
229    /// Creates a new [`VersionComparison`] from a string slice.
230    ///
231    /// Delegates to [`VersionComparison::parser`].
232    ///
233    /// # Errors
234    ///
235    /// Returns an error if [`VersionComparison::parser`] fails.
236    fn from_str(s: &str) -> Result<Self, Self::Err> {
237        Ok(Self::parser.parse(s)?)
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use rstest::rstest;
244
245    use super::*;
246    /// Ensure that valid version comparison strings can be parsed.
247    #[rstest]
248    #[case("<", VersionComparison::Less)]
249    #[case("<=", VersionComparison::LessOrEqual)]
250    #[case("=", VersionComparison::Equal)]
251    #[case(">=", VersionComparison::GreaterOrEqual)]
252    #[case(">", VersionComparison::Greater)]
253    fn valid_version_comparison(#[case] comparison: &str, #[case] expected: VersionComparison) {
254        assert_eq!(comparison.parse(), Ok(expected));
255    }
256
257    /// Ensure that invalid version comparisons will throw an error.
258    #[rstest]
259    #[case("", "invalid comparison operator")]
260    #[case("<<", "invalid comparison operator")]
261    #[case("==", "invalid comparison operator")]
262    #[case("!=", "invalid comparison operator")]
263    #[case(" =", "invalid comparison operator")]
264    #[case("= ", "invalid comparison operator")]
265    #[case("<1", "invalid comparison operator")]
266    fn invalid_version_comparison(#[case] comparison: &str, #[case] err_snippet: &str) {
267        let Err(Error::ParseError(err_msg)) = VersionComparison::from_str(comparison) else {
268            panic!("'{comparison}' did not fail as expected")
269        };
270        assert!(
271            err_msg.contains(err_snippet),
272            "Error:\n=====\n{err_msg}\n=====\nshould contain snippet:\n\n{err_snippet}"
273        );
274    }
275
276    /// Test successful parsing for version requirement strings.
277    #[rstest]
278    #[case("=1", VersionRequirement {
279        comparison: VersionComparison::Equal,
280        version: Version::from_str("1").unwrap(),
281    })]
282    #[case("<=42:abcd-2.4", VersionRequirement {
283        comparison: VersionComparison::LessOrEqual,
284        version: Version::from_str("42:abcd-2.4").unwrap(),
285    })]
286    #[case(">3.1", VersionRequirement {
287        comparison: VersionComparison::Greater,
288        version: Version::from_str("3.1").unwrap(),
289    })]
290    fn valid_version_requirement(#[case] requirement: &str, #[case] expected: VersionRequirement) {
291        assert_eq!(
292            requirement.parse(),
293            Ok(expected),
294            "Expected successful parse for version requirement '{requirement}'"
295        );
296    }
297
298    #[rstest]
299    #[case::bad_operator("<>3.1", "invalid comparison operator")]
300    #[case::no_operator("3.1", "expected version comparison operator")]
301    #[case::arrow_operator("=>3.1", "invalid comparison operator")]
302    #[case::no_version("<=", "expected pkgver string")]
303    fn invalid_version_requirement(#[case] requirement: &str, #[case] err_snippet: &str) {
304        let Err(Error::ParseError(err_msg)) = VersionRequirement::from_str(requirement) else {
305            panic!("'{requirement}' erroneously parsed as VersionRequirement")
306        };
307        assert!(
308            err_msg.contains(err_snippet),
309            "Error:\n=====\n{err_msg}\n=====\nshould contain snippet:\n\n{err_snippet}"
310        );
311    }
312
313    #[rstest]
314    #[case("<3.1>3.2", "invalid pkgver character")]
315    fn invalid_version_requirement_pkgver_parse(
316        #[case] requirement: &str,
317        #[case] err_snippet: &str,
318    ) {
319        let Err(Error::ParseError(err_msg)) = VersionRequirement::from_str(requirement) else {
320            panic!("'{requirement}' erroneously parsed as VersionRequirement")
321        };
322        assert!(
323            err_msg.contains(err_snippet),
324            "Error:\n=====\n{err_msg}\n=====\nshould contain snippet:\n\n{err_snippet}"
325        );
326    }
327
328    /// Check whether a version requirement (>= 1.0) is fulfilled by a given version string.
329    #[rstest]
330    #[case("=1", "1", true)]
331    #[case("=1", "1.0", false)]
332    #[case("=1", "1-1", false)]
333    #[case("=1", "1:1", false)]
334    #[case("=1", "0.9", false)]
335    #[case("<42", "41", true)]
336    #[case("<42", "42", false)]
337    #[case("<42", "43", false)]
338    #[case("<=42", "41", true)]
339    #[case("<=42", "42", true)]
340    #[case("<=42", "43", false)]
341    #[case(">42", "41", false)]
342    #[case(">42", "42", false)]
343    #[case(">42", "43", true)]
344    #[case(">=42", "41", false)]
345    #[case(">=42", "42", true)]
346    #[case(">=42", "43", true)]
347    fn version_requirement_satisfied(
348        #[case] requirement: &str,
349        #[case] version: &str,
350        #[case] result: bool,
351    ) {
352        let requirement = VersionRequirement::from_str(requirement).unwrap();
353        let version = Version::from_str(version).unwrap();
354        assert_eq!(requirement.is_satisfied_by(&version), result);
355    }
356}