alpm_mtree/
path_decoder.rs

1use std::char;
2
3use winnow::{
4    ModalResult,
5    Parser,
6    combinator::{alt, cut_err, fail, preceded},
7    error::{AddContext, ContextError, ErrMode, StrContext, StrContextValue},
8    stream::{Checkpoint, Stream},
9    token::take_while,
10};
11
12/// Decodes UTF-8 characters from a string using MTREE-specific escape sequences.
13///
14/// MTREE uses various decodings.
15/// 1. the VIS_CSTYLE encoding of `strsvis(3)`, which encodes a specific set of characters. Of
16///    these, only the following control characters are allowed in filenames:
17///    - \s Space
18///    - \t Tab
19///    - \r Carriage Return
20///    - \n Line Feed
21/// 2. `#` is encoded as `\#` to differentiate between comments.
22/// 3. For all other chars, octal triplets in the style of `\360\237\214\240` are used. Check
23///    [`unicode_char`] for more info.
24///
25/// # Solution
26///
27/// To effectively decode this pattern we use winnow instead of a handwritten parser, mostly to
28/// have convenient backtracking and error messages in case we encounter invalid escape
29/// sequences or malformed escaped UTF-8.
30pub fn decode_utf8_chars(input: &mut &str) -> ModalResult<String> {
31    // This is the string we'll accumulated the decoded path into.
32    let mut path = String::new();
33
34    loop {
35        // Parse the string until we hit a `\`
36        let part = take_while(0.., |c| c != '\\').parse_next(input)?;
37        path.push_str(part);
38
39        if input.is_empty() {
40            break;
41        }
42
43        // We hit a `\`. See if it's an expected escape sequence.
44        // If none of the expected sequences are encountered, fail and throw an error.
45        let escaped = alt((
46            "\\s".map(|s: &str| s.to_string()),
47            "\\t".map(|s: &str| s.to_string()),
48            "\\r".map(|s: &str| s.to_string()),
49            "\\n".map(|s: &str| s.to_string()),
50            "\\#".map(|s: &str| s.to_string()),
51            unicode_char,
52            fail.context(StrContext::Label("escape sequence"))
53                .context(StrContext::Expected(StrContextValue::Description(
54                    "VIS_CSTYLE encoding or encoded octal triplets for unicode chars.",
55                ))),
56        ))
57        .parse_next(input)?;
58
59        let unescaped = match escaped.as_str() {
60            "\\s" => " ".to_string(),
61            "\\t" => "\t".to_string(),
62            "\\r" => "\r".to_string(),
63            "\\n" => "\n".to_string(),
64            "\\#" => "#".to_string(),
65            _ => escaped,
66        };
67
68        path.push_str(&unescaped);
69    }
70
71    Ok(path)
72}
73
74/// Parse and convert a single octal triplet string into a byte.
75///
76/// This isn't a trivial conversion as an octal has three bits and an octal triplet has thereby 9
77/// bits. The highest bit is expected to be always `0`. This is ensured via the conversion to `u8`,
78/// which would otherwise overflow and throw an error.
79fn octal_triplet(input: &mut &str) -> ModalResult<u8> {
80    preceded('\\', take_while(3, |c: char| c.is_digit(8)))
81        .verify_map(|octals| u8::from_str_radix(octals, 8).ok())
82        .parse_next(input)
83}
84
85/// Parse and decode a unicode char that's encoded as octal triplets.
86///
87/// For example, 🌠 translates to `\360\237\214\240`, which is equivalent to
88/// `0xf0 0x9f 0x8c 0xa0` hex encoding.
89///
90/// Each triplet represents a single UTF-8 byte segment, check [`octal_triplet`] for more details.
91fn unicode_char(input: &mut &str) -> ModalResult<String> {
92    // A unicode char can consist of up to 4 bytes, which is what we use this buffer for.
93    let mut unicode_bytes = Vec::new();
94
95    // Create a checkpoint in case there's an error while decoding the whole
96    // byte sequence in the very end.
97    let checkpoint = input.checkpoint();
98
99    // Parse the first octal triplet into bytes.
100    // If the input isn't an octal triplet, we hit an unknown encoding and return a backtrack error
101    // for a better error message on a higher level.
102    let first = octal_triplet(input)?;
103
104    unicode_bytes.push(first);
105
106    // Get the number of leading ones, which determines the amount of following
107    // bytes in this unicode char. This amount of leading ones can be one of `[0, 2, 3, 4]`.
108    // Other values are forbidden.
109    let leading_ones: usize = first.leading_ones() as usize;
110
111    // If there're no leading ones this char is a single byte UTF-8 char.
112    if leading_ones == 0 {
113        return bytes_to_string(input, checkpoint, unicode_bytes);
114    }
115
116    // Make sure that we didn't get an invalid amount of leading zeroes
117    if leading_ones > 4 || leading_ones == 1 {
118        let mut error = ContextError::new();
119        error = error.add_context(
120            input,
121            &checkpoint,
122            StrContext::Label("amount of leading zeroes in first UTF-8 byte"),
123        );
124        return Err(ErrMode::Cut(error));
125    }
126
127    // Due to the amount of leading ones, we know how many bytes we have to expect.
128    // Parse the amount of expected bytes and throw an error if that didn't work out.
129    for _ in 1..leading_ones {
130        let byte = cut_err(octal_triplet)
131            .context(StrContext::Label("utf8 encoded byte"))
132            .context(StrContext::Expected(StrContextValue::Description(
133                "octal triplet encoded unicode byte.",
134            )))
135            .parse_next(input)?;
136
137        unicode_bytes.push(byte);
138    }
139
140    // Read the bytes to string, which might result in another parser error.
141    bytes_to_string(input, checkpoint, unicode_bytes)
142}
143
144/// Take the UTF-8 byte sequence and parse it into a `String`.
145///
146/// # Errors
147///
148/// Returns a custom parse error if we encounter an invalid escaped UTF-8 sequence.
149fn bytes_to_string(
150    input: &mut &str,
151    checkpoint: Checkpoint<&str, &str>,
152    bytes: Vec<u8>,
153) -> ModalResult<String> {
154    match String::from_utf8(bytes) {
155        Ok(decoded) => Ok(decoded),
156        Err(_) => {
157            let mut error = ContextError::new();
158            error = error.add_context(input, &checkpoint, StrContext::Label("UTF-8 byte sequence"));
159            Err(ErrMode::Cut(error))
160        }
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use rstest::rstest;
167
168    use super::*;
169
170    #[rstest]
171    #[case(r"hello\sworld", "hello world")]
172    #[case(r"\#", "#")]
173    #[case(r"\n", "\n")]
174    #[case(r"\r", "\r")]
175    #[case(r"\360\237\214\240", "🌠")]
176    #[case(
177        r"./test\360\237\214\240\342\232\231\302\247\134test\360\237\214\240t\342\232\231e\302\247s\134t",
178        "./test🌠⚙§\\test🌠t⚙e§s\\t"
179    )]
180    fn test_decode_utf8_chars(#[case] input: &str, #[case] expected: &str) {
181        let input = input.to_string();
182        let result = decode_utf8_chars(&mut input.as_str());
183        assert_eq!(result, Ok(expected.to_string()));
184    }
185
186    #[rstest]
187    // Unknown escape sequence
188    #[case(r"invalid\escape")]
189    // First octal triplet will result in u8 int overflow.
190    #[case(r"\460\237\214\240")]
191    // 4 byte segments are expected, 3 are passed.
192    #[case(r"\360\237\214")]
193    // 5 leading zeroes in first byte.
194    #[case(r"\370\237\214\240")]
195    fn test_decode_utf8_chars_invalid_escape(#[case] input: &str) {
196        let input = input.to_string();
197        let result = decode_utf8_chars(&mut input.as_str());
198        assert!(result.is_err());
199    }
200}