knipknap
knipknap

Reputation: 6154

Unmarshalling YAML to ordered maps

I am trying to unmarshal the following YAML with Go YAML v3.

model:
  name: mymodel
  default-children:
  - payment

  pipeline:
    accumulator_v1:
      by-type:
        type: static
        value: false
      result-type:
        type: static
        value: 3

    item_v1:
      amount:
        type: schema-path
        value: amount
      start-date:
        type: schema-path
        value: start-date

Under pipeline is an arbitrary number of ordered items. The struct to which this should be unmarshalled looks like this:

type PipelineItemOption struct {
        Type string
        Value interface{}
}

type PipelineItem struct {
        Options map[string]PipelineItemOption
}

type Model struct {
        Name string
        DefaultChildren []string `yaml:"default-children"`
        Pipeline orderedmap[string]PipelineItem    // "pseudo code"
}

How does this work with Golang YAML v3? In v2 there was MapSlice, but that is gone in v3.

Upvotes: 7

Views: 2976

Answers (3)

Noel Yap
Noel Yap

Reputation: 19798

Building from @jws's solution and adding recursion:

func Encode(obj any) (string, error) {
    var buffer bytes.Buffer
    yamlEncoder := yaml.NewEncoder(&buffer)
    yamlEncoder.SetIndent(2)

    encodeErr := yamlEncoder.Encode(obj)
    if encodeErr != nil {
        return "", encodeErr
    }

    return buffer.String(), nil
}

type OrderedMap struct {
    Map   map[string]interface{}
    Order []string
}

func (om *OrderedMap) MarshalYAML() (interface{}, error) {
    node, err := EncodeDocumentNode(om)
    if err != nil {
        return nil, err
    }

    return node.Content[0], nil
}

// DecodeDocumentNode decodes a root yaml node into an OrderedMap
func DecodeDocumentNode(node *yaml.Node) (*OrderedMap, error) {
    if node.Kind != yaml.DocumentNode {
        return nil, fmt.Errorf("node %v is not a document node", node)
    }

    om, err := decodeMap(node.Content[0])
    if err != nil {
        return nil, err
    }

    return om, err
}

func decode(node *yaml.Node) (any, error) {
    switch node.Tag {
    case "!!null":
        return decodeNull(node)
    case "!!str":
        return decodeStr(node)
    case "!!map":
        return decodeMap(node)
    case "!!seq":
        return decodeSeq(node)
    default:
        return nil, fmt.Errorf("unknown node tag %s", node.Tag)
    }
}

func decodeNull(_ *yaml.Node) (any, error) {
    return nil, nil
}

func decodeStr(node *yaml.Node) (string, error) {
    var s string
    if err := node.Decode(&s); err != nil {
        return "", fmt.Errorf("decode error for %v: %v", node, err)
    }

    return s, nil
}

func decodeMap(node *yaml.Node) (*OrderedMap, error) {
    keyValuePairs := lo.Map(lo.Chunk(node.Content, 2), func(c []*yaml.Node, _ int) mo.Result[lo.Entry[string, any]] {
        if len(c) != 2 {
            return mo.Err[lo.Entry[string, any]](fmt.Errorf("invalid yaml; expected key/value pair"))
        }

        keyNode := c[0]
        valueNode := c[1]

        if keyNode.Tag != "!!str" {
            return mo.Err[lo.Entry[string, any]](fmt.Errorf("expected a string key but got %s on line %d", keyNode.Tag, keyNode.Line))
        }

        key, err := decodeStr(keyNode)
        if err != nil {
            return mo.Err[lo.Entry[string, any]](fmt.Errorf("key decode error: %v", err))
        }

        value, err := decode(valueNode)
        if err != nil {
            return mo.Err[lo.Entry[string, any]](fmt.Errorf("value decode error: %v", err))
        }

        return mo.Ok(lo.Entry[string, any]{
            Key:   key,
            Value: value,
        })
    })

    validErrGroups := lo.GroupBy(keyValuePairs, func(kvp mo.Result[lo.Entry[string, any]]) bool {
        return kvp.IsOk()
    })

    errs := validErrGroups[false]
    if len(errs) != 0 {
        return nil, fmt.Errorf("%v", lo.Map(errs, func(e mo.Result[lo.Entry[string, any]], _ int) error {
            return e.Error()
        }))
    }

    kvps := lo.Map(validErrGroups[true], func(kvp mo.Result[lo.Entry[string, any]], _ int) lo.Entry[string, any] {
        return kvp.MustGet()
    })
    return &OrderedMap{
        Map: lo.FromEntries(kvps),
        Order: lo.Map(kvps, func(kvp lo.Entry[string, any], _ int) string {
            return kvp.Key
        }),
    }, nil
}

