MengCheng Wei
MengCheng Wei

Reputation: 727

How do I mock a function that write result to it's argument in Go

I am writing unit test in golang by https://github.com/stretchr/testify Suppose I have a method below,

func DoSomething(result interface{}) error {
    // write some data to result
    return nil
}

so the caller can call DoSomething as following

result := &SomeStruct{}
err := DoSomething(result)

if err != nil {
  fmt.Println(err)
} else {
  fmt.Println("The result is", result)
}

Now I know how to use testify or some other mocking tools to mock the returns value (it's err here) by something like

mockObj.On("DoSomething", mock.Anything).Return(errors.New("mock error"))

My question is "how do i mock the result argument" in this kind of scenario?

Since result is not a return value but a argument, the caller calls it by passing a pointer of a struct, and the function modify it.

Upvotes: 25

Views: 25150

Answers (4)

CodeFarmer
CodeFarmer

Reputation: 2708

If it's a pointer type argument. e.g., redis get item into that address

You may try this

Run(func(args mock.Arguments) {
    arg := args.Get(3).(*YOURTYPE)
    *arg = expect
}).

Upvotes: 0

bikbah
bikbah

Reputation: 409

You can use the (*Call).Run method:

Run sets a handler to be called before returning. It can be used when mocking a method (such as an unmarshaler) that takes a pointer to a struct and sets properties in such struct

Example:

mockObj.On("Unmarshal", mock.AnythingOfType("*map[string]interface{}")).Return().Run(func(args Arguments) {
    arg := args.Get(0).(*map[string]interface{})
    arg["foo"] = "bar"
})

Upvotes: 29

Michal K
Michal K

Reputation: 21

I highly recommend to get familiar with the gomock framework and develop towards interfaces. What you need would look something like this.

// SetArg does the job
myObj.EXPECT().DoSomething(gomock.Any()).SetArg(0, <value you want to r eturn>).Return(nil)

Upvotes: 0

Lin Du
Lin Du

Reputation: 102327

As @bikbah said, here is an example:

services/message.go:

type messageService struct {
    HttpClient http.Client
    BaseURL    string
}
func (m *messageService) MarkAllMessages(accesstoken string) []*model.MarkedMessage {
    endpoint := m.BaseURL + "/message/mark_all"
    var res model.MarkAllMessagesResponse
    if err := m.HttpClient.Post(endpoint, &MarkAllMessagesRequestPayload{Accesstoken: accesstoken}, &res); err != nil {
        fmt.Println(err)
        return res.MarkedMsgs
    }
    return res.MarkedMsgs
}

We passes res to the m.HttpClient.Post method. In this method, the res will be populated with json.unmarshal method.

mocks/http.go:

package mocks

import (
    "io"

    "github.com/stretchr/testify/mock"
)

type MockedHttp struct {
    mock.Mock
}

func (m *MockedHttp) Get(url string, data interface{}) error {
    args := m.Called(url, data)
    return args.Error(0)
}

func (m *MockedHttp) Post(url string, body interface{}, data interface{}) error {
    args := m.Called(url, body, data)
    return args.Error(0)
}

services/message_test.go:

package services_test

import (
    "errors"
    "reflect"
    "strconv"
    "testing"

    "github.com/stretchr/testify/mock"
    "github.com/mrdulin/gqlgen-cnode/graph/model"
    "github.com/mrdulin/gqlgen-cnode/services"
    "github.com/mrdulin/gqlgen-cnode/mocks"
)

const (
    baseURL     string = "http://localhost/api/v1"
    accesstoken string = "123"
)

func TestMessageService_MarkAllMessages(t *testing.T) {
    t.Run("should mark all messaages", func(t *testing.T) {
        testHttp := new(mocks.MockedHttp)
        var res model.MarkAllMessagesResponse
        var markedMsgs []*model.MarkedMessage
        for i := 1; i <= 3; i++ {
            markedMsgs = append(markedMsgs, &model.MarkedMessage{ID: strconv.Itoa(i)})
        }
        postBody := services.MarkAllMessagesRequestPayload{Accesstoken: accesstoken}
        testHttp.On("Post", baseURL+"/message/mark_all", &postBody, &res).Return(nil).Run(func(args mock.Arguments) {
            arg := args.Get(2).(*model.MarkAllMessagesResponse)
            arg.MarkedMsgs = markedMsgs
        })
        service := services.NewMessageService(testHttp, baseURL)
        got := service.MarkAllMessages(accesstoken)
        want := markedMsgs
        testHttp.AssertExpectations(t)
        if !reflect.DeepEqual(got, want) {
            t.Errorf("got wrong return value. got: %#v, want: %#v", got, want)
        }
    })

    t.Run("should print error and return empty slice", func(t *testing.T) {
        var res model.MarkAllMessagesResponse
        testHttp := new(mocks.MockedHttp)
        postBody := services.MarkAllMessagesRequestPayload{Accesstoken: accesstoken}
        testHttp.On("Post", baseURL+"/message/mark_all", &postBody, &res).Return(errors.New("network"))
        service := services.NewMessageService(testHttp, baseURL)
        got := service.MarkAllMessages(accesstoken)
        var want []*model.MarkedMessage
        testHttp.AssertExpectations(t)
        if !reflect.DeepEqual(got, want) {
            t.Errorf("got wrong return value. got: %#v, want: %#v", got, want)
        }
    })
}

In the unit test case, we populated the res in #Call.Run method and assigned the return value(res.MarkedMsgs) of service.MarkAllMessages(accesstoken) to got variable.

unit test result and coverage:

=== RUN   TestMessageService_MarkAllMessages
--- PASS: TestMessageService_MarkAllMessages (0.00s)
=== RUN   TestMessageService_MarkAllMessages/should_mark_all_messaages
    TestMessageService_MarkAllMessages/should_mark_all_messaages: message_test.go:39: PASS: Post(string,*services.MarkAllMessagesRequestPayload,*model.MarkAllMessagesResponse)
    --- PASS: TestMessageService_MarkAllMessages/should_mark_all_messaages (0.00s)
=== RUN   TestMessageService_MarkAllMessages/should_print_error_and_return_empty_slice
network
    TestMessageService_MarkAllMessages/should_print_error_and_return_empty_slice: message_test.go:53: PASS: Post(string,*services.MarkAllMessagesRequestPayload,*model.MarkAllMessagesResponse)
    --- PASS: TestMessageService_MarkAllMessages/should_print_error_and_return_empty_slice (0.00s)
PASS
coverage: 5.6% of statements in ../../gqlgen-cnode/...

Process finished with exit code 0

enter image description here

Upvotes: 5

Related Questions