1use std::{
2 fmt::{Display, Formatter},
3 str::FromStr,
4};
5
6use serde::{Deserialize, Serialize};
7use winnow::{
8 ModalResult,
9 Parser,
10 ascii::{alpha1, space0},
11 combinator::{alt, cut_err, eof, fail, opt, peek, repeat_till, terminated},
12 error::{StrContext, StrContextValue},
13 token::{any, rest},
14};
15
16use crate::Error;
17
18#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
43pub struct Url(url::Url);
44
45impl Url {
46 pub fn new(url: url::Url) -> Result<Self, Error> {
48 Ok(Self(url))
49 }
50
51 pub fn as_str(&self) -> &str {
53 self.0.as_str()
54 }
55
56 pub fn into_inner(self) -> url::Url {
58 self.0
59 }
60
61 pub fn inner(&self) -> &url::Url {
63 &self.0
64 }
65}
66
67impl AsRef<str> for Url {
68 fn as_ref(&self) -> &str {
69 self.as_str()
70 }
71}
72
73impl FromStr for Url {
74 type Err = Error;
75
76 fn from_str(s: &str) -> Result<Self, Self::Err> {
92 let url = url::Url::parse(s).map_err(Error::InvalidUrl)?;
93 Self::new(url)
94 }
95}
96
97impl Display for Url {
98 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
99 write!(f, "{}", self.as_str())
100 }
101}
102
103#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
152pub struct SourceUrl {
153 pub url: Url,
155 pub vcs_info: Option<VcsInfo>,
157}
158
159impl FromStr for SourceUrl {
160 type Err = Error;
161
162 fn from_str(s: &str) -> Result<Self, Self::Err> {
182 Ok(Self::parser.parse(s)?)
183 }
184}
185
186impl Display for SourceUrl {
187 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
188 let Some(vcs_info) = &self.vcs_info else {
190 return write!(f, "{}", self.url.as_str());
191 };
192
193 let mut prefix = None;
194 let url = self.url.as_str();
195 let mut formatted_fragment = String::new();
196 let mut query = String::new();
197
198 match vcs_info {
200 VcsInfo::Bzr { fragment } => {
201 prefix = Some(VcsProtocol::Bzr);
202 if let Some(fragment) = fragment {
203 formatted_fragment = format!("#{fragment}");
204 }
205 }
206 VcsInfo::Fossil { fragment } => {
207 prefix = Some(VcsProtocol::Fossil);
208 if let Some(fragment) = fragment {
209 formatted_fragment = format!("#{fragment}");
210 }
211 }
212 VcsInfo::Git { fragment, signed } => {
213 if !url.starts_with("git://") {
215 prefix = Some(VcsProtocol::Git);
216 }
217 if *signed {
218 query = "?signed".to_string();
219 }
220 if let Some(fragment) = fragment {
221 formatted_fragment = format!("#{fragment}");
222 }
223 }
224 VcsInfo::Hg { fragment } => {
225 prefix = Some(VcsProtocol::Hg);
226 if let Some(fragment) = fragment {
227 formatted_fragment = format!("#{fragment}");
228 }
229 }
230 VcsInfo::Svn { fragment } => {
231 if !url.starts_with("svn://") {
233 prefix = Some(VcsProtocol::Svn);
234 }
235 if let Some(fragment) = fragment {
236 formatted_fragment = format!("#{fragment}");
237 }
238 }
239 }
240
241 let prefix = if let Some(prefix) = prefix {
242 format!("{prefix}+")
243 } else {
244 String::new()
245 };
246
247 write!(f, "{prefix}{url}{query}{formatted_fragment}",)
248 }
249}
250
251impl SourceUrl {
252 fn parser(input: &mut &str) -> ModalResult<SourceUrl> {
254 let vcs = opt(VcsProtocol::parser).parse_next(input)?;
256
257 let Some(vcs) = vcs else {
258 let url = cut_err(rest.try_map(Url::from_str))
263 .context(StrContext::Label("url"))
264 .parse_next(input)?;
265 return Ok(SourceUrl {
266 url,
267 vcs_info: None,
268 });
269 };
270
271 let url = cut_err(SourceUrl::inner_url_parser.try_map(|url| Url::from_str(&url)))
274 .context(StrContext::Label("url"))
275 .parse_next(input)?;
276
277 let vcs_info = VcsInfo::parser(vcs).parse_next(input)?;
278
279 let _: Option<String> =
282 opt(("?", rest)
283 .take()
284 .and_then(cut_err(fail.context(StrContext::Label(
285 "or duplicate query parameter for detected VCS.",
286 )))))
287 .parse_next(input)?;
288
289 cut_err((space0, eof))
290 .context(StrContext::Label("unexpected trailing content in URL."))
291 .context(StrContext::Expected(StrContextValue::Description(
292 "end of input.",
293 )))
294 .parse_next(input)?;
295
296 Ok(SourceUrl {
297 url,
298 vcs_info: Some(vcs_info),
299 })
300 }
301
302 fn inner_url_parser(input: &mut &str) -> ModalResult<String> {
312 let (url, _) = repeat_till(0.., any, peek(alt(("#", "?", eof)))).parse_next(input)?;
313 Ok(url)
314 }
315}
316
317#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
323#[serde(tag = "protocol", rename_all = "lowercase")]
324pub enum VcsInfo {
325 Bzr {
326 fragment: Option<BzrFragment>,
327 },
328 Fossil {
329 fragment: Option<FossilFragment>,
330 },
331 Git {
332 fragment: Option<GitFragment>,
333 signed: bool,
334 },
335 Hg {
336 fragment: Option<HgFragment>,
337 },
338 Svn {
339 fragment: Option<SvnFragment>,
340 },
341}
342
343impl VcsInfo {
344 fn parser(vcs: VcsProtocol) -> impl FnMut(&mut &str) -> ModalResult<VcsInfo> {
349 move |input: &mut &str| match vcs {
350 VcsProtocol::Bzr => {
351 let fragment = opt(BzrFragment::parser).parse_next(input)?;
352 Ok(VcsInfo::Bzr { fragment })
353 }
354 VcsProtocol::Fossil => {
355 let fragment = opt(FossilFragment::parser).parse_next(input)?;
356 Ok(VcsInfo::Fossil { fragment })
357 }
358 VcsProtocol::Git => {
359 let mut signed = git_query(input)?;
363 let fragment = opt(GitFragment::parser).parse_next(input)?;
364 if !signed {
365 signed = git_query(input)?;
368 }
369 Ok(VcsInfo::Git { fragment, signed })
370 }
371 VcsProtocol::Hg => {
372 let fragment = opt(HgFragment::parser).parse_next(input)?;
373 Ok(VcsInfo::Hg { fragment })
374 }
375 VcsProtocol::Svn => {
376 let fragment = opt(SvnFragment::parser).parse_next(input)?;
377 Ok(VcsInfo::Svn { fragment })
378 }
379 }
380 }
381}
382
383#[derive(strum::EnumString, strum::Display)]
390#[strum(serialize_all = "lowercase")]
391enum VcsProtocol {
392 Bzr,
393 Fossil,
394 Git,
395 Hg,
396 Svn,
397}
398
399impl VcsProtocol {
400 fn parser(input: &mut &str) -> ModalResult<VcsProtocol> {
411 let protocol =
413 opt(terminated(alpha1.try_map(VcsProtocol::from_str), "+")).parse_next(input)?;
414
415 if let Some(protocol) = protocol {
416 return Ok(protocol);
417 }
418
419 let protocol = peek(alt(("git://", "svn://"))).parse_next(input)?;
425
426 match protocol {
427 "git://" => Ok(VcsProtocol::Git),
428 "svn://" => Ok(VcsProtocol::Svn),
429 _ => unreachable!(),
430 }
431 }
432}
433
434fn fragment_value(input: &mut &str) -> ModalResult<String> {
442 let _ = cut_err("=")
444 .context(StrContext::Label("fragment separator"))
445 .context(StrContext::Expected(StrContextValue::Description(
446 "a literal '='",
447 )))
448 .parse_next(input)?;
449
450 let (value, _) = repeat_till(0.., any, peek(alt(("?", "#", eof)))).parse_next(input)?;
452
453 Ok(value)
454}
455
456#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
458#[serde(rename_all = "snake_case")]
459pub enum BzrFragment {
460 Revision(String),
461}
462
463impl Display for BzrFragment {
464 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
465 match self {
466 BzrFragment::Revision(revision) => write!(f, "revision={revision}"),
467 }
468 }
469}
470
471impl BzrFragment {
472 fn parser(input: &mut &str) -> ModalResult<BzrFragment> {
476 let _ = "#".parse_next(input)?;
478
479 cut_err("revision")
481 .context(StrContext::Label("bzr revision type"))
482 .context(StrContext::Expected(StrContextValue::Description(
483 "revision keyword",
484 )))
485 .parse_next(input)?;
486
487 let value = fragment_value.parse_next(input)?;
488
489 Ok(BzrFragment::Revision(value))
490 }
491}
492
493#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
495#[serde(rename_all = "snake_case")]
496pub enum FossilFragment {
497 Branch(String),
498 Commit(String),
499 Tag(String),
500}
501
502impl Display for FossilFragment {
503 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
504 match self {
505 FossilFragment::Branch(revision) => write!(f, "branch={revision}"),
506 FossilFragment::Commit(revision) => write!(f, "commit={revision}"),
507 FossilFragment::Tag(revision) => write!(f, "tag={revision}"),
508 }
509 }
510}
511
512impl FossilFragment {
513 fn parser(input: &mut &str) -> ModalResult<FossilFragment> {
518 let _ = "#".parse_next(input)?;
520
521 let version_type = cut_err(alt(("branch", "commit", "tag")))
523 .context(StrContext::Label("fossil revision type"))
524 .context(StrContext::Expected(StrContextValue::Description(
525 "branch, commit or tag keyword",
526 )))
527 .parse_next(input)?;
528
529 let value = fragment_value.parse_next(input)?;
530
531 match version_type {
532 "branch" => Ok(FossilFragment::Branch(value.to_string())),
533 "commit" => Ok(FossilFragment::Commit(value.to_string())),
534 "tag" => Ok(FossilFragment::Tag(value.to_string())),
535 _ => unreachable!(),
536 }
537 }
538}
539
540#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
542#[serde(rename_all = "snake_case")]
543pub enum GitFragment {
544 Branch(String),
545 Commit(String),
546 Tag(String),
547}
548
549impl Display for GitFragment {
550 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
551 match self {
552 GitFragment::Branch(revision) => write!(f, "branch={revision}"),
553 GitFragment::Commit(revision) => write!(f, "commit={revision}"),
554 GitFragment::Tag(revision) => write!(f, "tag={revision}"),
555 }
556 }
557}
558
559impl GitFragment {
560 fn parser(input: &mut &str) -> ModalResult<GitFragment> {
565 let _ = "#".parse_next(input)?;
567
568 let version_type = cut_err(alt(("branch", "commit", "tag")))
570 .context(StrContext::Label("git revision type"))
571 .context(StrContext::Expected(StrContextValue::Description(
572 "branch, commit or tag keyword",
573 )))
574 .parse_next(input)?;
575
576 let value = fragment_value.parse_next(input)?;
577
578 match version_type {
579 "branch" => Ok(GitFragment::Branch(value.to_string())),
580 "commit" => Ok(GitFragment::Commit(value.to_string())),
581 "tag" => Ok(GitFragment::Tag(value.to_string())),
582 _ => unreachable!(),
583 }
584 }
585}
586
587fn git_query(input: &mut &str) -> ModalResult<bool> {
591 let query = opt("?signed").parse_next(input)?;
592 Ok(query.is_some())
593}
594
595#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
597#[serde(rename_all = "snake_case")]
598pub enum HgFragment {
599 Branch(String),
600 Revision(String),
601 Tag(String),
602}
603
604impl Display for HgFragment {
605 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
606 match self {
607 HgFragment::Branch(revision) => write!(f, "branch={revision}"),
608 HgFragment::Revision(revision) => write!(f, "revision={revision}"),
609 HgFragment::Tag(revision) => write!(f, "tag={revision}"),
610 }
611 }
612}
613
614impl HgFragment {
615 fn parser(input: &mut &str) -> ModalResult<HgFragment> {
620 let _ = "#".parse_next(input)?;
622
623 let version_type = cut_err(alt(("branch", "revision", "tag")))
625 .context(StrContext::Label("hg revision type"))
626 .context(StrContext::Expected(StrContextValue::Description(
627 "branch, revision or tag keyword",
628 )))
629 .parse_next(input)?;
630
631 let value = fragment_value.parse_next(input)?;
632
633 match version_type {
634 "branch" => Ok(HgFragment::Branch(value.to_string())),
635 "revision" => Ok(HgFragment::Revision(value.to_string())),
636 "tag" => Ok(HgFragment::Tag(value.to_string())),
637 _ => unreachable!(),
638 }
639 }
640}
641
642#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
644#[serde(rename_all = "snake_case")]
645pub enum SvnFragment {
646 Revision(String),
647}
648
649impl Display for SvnFragment {
650 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
651 match self {
652 SvnFragment::Revision(revision) => write!(f, "revision={revision}"),
653 }
654 }
655}
656
657impl SvnFragment {
658 fn parser(input: &mut &str) -> ModalResult<SvnFragment> {
663 let _ = "#".parse_next(input)?;
665
666 cut_err("revision")
668 .context(StrContext::Label("svn revision type"))
669 .context(StrContext::Expected(StrContextValue::Description(
670 "revision keyword",
671 )))
672 .parse_next(input)?;
673
674 let value = fragment_value.parse_next(input)?;
675
676 Ok(SvnFragment::Revision(value))
677 }
678}
679
680#[cfg(test)]
681mod tests {
682 use rstest::rstest;
683 use testresult::TestResult;
684
685 use super::*;
686
687 #[rstest]
688 #[case("https://example.com/", Ok("https://example.com/"))]
689 #[case(
690 "https://example.com/path?query=1",
691 Ok("https://example.com/path?query=1")
692 )]
693 #[case("ftp://example.com/", Ok("ftp://example.com/"))]
694 #[case("not-a-url", Err(url::ParseError::RelativeUrlWithoutBase.into()))]
695 fn test_url_parsing(#[case] input: &str, #[case] expected: Result<&str, Error>) {
696 let result = input.parse::<Url>();
697 assert_eq!(
698 result.as_ref().map(|v| v.to_string()),
699 expected.as_ref().map(|v| v.to_string())
700 );
701
702 if let Ok(url) = result {
703 assert_eq!(url.as_str(), input);
704 }
705 }
706
707 #[rstest]
708 #[case(
709 "git+https://example/project#tag=v1.0.0?signed",
710 Some("git+https://example/project?signed#tag=v1.0.0"),
711 SourceUrl {
712 url: Url::from_str("https://example/project").unwrap(),
713 vcs_info: Some(VcsInfo::Git {
714 fragment: Some(GitFragment::Tag("v1.0.0".to_string())),
715 signed: true
716 })
717 }
718 )]
719 #[case(
720 "git+https://example/project?signed#tag=v1.0.0",
721 None,
722 SourceUrl {
723 url: Url::from_str("https://example/project").unwrap(),
724 vcs_info: Some(VcsInfo::Git {
725 fragment: Some(GitFragment::Tag("v1.0.0".to_string())),
726 signed: true
727 })
728 }
729 )]
730 #[case(
731 "git://example/project#commit=a51720b",
732 None,
733 SourceUrl {
734 url: Url::from_str("git://example/project").unwrap(),
735 vcs_info: Some(VcsInfo::Git {
736 fragment: Some(GitFragment::Commit("a51720b".to_string())),
737 signed: false
738 })
739 }
740 )]
741 #[case(
742 "svn+https://example/project#revision=a51720b",
743 None,
744 SourceUrl {
745 url: Url::from_str("https://example/project").unwrap(),
746 vcs_info: Some(VcsInfo::Svn {
747 fragment: Some(SvnFragment::Revision("a51720b".to_string())),
748 })
749 }
750 )]
751 #[case(
752 "bzr+https://example/project#revision=a51720b",
753 None,
754 SourceUrl {
755 url: Url::from_str("https://example/project").unwrap(),
756 vcs_info: Some(VcsInfo::Bzr {
757 fragment: Some(BzrFragment::Revision("a51720b".to_string())),
758 })
759 }
760 )]
761 #[case(
762 "hg+https://example/project#branch=feature",
763 None,
764 SourceUrl {
765 url: Url::from_str("https://example/project").unwrap(),
766 vcs_info: Some(VcsInfo::Hg {
767 fragment: Some(HgFragment::Branch("feature".to_string())),
768 })
769 }
770 )]
771 #[case(
772 "fossil+https://example/project#branch=feature",
773 None,
774 SourceUrl {
775 url: Url::from_str("https://example/project").unwrap(),
776 vcs_info: Some(VcsInfo::Fossil {
777 fragment: Some(FossilFragment::Branch("feature".to_string())),
778 })
779 }
780 )]
781 #[case(
782 "https://example/project#branch=feature?signed",
783 None,
784 SourceUrl {
785 url: Url::from_str("https://example/project#branch=feature?signed").unwrap(),
786 vcs_info: None,
787 }
788 )]
789 fn test_source_url_parsing_success(
790 #[case] input: &str,
791 #[case] expected_to_string: Option<&str>,
792 #[case] expected: SourceUrl,
793 ) -> TestResult {
794 let source_url = SourceUrl::from_str(input)?;
795 assert_eq!(
796 source_url, expected,
797 "Parsed source_url should resemble the expected output."
798 );
799
800 let expected_to_string = expected_to_string.unwrap_or(input);
803 assert_eq!(
804 source_url.to_string(),
805 expected_to_string,
806 "Parsed and displayed source_url should resemble original."
807 );
808
809 Ok(())
810 }
811
812 #[rstest]
814 #[case(
815 "git+https://example/project#revision=v1.0.0?signed",
816 "invalid git revision type\nexpected branch, commit or tag keyword"
817 )]
818 #[case(
819 "git+https://example/project#branch=feature#branch=feature",
820 "invalid unexpected trailing content in URL."
821 )]
822 #[case(
823 "git+https://example/project#branch=feature?signed?signed",
824 "invalid or duplicate query parameter for detected VCS."
825 )]
826 #[case(
827 "bzr+https://example/project#branch=feature",
828 "invalid bzr revision type\nexpected revision keyword"
829 )]
830 #[case(
831 "svn+https://example/project#branch=feature",
832 "invalid svn revision type\nexpected revision keyword"
833 )]
834 #[case(
835 "hg+https://example/project#commit=154021a",
836 "invalid hg revision type\nexpected branch, revision or tag keyword"
837 )]
838 #[case(
839 "hg+https://example/project#branch=feature?signed",
840 "invalid or duplicate query parameter for detected VCS."
841 )]
842 fn test_source_url_parsing_failure(#[case] input: &str, #[case] error_snippet: &str) {
843 let result = SourceUrl::from_str(input);
844 assert!(result.is_err(), "Invalid source_url should fail to parse.");
845 let err = result.unwrap_err();
846 let pretty_error = err.to_string();
847 assert!(
848 pretty_error.contains(error_snippet),
849 "Error:\n=====\n{pretty_error}\n=====\nshould contain snippet:\n\n{error_snippet}"
850 );
851 }
852}