func decodeSeq(node *yaml.Node) ([]any, error) {
    seq := lo.Map(node.Content, func(n *yaml.Node, _ int) mo.Result[any] {
        return mo.Try(func() (any, error) {
            switch n.Tag {
            case "!!str":
                return decodeStr(n)
            case "!!map":
                return decodeMap(n)
            default:
                return nil, fmt.Errorf("unknown tag %s for node %v", n.Tag, n)
            }
        })
    })

    validErrGroups := lo.GroupBy(seq, func(kvp mo.Result[any]) bool {
        return kvp.IsOk()
    })

    errs := validErrGroups[false]
    if len(errs) != 0 {
        return nil, fmt.Errorf("%v", lo.Map(errs, func(e mo.Result[any], _ int) error {
            return e.Error()
        }))
    }

    oms := validErrGroups[true]
    return lo.Map(oms, func(om mo.Result[any], _ int) any {
        return om.MustGet()
    }), nil
}

// EncodeDocumentNode encodes an OrderedMap into a root yaml node
func EncodeDocumentNode(om *OrderedMap) (*yaml.Node, error) {
    node, err := encodeMap(om)
    if err != nil {
        return nil, err
    }

    return &yaml.Node{
        Kind:    yaml.DocumentNode,
        Content: []*yaml.Node{node},
        Line:    1,
        Column:  1,
    }, nil
}

func encode(x any) (*yaml.Node, error) {
    if x == nil {
        return encodeNull()
    }

    switch reflect.ValueOf(x).Kind() {
    case reflect.String:
        return encodeStr(x.(string))
    case reflect.Ptr:
        return encodeMap(x.(*OrderedMap))
    case reflect.Slice:
        return encodeSeq(x.([]any))
    default:
        return nil, fmt.Errorf("unable to encode %v with kind %v", x, reflect.ValueOf(x).Kind())
    }
}

func encodeNull() (*yaml.Node, error) {
    return &yaml.Node{
        Kind: yaml.ScalarNode,
        Tag:  "!!null",
    }, nil
}

func encodeStr(s string) (*yaml.Node, error) {
    return &yaml.Node{
        Kind:  yaml.ScalarNode,
        Tag:   "!!str",
        Value: s,
    }, nil
}

func encodeMap(om *OrderedMap) (*yaml.Node, error) {
    content := lo.FlatMap(om.Order, func(key string, _ int) []mo.Result[*yaml.Node] {
        return []mo.Result[*yaml.Node]{
            mo.Try(func() (*yaml.Node, error) {
                return encodeStr(key)
            }),
            mo.Try(func() (*yaml.Node, error) {
                return encode(om.Map[key])
            }),
        }
    })

    validErrGroups := lo.GroupBy(content, func(kvp mo.Result[*yaml.Node]) bool {
        return kvp.IsOk()
    })

    errs := validErrGroups[false]
    if len(errs) != 0 {
        return nil, fmt.Errorf("%v", lo.Map(errs, func(e mo.Result[*yaml.Node], _ int) error {
            return e.Error()
        }))
    }

    nodes := validErrGroups[true]
    return &yaml.Node{
        Kind: yaml.MappingNode,
        Tag:  "!!map",
        Content: lo.Map(nodes, func(c mo.Result[*yaml.Node], _ int) *yaml.Node {
            return c.MustGet()
        }),
    }, nil
}

func encodeSeq(objs []any) (*yaml.Node, error) {
    content := lo.Map(objs, func(obj any, _ int) mo.Result[*yaml.Node] {
        return mo.Try(func() (*yaml.Node, error) {
            switch reflect.ValueOf(obj).Kind() {
            case reflect.String:
                return encodeStr(obj.(string))
            case reflect.Ptr:
                return encodeMap(obj.(*OrderedMap))
            default:
                return nil, fmt.Errorf("unknown kind %v for object %v", reflect.ValueOf(obj).Kind(), obj)
            }
        })
    })

    validErrGroups := lo.GroupBy(content, func(kvp mo.Result[*yaml.Node]) bool {
        return kvp.IsOk()
    })

    errs := validErrGroups[false]
    if len(errs) != 0 {
        return nil, fmt.Errorf("%v", lo.Map(errs, func(e mo.Result[*yaml.Node], _ int) error {
            return e.Error()
        }))
    }

    nodes := validErrGroups[true]
    return &yaml.Node{
        Kind: yaml.SequenceNode,
        Tag:  "!!seq",
        Content: lo.Map(nodes, func(c mo.Result[*yaml.Node], _ int) *yaml.Node {
            return c.MustGet()
        }),
    }, nil
}

End-to-end test:

