fzgregor
fzgregor

Reputation: 1897

Implementing a generic length delimited hex deserializer in Serde

I want to use Serde to take as much responsibility for well-formatted user submitted input as possible. I have a number of fields that require hex values of specific, different length in the input.

How can I use Serde to enforce the allowed character set and the individual field's length without repetitive code?

Up to now, I tried a couple of different approaches. All of them involve the implementation of custom deserializers. Please let me know if there is a simpler solution, to begin with.

A macro

A macro HexString!($name:ident, $length:expr) that produces two structs: Name holding the resulting string and NameVisitor implementing the Serde deserialization visitor.

extern crate serde;
extern crate serde_json;

#[macro_use]
extern crate serde_derive;

#[macro_use]
extern crate error_chain;

error_chain!{}

macro_rules! HexString {
    ($name:ident, $length:expr) => {
        #[derive(Debug, Serialize)]
        pub struct $name(String);

        impl<'de> serde::de::Deserialize<'de> for $name {
            fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
            where
                D: serde::de::Deserializer<'de>,
            {
                deserializer.deserialize_str($nameVisitor)
            }
        }

        struct $nameVisitor;

        impl<'de> serde::de::Visitor<'de> for $nameVisitor {
            type Value = $name;

            fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
                write!(
                    formatter,
                    "an string of exactly {} hexadecimal characters",
                    $length
                )
            }

            fn visit_str<E>(self, s: &str) -> std::result::Result<Self::Value, E>
            where
                E: serde::de::Error,
            {
                use serde::de;
                if s.len() != $length {
                    return Err(de::Error::invalid_value(
                        de::Unexpected::Other(&format!(
                            "String is not {} characters long",
                            $length
                        )),
                        &self,
                    ));
                }
                for c in s.chars() {
                    if !c.is_ascii_hexdigit() {
                        return Err(de::Error::invalid_value(de::Unexpected::Char(c), &self));
                    }
                }

                let mut s = s.to_owned();
                s.make_ascii_uppercase();
                Ok($name(s))
            }
        }
    };
}

HexString!(Sha256, 32);

fn main() {
    let h: Sha256 = serde_json::from_str("a412").unwrap(); // should fail
}

Playground

This failed because I was unable to concatenate $name and Visitor in the pattern.

A trait

A trait HexString together with a HexStringVisitor trait, potentially combined with a macro in the end to ease usage:

extern crate serde;
extern crate serde_json;

#[macro_use]
extern crate serde_derive;

#[macro_use]
extern crate error_chain;

error_chain!{}

trait HexString {
    type T: HexString;
    fn init(s: String) -> Self::T;
    fn len() -> usize;
    fn visitor() -> HexStringVisitor<T=Self::T>;
}

impl<'de, T: HexString> serde::de::Deserialize<'de> for T {
    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
    where D: serde::de::Deserializer<'de>
    {
        deserializer.deserialize_str(T::visitor())
    }
}

trait HexStringVisitor {
    type T: HexString;
}

impl<'de, T: HexStringVisitor> serde::de::Visitor<'de> for T {
    type Value = T::T;

    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
        write!(formatter, "an string of exactly {} hexadecimal characters", Self::Value::len())
    }

    fn visit_str<E>(self, s: &str) -> std::result::Result<Self::Value, E>
        where E: serde::de::Error
    {
        use serde::de;
        if s.len() != Self::Value::len() {
        return Err(de::Error::invalid_value(de::Unexpected::Other(&format!("String is not {} characters long", Self::Value::len())),
        &self));
    }
        for c in s.chars() {
            if !c.is_ascii_hexdigit() {
                return Err(de::Error::invalid_value(de::Unexpected::Char(c), &self));
            }
        }

        let mut s = s.to_owned();
        s.make_ascii_uppercase();
        Ok(T::init(s))
    }
}

struct Sha256(String);
struct Sha256Visitor;

impl HexString for Sha256 {
    type T=Sha256;
    fn init(s: String) -> Sha256 {
        Sha256(s)
    }
    fn len() -> usize {
        32
    }
    fn visitor() -> Sha256Visitor {
        Sha256Visitor()
    }
}

impl HexStringVisitor for Sha256Visitor {
}

fn main() {
    let h: Sha256 = serde_json::from_str("a412").unwrap(); // should fail
}

Playground

This fails because I'm not allowed to implement the Deserialize trait for any Implementer of HexString

Upvotes: 3

Views: 861

Answers (1)

Shepmaster
Shepmaster

Reputation: 430290

As Boiethios mentions, this would be more obvious with const generics.

Since that doesn't exist yet, there's two main alternatives. One is to simulate such a feature, the other is to use an array. In this case, it makes sense to use an array because your data is a fixed length of bytes anyway.

I'd then implement Deserialize for a newtype containing any type that can be generated and then accessed as a collection of bytes:

extern crate hex;
extern crate serde;
extern crate serde_json;

use serde::de::Error;

#[derive(Debug)]
struct Hex<B>(B);

impl<'de, B> serde::de::Deserialize<'de> for Hex<B>
where
    B: AsMut<[u8]> + Default,
{
    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
    where
        D: serde::de::Deserializer<'de>,
    {
        let s = String::deserialize(deserializer)?;
        let mut b = Hex(B::default());
        match hex::decode(s) {
            Ok(v) => {
                let expected_len = b.0.as_mut().len();
                if v.len() != expected_len {
                    Err(D::Error::custom(format_args!(
                        "Expected input of {} bytes, found {}",
                        expected_len,
                        v.len()
                    )))
                } else {
                    b.0.as_mut().copy_from_slice(&v);
                    Ok(b)
                }
            }
            Err(e) => Err(D::Error::custom(format_args!(
                "Unable to deserialize: {}",
                e
            ))),
        }
    }
}

type Sha16 = Hex<[u8; 2]>;
type Sha256 = Hex<[u8; 32]>;

const TWO_BYTES: &str = r#""a412""#;
const THIRTY_TWO_BYTES: &str =
    r#""2CF24DBA5FB0A30E26E83B2AC5B9E29E1B161E5C1FA7425E73043362938B9824""#;

fn main() {
    let h: Result<Sha256, _> = serde_json::from_str(TWO_BYTES);
    println!("{:?}", h);
    let h: Result<Sha16, _> = serde_json::from_str(TWO_BYTES);
    println!("{:?}", h);

    let h: Result<Sha256, _> = serde_json::from_str(THIRTY_TWO_BYTES);
    println!("{:?}", h);
    let h: Result<Sha16, _> = serde_json::from_str(THIRTY_TWO_BYTES);
    println!("{:?}", h);
}

This has two sources of potential inefficiency:

  1. We allocate an empty array and then overwrite the bytes
  2. We allocate a Vec and then copy the bytes from it

There are ways around these, but for the purposes of user input, this is probably reasonable enough.

See also:

Upvotes: 1

Related Questions