0C3D
0C3D

Reputation: 348

Deserialize vector of values where one value is enum tag with Serde

I am attempting to process some JSON from a WebSocket formatted in a challenging way for Serde to deserialize. There are many possible responses from this API, but at the top level, it's always has at least three values. First is some ID, the second is the data as an object, and the third is the type of object.

Here is some example JSON

[
  32,
  {
    "speed": 900,
    "altitude": 30000
    "num_passengers": 200
  },
  "plane"
]


[
  42,
  {
    "num_ingredients": 12,
    "cook_time": "4 mins",
    "oven_temp": 180
  },
  {
    "num_ingredients": 4,
    "cook_time": "25 mins",
    "oven_temp": 250
  },
  "recipe"
]

I would like to be able to deserialize this into an enum.

enum Messages {
    Plane(Plane),
    Recipe(Recipe),
}

In reality, there are a lot more than two message types (around 20), and I expect to be receiving quite a high volume of messages. Due to this, I'm a bit concerned about the performance of using an untagged enum. Is there any other solution to deserializing data with this structure?

Upvotes: 0

Views: 573

Answers (1)

0C3D
0C3D

Reputation: 348

Here's the solution I came up with.

Define the target structures as normal

#[derive(Debug, Serialize, Deserialize)]
struct Plane {
    speed: u32,
    altitude: i32,
    num_passengers: u16
}

#[derive(Debug, Serialize, Deserialize)]
struct Recipe {
    num_ingredients: u32,
    cook_time: String,
    oven_temp: u16
}

type Recipes = Vec<Recipe>;

Then we need two enums. One to hold the data and one which describes the type. Notice the one that holds the data has a helper function (new). This creates an instance of itself based on the type definition and the JSON value

#[derive(Debug, Serialize, Deserialize, EnumDisplay)]
#[serde(rename_all = "camelCase")]
enum MessageType {
    Plane,
    Recipe
}

#[derive(Debug, Serialize)]
enum Message {
    Plane(Plane),
    Recipe(Recipes)
}

impl Message {
    fn new(message_type: &MessageType, message: serde_json::Value) -> Result<Message, serde_json::Error> {
        Ok(match message_type {
            MessageType::Plane => Self::Plane(serde_json::from_value(message)?),
            MessageType::Recipe => Self::Recipe(serde_json::from_value(message)?),
        })
    }
}

The last piece is the wrapper struct. This will hold the message id and the data. This is the struct that needs the custom deserializer, so don't derive Deserialize.

#[derive(Debug, Serialize)]
struct MessageWrapper {
    id: i64,
    message: Message,
}

Now the fun part. The custom deserialize implementation.

Step by step, here's what it's doing.

  1. Parse the first value (the id) as an i64
  2. loop through the next values until it finds the message type
    1. If the value is a string, it must be the message type, and we can stop looping
    2. If the value is an object or an array, add it to our list of messages
  3. Turn the array of values into a single value
    1. If there's only one thing in the array, remove it, and that's the value
    2. If there's more than one thing, construct a new Array type value using the array of values
  4. Use the new function we created on the Messsage type to parse our message
  5. Construct and return a new MessageWrapper
impl<'de> Deserialize<'de> for MessageWrapper {
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
        where
            D: Deserializer<'de>,
    {
        struct MessageWrapperVisitor {}

        impl<'de> Visitor<'de> for MessageWrapperVisitor {
            type Value = MessageWrapper;

            fn expecting(&self, formatter: &mut Formatter) -> std::fmt::Result {
                formatter.write_str("[i64, Message..., String]")
            }

            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
                where
                    A: SeqAccess<'de>,
            {
                let id: i64 = seq
                    .next_element()?
                    .ok_or_else(|| DeError::invalid_length(0, &self))?;

                let message_type: MessageType;
                let mut messages = Vec::with_capacity(1);
                loop {
                    let val: serde_json::Value = seq
                        .next_element()?
                        .ok_or_else(|| DeError::invalid_length(1, &self))?;

                    if val.is_string() {
                        message_type = serde_json::from_value(val).or_else(|e| Err(DeError::custom(e)))?;
                        break;
                    } else if val.is_object() || val.is_array() {
                        messages.push(val);
                    } else {
                        return Err(DeError::custom("unexpected value. Expected channel name or json object"));
                    }
                }

                if messages.is_empty() {
                    return Err(DeError::custom("no data"));
                }
                let message: serde_json::Value;
                if messages.len() == 1 {
                    message = messages.remove(0);
                } else {
                    message = serde_json::Value::Array(messages)
                }

                let message = Message::new(&message_type, message)
                    .or_else(|e| Err(DeError::custom(format!("inner object cannot be deserialized as {} -> {}", message_type, e))))?;
                Ok(MessageWrapper {
                    id,
                    message,
                })
            }
        }

        deserializer.deserialize_seq(MessageWrapperVisitor {})
    }
}

Upvotes: 0

Related Questions