func TestDecodeEncodeE2E(t *testing.T) {
    y := heredoc.Doc(`
        root:
          key-9:
            - value-8
            - key-7:
                key-6: value-6
              key-5: value-5
            - key-4: value-4
          key-3:
            key-2: value-2
          key-1: value-1
    `)

    var documentNode yaml.Node
    err := yaml.Unmarshal([]byte(y), &documentNode)
    require.NoError(t, err)

    decodeActual, decodeErr := DecodeDocumentNode(&documentNode)
    require.NoError(t, decodeErr)

    stringifiedOrderedMap, stringifiedOrderedMapErr := Encode(decodeActual)

    assert.NoError(t, stringifiedOrderedMapErr)
    assert.Equal(t, y, stringifiedOrderedMap)

    encodeActual, encodeErr := EncodeDocumentNode(decodeActual)
    require.NoError(t, encodeErr)

    // for troubleshooting purposes; commented out because lines and columns don't match
    // assert.Equal(t, &documentNode, encodeActual)

    stringifiedNode, stringifiedNodeErr := Encode(encodeActual)

    assert.NoError(t, stringifiedNodeErr)
    assert.Equal(t, y, stringifiedNode)
}

Upvotes: 1

jws
jws

Reputation: 2764

For me it was a bit of a learning curve to figure out what v3 expects instead of MapSlice. Similar to answer from @flyx, the yaml.Node tree needs to be walked, particularly its []Content.

Here is a utility to provide an ordered map[string]interface{} that is a little more reusable and tidy. (Though it is not as constrained as the question specified.)

Per structure above, redefine Pipeline generically:

type Model struct {
    Name string
    DefaultChildren []string `yaml:"default-children"`
    Pipeline *yaml.Node
}

Use a utility fn to traverse yaml.Node content:

// fragment
var model Model
if err := yaml.Unmarshal(&model) ; err != nil {
    return err
}

om, err := getOrderedMap(model.Pipeline)
if err != nil {
    return err
}

for _,k := range om.Order {
    v := om.Map[k]
    fmt.Printf("%s=%v\n", k, v)
}

The utility fn:

type OrderedMap struct {
    Map map[string]interface{}
    Order []string
}

func getOrderedMap(node *yaml.Node) (om *OrderedMap, err error) {
    content := node.Content
    end := len(content)
    count := end / 2

    om = &OrderedMap{
        Map: make(map[string]interface{}, count),
        Order: make([]string, 0, count),
    }
    
    for pos := 0 ; pos < end ; pos += 2 {
        keyNode := content[pos]
        valueNode := content[pos + 1]

        if keyNode.Tag != "!!str" {
            err = fmt.Errorf("expected a string key but got %s on line %d", keyNode.Tag, keyNode.Line)
            return
        }

        var k string
        if err = keyNode.Decode(&k) ; err != nil {
            return
        }

        var v interface{}
        if err = valueNode.Decode(&v) ; err != nil {
            return
        }

        om.Map[k] = v
        om.Order = append(om.Order, k)
    }

    return
}

Upvotes: 2

flyx
flyx

Reputation: 39718

You claim that marshaling to an intermediate yaml.Node is highly non-generic, but I don't really see why. It looks like this:

package main

import (
    "fmt"
    "gopkg.in/yaml.v3"
)

type PipelineItemOption struct {
        Type string
        Value interface{}
}

type PipelineItem struct {
    Name string
        Options map[string]PipelineItemOption
}

type Pipeline []PipelineItem

type Model struct {
        Name string
        DefaultChildren []string `yaml:"default-children"`
        Pipeline Pipeline
}

func (p *Pipeline) UnmarshalYAML(value *yaml.Node) error {
    if value.Kind != yaml.MappingNode {
        return fmt.Errorf("pipeline must contain YAML mapping, has %v", value.Kind)
    }
    *p = make([]PipelineItem, len(value.Content)/2)
    for i := 0; i < len(value.Content); i += 2 {
        var res = &(*p)[i/2]
        if err := value.Content[i].Decode(&res.Name); err != nil {
            return err
        }
        if err := value.Content[i+1].Decode(&res.Options); err != nil {
            return err
        }
    }
    return nil
}


var input []byte = []byte(`
model:
  name: mymodel
  default-children:
  - payment

  pipeline:
    accumulator_v1:
      by-type:
        type: static
        value: false
      result-type:
        type: static
        value: 3

    item_v1:
      amount:
        type: schema-path
        value: amount
      start-date:
        type: schema-path
        value: start-date`)

func main() {
    var f struct {
        Model Model
    }
    var err error
    if err = yaml.Unmarshal(input, &f); err != nil {
        panic(err)
    }
    fmt.Printf("%v", f)
}

Upvotes: 3

Related Questions