1use std::{
2 fmt::{Display, Formatter},
3 str::FromStr,
4};
5
6use alpm_parsers::iter_str_context;
7use serde::{Deserialize, Serialize};
8use winnow::{
9 ModalResult,
10 Parser,
11 ascii::{alpha1, space0},
12 combinator::{alt, cut_err, eof, fail, opt, peek, repeat_till, terminated},
13 error::{StrContext, StrContextValue},
14 token::{any, rest},
15};
16
17use crate::Error;
18
19#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
44pub struct Url(url::Url);
45
46impl Url {
47 pub fn new(url: url::Url) -> Result<Self, Error> {
49 Ok(Self(url))
50 }
51
52 pub fn as_str(&self) -> &str {
54 self.0.as_str()
55 }
56
57 pub fn into_inner(self) -> url::Url {
59 self.0
60 }
61
62 pub fn inner(&self) -> &url::Url {
64 &self.0
65 }
66}
67
68impl AsRef<str> for Url {
69 fn as_ref(&self) -> &str {
70 self.as_str()
71 }
72}
73
74impl FromStr for Url {
75 type Err = Error;
76
77 fn from_str(s: &str) -> Result<Self, Self::Err> {
93 let url = url::Url::parse(s).map_err(Error::InvalidUrl)?;
94 Self::new(url)
95 }
96}
97
98impl Display for Url {
99 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
100 write!(f, "{}", self.as_str())
101 }
102}
103
104#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
153pub struct SourceUrl {
154 pub url: Url,
156 pub vcs_info: Option<VcsInfo>,
158}
159
160impl FromStr for SourceUrl {
161 type Err = Error;
162
163 fn from_str(s: &str) -> Result<Self, Self::Err> {
183 Ok(Self::parser.parse(s)?)
184 }
185}
186
187impl Display for SourceUrl {
188 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
189 let Some(vcs_info) = &self.vcs_info else {
191 return write!(f, "{}", self.url.as_str());
192 };
193
194 let mut prefix = None;
195 let url = self.url.as_str();
196 let mut formatted_fragment = String::new();
197 let mut query = String::new();
198
199 match vcs_info {
201 VcsInfo::Bzr { fragment } => {
202 prefix = Some(VcsProtocol::Bzr);
203 if let Some(fragment) = fragment {
204 formatted_fragment = format!("#{fragment}");
205 }
206 }
207 VcsInfo::Fossil { fragment } => {
208 prefix = Some(VcsProtocol::Fossil);
209 if let Some(fragment) = fragment {
210 formatted_fragment = format!("#{fragment}");
211 }
212 }
213 VcsInfo::Git { fragment, signed } => {
214 if !url.starts_with("git://") {
216 prefix = Some(VcsProtocol::Git);
217 }
218 if *signed {
219 query = "?signed".to_string();
220 }
221 if let Some(fragment) = fragment {
222 formatted_fragment = format!("#{fragment}");
223 }
224 }
225 VcsInfo::Hg { fragment } => {
226 prefix = Some(VcsProtocol::Hg);
227 if let Some(fragment) = fragment {
228 formatted_fragment = format!("#{fragment}");
229 }
230 }
231 VcsInfo::Svn { fragment } => {
232 if !url.starts_with("svn://") {
234 prefix = Some(VcsProtocol::Svn);
235 }
236 if let Some(fragment) = fragment {
237 formatted_fragment = format!("#{fragment}");
238 }
239 }
240 }
241
242 let prefix = if let Some(prefix) = prefix {
243 format!("{prefix}+")
244 } else {
245 String::new()
246 };
247
248 write!(f, "{prefix}{url}{query}{formatted_fragment}",)
249 }
250}
251
252impl SourceUrl {
253 fn parser(input: &mut &str) -> ModalResult<SourceUrl> {
255 let vcs = opt(VcsProtocol::parser).parse_next(input)?;
257
258 let Some(vcs) = vcs else {
259 let url = cut_err(rest.try_map(Url::from_str))
264 .context(StrContext::Label("url"))
265 .parse_next(input)?;
266 return Ok(SourceUrl {
267 url,
268 vcs_info: None,
269 });
270 };
271
272 let url = cut_err(SourceUrl::inner_url_parser.try_map(|url| Url::from_str(&url)))
275 .context(StrContext::Label("url"))
276 .parse_next(input)?;
277
278 let vcs_info = VcsInfo::parser(vcs).parse_next(input)?;
279
280 let _: Option<String> =
283 opt(("?", rest)
284 .take()
285 .and_then(cut_err(fail.context(StrContext::Label(
286 "or duplicate query parameter for detected VCS.",
287 )))))
288 .parse_next(input)?;
289
290 cut_err((space0, eof))
291 .context(StrContext::Label("unexpected trailing content in URL."))
292 .context(StrContext::Expected(StrContextValue::Description(
293 "end of input.",
294 )))
295 .parse_next(input)?;
296
297 Ok(SourceUrl {
298 url,
299 vcs_info: Some(vcs_info),
300 })
301 }
302
303 fn inner_url_parser(input: &mut &str) -> ModalResult<String> {
313 let (url, _) = repeat_till(0.., any, peek(alt(("#", "?", eof)))).parse_next(input)?;
314 Ok(url)
315 }
316}
317
318#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
324#[serde(tag = "protocol", rename_all = "lowercase")]
325pub enum VcsInfo {
326 Bzr {
327 fragment: Option<BzrFragment>,
328 },
329 Fossil {
330 fragment: Option<FossilFragment>,
331 },
332 Git {
333 fragment: Option<GitFragment>,
334 signed: bool,
335 },
336 Hg {
337 fragment: Option<HgFragment>,
338 },
339 Svn {
340 fragment: Option<SvnFragment>,
341 },
342}
343
344impl VcsInfo {
345 fn parser(vcs: VcsProtocol) -> impl FnMut(&mut &str) -> ModalResult<VcsInfo> {
350 move |input: &mut &str| match vcs {
351 VcsProtocol::Bzr => {
352 let fragment = opt(BzrFragment::parser).parse_next(input)?;
353 Ok(VcsInfo::Bzr { fragment })
354 }
355 VcsProtocol::Fossil => {
356 let fragment = opt(FossilFragment::parser).parse_next(input)?;
357 Ok(VcsInfo::Fossil { fragment })
358 }
359 VcsProtocol::Git => {
360 let mut signed = git_query(input)?;
364 let fragment = opt(GitFragment::parser).parse_next(input)?;
365 if !signed {
366 signed = git_query(input)?;
369 }
370 Ok(VcsInfo::Git { fragment, signed })
371 }
372 VcsProtocol::Hg => {
373 let fragment = opt(HgFragment::parser).parse_next(input)?;
374 Ok(VcsInfo::Hg { fragment })
375 }
376 VcsProtocol::Svn => {
377 let fragment = opt(SvnFragment::parser).parse_next(input)?;
378 Ok(VcsInfo::Svn { fragment })
379 }
380 }
381 }
382}
383
384#[derive(strum::EnumString, strum::Display)]
391#[strum(serialize_all = "lowercase")]
392enum VcsProtocol {
393 Bzr,
394 Fossil,
395 Git,
396 Hg,
397 Svn,
398}
399
400impl VcsProtocol {
401 fn parser(input: &mut &str) -> ModalResult<VcsProtocol> {
412 let protocol =
414 opt(terminated(alpha1.try_map(VcsProtocol::from_str), "+")).parse_next(input)?;
415
416 if let Some(protocol) = protocol {
417 return Ok(protocol);
418 }
419
420 let protocol = peek(alt(("git://", "svn://"))).parse_next(input)?;
426
427 match protocol {
428 "git://" => Ok(VcsProtocol::Git),
429 "svn://" => Ok(VcsProtocol::Svn),
430 _ => unreachable!(),
431 }
432 }
433}
434
435fn fragment_value(input: &mut &str) -> ModalResult<String> {
443 let _ = cut_err("=")
445 .context(StrContext::Label("fragment separator"))
446 .context(StrContext::Expected(StrContextValue::Description(
447 "a literal '='",
448 )))
449 .parse_next(input)?;
450
451 let (value, _) = repeat_till(0.., any, peek(alt(("?", "#", eof)))).parse_next(input)?;
453
454 Ok(value)
455}
456
457#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
459#[serde(rename_all = "snake_case")]
460pub enum BzrFragment {
461 Revision(String),
462}
463
464impl Display for BzrFragment {
465 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
466 match self {
467 BzrFragment::Revision(revision) => write!(f, "revision={revision}"),
468 }
469 }
470}
471
472impl BzrFragment {
473 fn parser(input: &mut &str) -> ModalResult<BzrFragment> {
477 let _ = "#".parse_next(input)?;
479
480 cut_err("revision")
482 .context(StrContext::Label("bzr revision type"))
483 .context(StrContext::Expected(StrContextValue::Description(
484 "revision keyword",
485 )))
486 .parse_next(input)?;
487
488 let value = fragment_value.parse_next(input)?;
489
490 Ok(BzrFragment::Revision(value))
491 }
492}
493
494#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
496#[serde(rename_all = "snake_case")]
497pub enum FossilFragment {
498 Branch(String),
499 Commit(String),
500 Tag(String),
501}
502
503impl Display for FossilFragment {
504 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
505 match self {
506 FossilFragment::Branch(revision) => write!(f, "branch={revision}"),
507 FossilFragment::Commit(revision) => write!(f, "commit={revision}"),
508 FossilFragment::Tag(revision) => write!(f, "tag={revision}"),
509 }
510 }
511}
512
513impl FossilFragment {
514 fn parser(input: &mut &str) -> ModalResult<FossilFragment> {
519 let _ = "#".parse_next(input)?;
521
522 let version_keywords = ["branch", "commit", "tag"];
524 let version_type = cut_err(alt(version_keywords))
525 .context(StrContext::Label("fossil revision type"))
526 .context_with(iter_str_context!([version_keywords]))
527 .parse_next(input)?;
528
529 let value = fragment_value.parse_next(input)?;
530
531 match version_type {
532 "branch" => Ok(FossilFragment::Branch(value.to_string())),
533 "commit" => Ok(FossilFragment::Commit(value.to_string())),
534 "tag" => Ok(FossilFragment::Tag(value.to_string())),
535 _ => unreachable!(),
536 }
537 }
538}
539
540#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
542#[serde(rename_all = "snake_case")]
543pub enum GitFragment {
544 Branch(String),
545 Commit(String),
546 Tag(String),
547}
548
549impl Display for GitFragment {
550 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
551 match self {
552 GitFragment::Branch(revision) => write!(f, "branch={revision}"),
553 GitFragment::Commit(revision) => write!(f, "commit={revision}"),
554 GitFragment::Tag(revision) => write!(f, "tag={revision}"),
555 }
556 }
557}
558
559impl GitFragment {
560 fn parser(input: &mut &str) -> ModalResult<GitFragment> {
565 let _ = "#".parse_next(input)?;
567
568 let version_keywords = ["branch", "commit", "tag"];
570 let version_type = cut_err(alt(version_keywords))
571 .context(StrContext::Label("git revision type"))
572 .context_with(iter_str_context!([version_keywords]))
573 .parse_next(input)?;
574
575 let value = fragment_value.parse_next(input)?;
576
577 match version_type {
578 "branch" => Ok(GitFragment::Branch(value.to_string())),
579 "commit" => Ok(GitFragment::Commit(value.to_string())),
580 "tag" => Ok(GitFragment::Tag(value.to_string())),
581 _ => unreachable!(),
582 }
583 }
584}
585
586fn git_query(input: &mut &str) -> ModalResult<bool> {
590 let query = opt("?signed").parse_next(input)?;
591 Ok(query.is_some())
592}
593
594#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
596#[serde(rename_all = "snake_case")]
597pub enum HgFragment {
598 Branch(String),
599 Revision(String),
600 Tag(String),
601}
602
603impl Display for HgFragment {
604 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
605 match self {
606 HgFragment::Branch(revision) => write!(f, "branch={revision}"),
607 HgFragment::Revision(revision) => write!(f, "revision={revision}"),
608 HgFragment::Tag(revision) => write!(f, "tag={revision}"),
609 }
610 }
611}
612
613impl HgFragment {
614 fn parser(input: &mut &str) -> ModalResult<HgFragment> {
619 let _ = "#".parse_next(input)?;
621
622 let version_keywords = ["branch", "revision", "tag"];
624 let version_type = cut_err(alt(version_keywords))
625 .context(StrContext::Label("hg revision type"))
626 .context_with(iter_str_context!([version_keywords]))
627 .parse_next(input)?;
628
629 let value = fragment_value.parse_next(input)?;
630
631 match version_type {
632 "branch" => Ok(HgFragment::Branch(value.to_string())),
633 "revision" => Ok(HgFragment::Revision(value.to_string())),
634 "tag" => Ok(HgFragment::Tag(value.to_string())),
635 _ => unreachable!(),
636 }
637 }
638}
639
640#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
642#[serde(rename_all = "snake_case")]
643pub enum SvnFragment {
644 Revision(String),
645}
646
647impl Display for SvnFragment {
648 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
649 match self {
650 SvnFragment::Revision(revision) => write!(f, "revision={revision}"),
651 }
652 }
653}
654
655impl SvnFragment {
656 fn parser(input: &mut &str) -> ModalResult<SvnFragment> {
661 let _ = "#".parse_next(input)?;
663
664 cut_err("revision")
666 .context(StrContext::Label("svn revision type"))
667 .context(StrContext::Expected(StrContextValue::Description(
668 "revision keyword",
669 )))
670 .parse_next(input)?;
671
672 let value = fragment_value.parse_next(input)?;
673
674 Ok(SvnFragment::Revision(value))
675 }
676}
677
678#[cfg(test)]
679mod tests {
680 use rstest::rstest;
681 use testresult::TestResult;
682
683 use super::*;
684
685 #[rstest]
686 #[case("https://example.com/", Ok("https://example.com/"))]
687 #[case(
688 "https://example.com/path?query=1",
689 Ok("https://example.com/path?query=1")
690 )]
691 #[case("ftp://example.com/", Ok("ftp://example.com/"))]
692 #[case("not-a-url", Err(url::ParseError::RelativeUrlWithoutBase.into()))]
693 fn test_url_parsing(#[case] input: &str, #[case] expected: Result<&str, Error>) {
694 let result = input.parse::<Url>();
695 assert_eq!(
696 result.as_ref().map(|v| v.to_string()),
697 expected.as_ref().map(|v| v.to_string())
698 );
699
700 if let Ok(url) = result {
701 assert_eq!(url.as_str(), input);
702 }
703 }
704
705 #[rstest]
706 #[case(
707 "git+https://example/project#tag=v1.0.0?signed",
708 Some("git+https://example/project?signed#tag=v1.0.0"),
709 SourceUrl {
710 url: Url::from_str("https://example/project").unwrap(),
711 vcs_info: Some(VcsInfo::Git {
712 fragment: Some(GitFragment::Tag("v1.0.0".to_string())),
713 signed: true
714 })
715 }
716 )]
717 #[case(
718 "git+https://example/project?signed#tag=v1.0.0",
719 None,
720 SourceUrl {
721 url: Url::from_str("https://example/project").unwrap(),
722 vcs_info: Some(VcsInfo::Git {
723 fragment: Some(GitFragment::Tag("v1.0.0".to_string())),
724 signed: true
725 })
726 }
727 )]
728 #[case(
729 "git://example/project#commit=a51720b",
730 None,
731 SourceUrl {
732 url: Url::from_str("git://example/project").unwrap(),
733 vcs_info: Some(VcsInfo::Git {
734 fragment: Some(GitFragment::Commit("a51720b".to_string())),
735 signed: false
736 })
737 }
738 )]
739 #[case(
740 "svn+https://example/project#revision=a51720b",
741 None,
742 SourceUrl {
743 url: Url::from_str("https://example/project").unwrap(),
744 vcs_info: Some(VcsInfo::Svn {
745 fragment: Some(SvnFragment::Revision("a51720b".to_string())),
746 })
747 }
748 )]
749 #[case(
750 "bzr+https://example/project#revision=a51720b",
751 None,
752 SourceUrl {
753 url: Url::from_str("https://example/project").unwrap(),
754 vcs_info: Some(VcsInfo::Bzr {
755 fragment: Some(BzrFragment::Revision("a51720b".to_string())),
756 })
757 }
758 )]
759 #[case(
760 "hg+https://example/project#branch=feature",
761 None,
762 SourceUrl {
763 url: Url::from_str("https://example/project").unwrap(),
764 vcs_info: Some(VcsInfo::Hg {
765 fragment: Some(HgFragment::Branch("feature".to_string())),
766 })
767 }
768 )]
769 #[case(
770 "fossil+https://example/project#branch=feature",
771 None,
772 SourceUrl {
773 url: Url::from_str("https://example/project").unwrap(),
774 vcs_info: Some(VcsInfo::Fossil {
775 fragment: Some(FossilFragment::Branch("feature".to_string())),
776 })
777 }
778 )]
779 #[case(
780 "https://example/project#branch=feature?signed",
781 None,
782 SourceUrl {
783 url: Url::from_str("https://example/project#branch=feature?signed").unwrap(),
784 vcs_info: None,
785 }
786 )]
787 fn test_source_url_parsing_success(
788 #[case] input: &str,
789 #[case] expected_to_string: Option<&str>,
790 #[case] expected: SourceUrl,
791 ) -> TestResult {
792 let source_url = SourceUrl::from_str(input)?;
793 assert_eq!(
794 source_url, expected,
795 "Parsed source_url should resemble the expected output."
796 );
797
798 let expected_to_string = expected_to_string.unwrap_or(input);
801 assert_eq!(
802 source_url.to_string(),
803 expected_to_string,
804 "Parsed and displayed source_url should resemble original."
805 );
806
807 Ok(())
808 }
809
810 #[rstest]
812 #[case(
813 "git+https://example/project#revision=v1.0.0?signed",
814 "invalid git revision type\nexpected `branch`, `commit`, `tag`"
815 )]
816 #[case(
817 "git+https://example/project#branch=feature#branch=feature",
818 "invalid unexpected trailing content in URL."
819 )]
820 #[case(
821 "git+https://example/project#branch=feature?signed?signed",
822 "invalid or duplicate query parameter for detected VCS."
823 )]
824 #[case(
825 "bzr+https://example/project#branch=feature",
826 "invalid bzr revision type\nexpected revision keyword"
827 )]
828 #[case(
829 "svn+https://example/project#branch=feature",
830 "invalid svn revision type\nexpected revision keyword"
831 )]
832 #[case(
833 "hg+https://example/project#commit=154021a",
834 "invalid hg revision type\nexpected `branch`, `revision`, `tag`"
835 )]
836 #[case(
837 "hg+https://example/project#branch=feature?signed",
838 "invalid or duplicate query parameter for detected VCS."
839 )]
840 fn test_source_url_parsing_failure(#[case] input: &str, #[case] error_snippet: &str) {
841 let result = SourceUrl::from_str(input);
842 assert!(result.is_err(), "Invalid source_url should fail to parse.");
843 let err = result.unwrap_err();
844 let pretty_error = err.to_string();
845 assert!(
846 pretty_error.contains(error_snippet),
847 "Error:\n=====\n{pretty_error}\n=====\nshould contain snippet:\n\n{error_snippet}"
848 );
849 }
850}