1use std::{
4 cmp::Ordering,
5 fmt::{Display, Formatter},
6 str::FromStr,
7};
8
9use alpm_parsers::iter_str_context;
10use serde::{Deserialize, Serialize};
11use strum::VariantNames;
12use winnow::{
13 ModalResult,
14 Parser,
15 combinator::{alt, eof, fail, seq},
16 error::{StrContext, StrContextValue},
17 token::take_while,
18};
19
20use crate::{Error, Version};
21
22#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
46pub struct VersionRequirement {
47 pub comparison: VersionComparison,
49 pub version: Version,
51}
52
53impl VersionRequirement {
54 pub fn new(comparison: VersionComparison, version: Version) -> Self {
56 VersionRequirement {
57 comparison,
58 version,
59 }
60 }
61
62 pub fn is_satisfied_by(&self, ver: &Version) -> bool {
83 self.comparison.is_compatible_with(ver.cmp(&self.version))
84 }
85
86 pub fn parser(input: &mut &str) -> ModalResult<Self> {
94 seq!(Self {
95 comparison: take_while(1.., ('<', '>', '='))
96 .context(StrContext::Expected(StrContextValue::Description(
98 "version comparison operator"
99 )))
100 .and_then(VersionComparison::parser),
101 version: Version::parser,
102 })
103 .parse_next(input)
104 }
105}
106
107impl Display for VersionRequirement {
108 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
109 write!(f, "{}{}", self.comparison, self.version)
110 }
111}
112
113impl FromStr for VersionRequirement {
114 type Err = Error;
115
116 fn from_str(s: &str) -> Result<Self, Self::Err> {
124 Ok(Self::parser.parse(s)?)
125 }
126}
127
128#[derive(
151 strum::AsRefStr,
152 Clone,
153 Copy,
154 Debug,
155 strum::Display,
156 strum::EnumIter,
157 PartialEq,
158 Eq,
159 strum::VariantNames,
160 Serialize,
161 Deserialize,
162)]
163pub enum VersionComparison {
164 #[strum(to_string = "<=")]
166 LessOrEqual,
167
168 #[strum(to_string = ">=")]
170 GreaterOrEqual,
171
172 #[strum(to_string = "=")]
174 Equal,
175
176 #[strum(to_string = "<")]
178 Less,
179
180 #[strum(to_string = ">")]
182 Greater,
183}
184
185impl VersionComparison {
186 fn is_compatible_with(self, ord: Ordering) -> bool {
189 match (self, ord) {
190 (VersionComparison::Less, Ordering::Less)
191 | (VersionComparison::LessOrEqual, Ordering::Less | Ordering::Equal)
192 | (VersionComparison::Equal, Ordering::Equal)
193 | (VersionComparison::GreaterOrEqual, Ordering::Greater | Ordering::Equal)
194 | (VersionComparison::Greater, Ordering::Greater) => true,
195
196 (VersionComparison::Less, Ordering::Equal | Ordering::Greater)
197 | (VersionComparison::LessOrEqual, Ordering::Greater)
198 | (VersionComparison::Equal, Ordering::Less | Ordering::Greater)
199 | (VersionComparison::GreaterOrEqual, Ordering::Less)
200 | (VersionComparison::Greater, Ordering::Less | Ordering::Equal) => false,
201 }
202 }
203
204 pub fn parser(input: &mut &str) -> ModalResult<Self> {
212 alt((
213 ("<=", eof).value(Self::LessOrEqual),
215 (">=", eof).value(Self::GreaterOrEqual),
216 ("=", eof).value(Self::Equal),
217 ("<", eof).value(Self::Less),
218 (">", eof).value(Self::Greater),
219 fail.context(StrContext::Label("comparison operator"))
220 .context_with(iter_str_context!([VersionComparison::VARIANTS])),
221 ))
222 .parse_next(input)
223 }
224}
225
226impl FromStr for VersionComparison {
227 type Err = Error;
228
229 fn from_str(s: &str) -> Result<Self, Self::Err> {
237 Ok(Self::parser.parse(s)?)
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use rstest::rstest;
244
245 use super::*;
246 #[rstest]
248 #[case("<", VersionComparison::Less)]
249 #[case("<=", VersionComparison::LessOrEqual)]
250 #[case("=", VersionComparison::Equal)]
251 #[case(">=", VersionComparison::GreaterOrEqual)]
252 #[case(">", VersionComparison::Greater)]
253 fn valid_version_comparison(#[case] comparison: &str, #[case] expected: VersionComparison) {
254 assert_eq!(comparison.parse(), Ok(expected));
255 }
256
257 #[rstest]
259 #[case("", "invalid comparison operator")]
260 #[case("<<", "invalid comparison operator")]
261 #[case("==", "invalid comparison operator")]
262 #[case("!=", "invalid comparison operator")]
263 #[case(" =", "invalid comparison operator")]
264 #[case("= ", "invalid comparison operator")]
265 #[case("<1", "invalid comparison operator")]
266 fn invalid_version_comparison(#[case] comparison: &str, #[case] err_snippet: &str) {
267 let Err(Error::ParseError(err_msg)) = VersionComparison::from_str(comparison) else {
268 panic!("'{comparison}' did not fail as expected")
269 };
270 assert!(
271 err_msg.contains(err_snippet),
272 "Error:\n=====\n{err_msg}\n=====\nshould contain snippet:\n\n{err_snippet}"
273 );
274 }
275
276 #[rstest]
278 #[case("=1", VersionRequirement {
279 comparison: VersionComparison::Equal,
280 version: Version::from_str("1").unwrap(),
281 })]
282 #[case("<=42:abcd-2.4", VersionRequirement {
283 comparison: VersionComparison::LessOrEqual,
284 version: Version::from_str("42:abcd-2.4").unwrap(),
285 })]
286 #[case(">3.1", VersionRequirement {
287 comparison: VersionComparison::Greater,
288 version: Version::from_str("3.1").unwrap(),
289 })]
290 fn valid_version_requirement(#[case] requirement: &str, #[case] expected: VersionRequirement) {
291 assert_eq!(
292 requirement.parse(),
293 Ok(expected),
294 "Expected successful parse for version requirement '{requirement}'"
295 );
296 }
297
298 #[rstest]
299 #[case::bad_operator("<>3.1", "invalid comparison operator")]
300 #[case::no_operator("3.1", "expected version comparison operator")]
301 #[case::arrow_operator("=>3.1", "invalid comparison operator")]
302 #[case::no_version("<=", "expected pkgver string")]
303 fn invalid_version_requirement(#[case] requirement: &str, #[case] err_snippet: &str) {
304 let Err(Error::ParseError(err_msg)) = VersionRequirement::from_str(requirement) else {
305 panic!("'{requirement}' erroneously parsed as VersionRequirement")
306 };
307 assert!(
308 err_msg.contains(err_snippet),
309 "Error:\n=====\n{err_msg}\n=====\nshould contain snippet:\n\n{err_snippet}"
310 );
311 }
312
313 #[rstest]
314 #[case("<3.1>3.2", "invalid pkgver character")]
315 fn invalid_version_requirement_pkgver_parse(
316 #[case] requirement: &str,
317 #[case] err_snippet: &str,
318 ) {
319 let Err(Error::ParseError(err_msg)) = VersionRequirement::from_str(requirement) else {
320 panic!("'{requirement}' erroneously parsed as VersionRequirement")
321 };
322 assert!(
323 err_msg.contains(err_snippet),
324 "Error:\n=====\n{err_msg}\n=====\nshould contain snippet:\n\n{err_snippet}"
325 );
326 }
327
328 #[rstest]
330 #[case("=1", "1", true)]
331 #[case("=1", "1.0", false)]
332 #[case("=1", "1-1", false)]
333 #[case("=1", "1:1", false)]
334 #[case("=1", "0.9", false)]
335 #[case("<42", "41", true)]
336 #[case("<42", "42", false)]
337 #[case("<42", "43", false)]
338 #[case("<=42", "41", true)]
339 #[case("<=42", "42", true)]
340 #[case("<=42", "43", false)]
341 #[case(">42", "41", false)]
342 #[case(">42", "42", false)]
343 #[case(">42", "43", true)]
344 #[case(">=42", "41", false)]
345 #[case(">=42", "42", true)]
346 #[case(">=42", "43", true)]
347 fn version_requirement_satisfied(
348 #[case] requirement: &str,
349 #[case] version: &str,
350 #[case] result: bool,
351 ) {
352 let requirement = VersionRequirement::from_str(requirement).unwrap();
353 let version = Version::from_str(version).unwrap();
354 assert_eq!(requirement.is_satisfied_by(&version), result);
355 }
356}