alpm_mtree/path_decoder.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200
use std::char;
use winnow::{
ModalResult,
Parser,
combinator::{alt, cut_err, fail, preceded},
error::{AddContext, ContextError, ErrMode, StrContext, StrContextValue},
stream::{Checkpoint, Stream},
token::take_while,
};
/// Decodes UTF-8 characters from a string using MTREE-specific escape sequences.
///
/// MTREE uses various decodings.
/// 1. the VIS_CSTYLE encoding of `strsvis(3)`, which encodes a specific set of characters. Of
/// these, only the following control characters are allowed in filenames:
/// - \s Space
/// - \t Tab
/// - \r Carriage Return
/// - \n Line Feed
/// 2. `#` is encoded as `\#` to differentiate between comments.
/// 3. For all other chars, octal triplets in the style of `\360\237\214\240` are used. Check
/// [`unicode_char`] for more info.
///
/// # Solution
///
/// To effectively decode this pattern we use winnow instead of a handwritten parser, mostly to
/// have convenient backtracking and error messages in case we encounter invalid escape
/// sequences or malformed escaped UTF-8.
pub fn decode_utf8_chars(input: &mut &str) -> ModalResult<String> {
// This is the string we'll accumulated the decoded path into.
let mut path = String::new();
loop {
// Parse the string until we hit a `\`
let part = take_while(0.., |c| c != '\\').parse_next(input)?;
path.push_str(part);
if input.is_empty() {
break;
}
// We hit a `\`. See if it's an expected escape sequence.
// If none of the expected sequences are encountered, fail and throw an error.
let escaped = alt((
"\\s".map(|s: &str| s.to_string()),
"\\t".map(|s: &str| s.to_string()),
"\\r".map(|s: &str| s.to_string()),
"\\n".map(|s: &str| s.to_string()),
"\\#".map(|s: &str| s.to_string()),
unicode_char,
fail.context(StrContext::Label("escape sequence"))
.context(StrContext::Expected(StrContextValue::Description(
"VIS_CSTYLE encoding or encoded octal triplets for unicode chars.",
))),
))
.parse_next(input)?;
let unescaped = match escaped.as_str() {
"\\s" => " ".to_string(),
"\\t" => "\t".to_string(),
"\\r" => "\r".to_string(),
"\\n" => "\n".to_string(),
"\\#" => "#".to_string(),
_ => escaped,
};
path.push_str(&unescaped);
}
Ok(path)
}
/// Parse and convert a single octal triplet string into a byte.
///
/// This isn't a trivial conversion as an octal has three bits and an octal triplet has thereby 9
/// bits. The highest bit is expected to be always `0`. This is ensured via the conversion to `u8`,
/// which would otherwise overflow and throw an error.
fn octal_triplet(input: &mut &str) -> ModalResult<u8> {
preceded('\\', take_while(3, |c: char| c.is_digit(8)))
.verify_map(|octals| u8::from_str_radix(octals, 8).ok())
.parse_next(input)
}
/// Parse and decode a unicode char that's encoded as octal triplets.
///
/// For example, 🌠translates to `\360\237\214\240`, which is equivalent to
/// `0xf0 0x9f 0x8c 0xa0` hex encoding.
///
/// Each triplet represents a single UTF-8 byte segment, check [`octal_triplet`] for more details.
fn unicode_char(input: &mut &str) -> ModalResult<String> {
// A unicode char can consist of up to 4 bytes, which is what we use this buffer for.
let mut unicode_bytes = Vec::new();
// Create a checkpoint in case there's an error while decoding the whole
// byte sequence in the very end.
let checkpoint = input.checkpoint();
// Parse the first octal triplet into bytes.
// If the input isn't an octal triplet, we hit an unknown encoding and return a backtrack error
// for a better error message on a higher level.
let first = octal_triplet(input)?;
unicode_bytes.push(first);
// Get the number of leading ones, which determines the amount of following
// bytes in this unicode char. This amount of leading ones can be one of `[0, 2, 3, 4]`.
// Other values are forbidden.
let leading_ones: usize = first.leading_ones() as usize;
// If there're no leading ones this char is a single byte UTF-8 char.
if leading_ones == 0 {
return bytes_to_string(input, checkpoint, unicode_bytes);
}
// Make sure that we didn't get an invalid amount of leading zeroes
if leading_ones > 4 || leading_ones == 1 {
let mut error = ContextError::new();
error = error.add_context(
input,
&checkpoint,
StrContext::Label("amount of leading zeroes in first UTF-8 byte"),
);
return Err(ErrMode::Cut(error));
}
// Due to the amount of leading ones, we know how many bytes we have to expect.
// Parse the amount of expected bytes and throw an error if that didn't work out.
for _ in 1..leading_ones {
let byte = cut_err(octal_triplet)
.context(StrContext::Label("utf8 encoded byte"))
.context(StrContext::Expected(StrContextValue::Description(
"octal triplet encoded unicode byte.",
)))
.parse_next(input)?;
unicode_bytes.push(byte);
}
// Read the bytes to string, which might result in another parser error.
bytes_to_string(input, checkpoint, unicode_bytes)
}
/// Take the UTF-8 byte sequence and parse it into a `String`.
///
/// # Errors
///
/// Returns a custom parse error if we encounter an invalid escaped UTF-8 sequence.
fn bytes_to_string(
input: &mut &str,
checkpoint: Checkpoint<&str, &str>,
bytes: Vec<u8>,
) -> ModalResult<String> {
match String::from_utf8(bytes) {
Ok(decoded) => Ok(decoded),
Err(_) => {
let mut error = ContextError::new();
error = error.add_context(input, &checkpoint, StrContext::Label("UTF-8 byte sequence"));
Err(ErrMode::Cut(error))
}
}
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use super::*;
#[rstest]
#[case(r"hello\sworld", "hello world")]
#[case(r"\#", "#")]
#[case(r"\n", "\n")]
#[case(r"\r", "\r")]
#[case(r"\360\237\214\240", "🌠")]
#[case(
r"./test\360\237\214\240\342\232\231\302\247\134test\360\237\214\240t\342\232\231e\302\247s\134t",
"./test🌠⚙§\\test🌠t⚙e§s\\t"
)]
fn test_decode_utf8_chars(#[case] input: &str, #[case] expected: &str) {
let input = input.to_string();
let result = decode_utf8_chars(&mut input.as_str());
assert_eq!(result, Ok(expected.to_string()));
}
#[rstest]
// Unknown escape sequence
#[case(r"invalid\escape")]
// First octal triplet will result in u8 int overflow.
#[case(r"\460\237\214\240")]
// 4 byte segments are expected, 3 are passed.
#[case(r"\360\237\214")]
// 5 leading zeroes in first byte.
#[case(r"\370\237\214\240")]
fn test_decode_utf8_chars_invalid_escape(#[case] input: &str) {
let input = input.to_string();
let result = decode_utf8_chars(&mut input.as_str());
assert!(result.is_err());
}
}