Nabeel Qaiser
Nabeel Qaiser

Reputation: 11

Hibernate's Envers equivalent for golang (GORM) for auditing

We want a track the history of operations (audits) on all the entries of various tables in the database (PostgreSQL). The programming language is golang with GORM ( https://gorm.io/ ) as ORM.

Basically capturing the CREATE, UPDATE and DELETE operations on the tables in a separate (say audit_logs) table.

The desired schema of audit_logs table is:

Such functionality is available in Hibernate's enver ( https://www.baeldung.com/database-auditing-jpa#hibernate ).

However, for golang's GORM there isn't any such functionality.

Has anyone landed up on such requirement in golang. If yes, what was the approach or any solution was found/built? Any leads would be helpful.

We've tried registering the hooks on gorm.DB object.

say object for audit_log is

import (
    "gorm.io/gorm"
)

type AuditLog struct {
    EntityClass  string `gorm:"column:entity_class"`
    EntityID     string `gorm:"column:entity_id"`
    Action       string `gorm:"column:action"`
    ActionTaker  string `gorm:"column:action_taker"`
    PrevData     interface{} `gorm:"column:prev_data"`
    CurData      interface{} `gorm:"column:cur_data"`
    At           time.Time `gorm:"column:at"`
}

Now, in initialisation function of database, registered the hooks

    _ = db.AutoMigrate(&AuditLog{})  // creating the table if not exists

    _ = db.Callback().Create().After("gorm:create").Register("audit_log:create", auditLogCreateCallback)
    _ = db.Callback().Update().After("gorm:update").Register("audit_log:update", auditLogUpdateCallback)
    _ = db.Callback().Delete().After("gorm:delete").Register("audit_log:delete", auditLogDeleteCallback)

And defined the respective functions, eg,

func auditLogCreateCallback(db *gorm.DB) {
    if db.Error != nil {
        return
    }

    entityName := db.Statement.Table

    // here we need to extract the information from db.Statement object (OR other location if required)
    // basically we need EntityID, PrevData and CurData here
    // but db.Statement.Dest or db.Statement.Model both contains data indifferent forms in different situation (eg, sometimes updated via map, sometimes via model itself). Also, incase of bulk inserts / conditional edits (eg, update product's discount where product category is 'xyz') the data is stored different in the db.Statement model.

    usernameInterface := db.Statement.Context.Value(constant.UserId)
    username, _ := usernameInterface.(string)

    auditLog := AuditLog{
        EntityClass:  entityName,
        EntityID:     "<id here>",
        Action:       "CREATE",
        ActionTaker:  username,
        PrevData:     "<json dump of prev data>",
        CurData:      "<json dump of cur data>",
        At:           time.Now(),
    }

    auditDB.Create(&auditLog)
}

Upvotes: 1

Views: 720

Answers (1)

I have created an audit log using that callback function. I have made an audit_log table in the same database but it's not working then I made an audit_log table separate database it's worked fine

here are code samples

schema

type AuditLog struct {
    gorm.Modal

    UserID      uuid.UUID `json:"user_id"`
    UserEmail   string    `json:"user_email"`
    UserName    string    `json:"user_name"`
    PhcID       uuid.UUID `json:"phc_id"`
    Role        string    `json:"role"`
    EventName   string    `json:"event_name"`
    Action      string    `json:"action"`
    EntityID    string    `json:"entity_id"`
    EntityModel string    `json:"entity_model"`
    OldData     string    `json:"old_data"`
    NewData     string    `json:"new_data"`
    Diff        string    `json:"diff"`
    Domain      string    `json:"domain"`
}

this function for register callbacks

func RegisterAuditLogCallbacks(db *gorm.DB) {
    db.Callback().Create().After("gorm:create").Register("audit_log_create", logCreate)
    db.Callback().Update().Before("gorm:update").Register("audit_log_update", logUpdate)

}

logging functions

create

// logCreate callback
func logCreate(db *gorm.DB) {

    // Check if the audit log flag is set
    if db.Statement.Context.Value(auditLogFlagKey) != nil {
        return // Skip logging if already inside audit log creation
    }

    if db.Statement.Schema != nil {
        // Get the primary key value
        var entityID interface{}
        if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil {
            entityID, _ = db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
        }

        if db.Statement.Table == "audit_logs" {
            return
        }

        userID := user_id
        role := user_role

        newData := make(map[string]interface{})
        newJSON, _ := json.Marshal(db.Statement.Model)
        json.Unmarshal(newJSON, &newData)

        //remove password from audit log ,I do this because encripted passwords can be expose :)
        delete(newData, "password")
        delete(newData, "Password")

        newJSON, err := json.Marshal(newData)
        if err != nil {
            logrus.Error("Error marshalling new data: ", err)
        }

        createAuditLogEntry(db, userID, role, "Create Event", "CREATE", db.Statement.Table, entityID, nil, newJSON)
    }
}

update

func logUpdate(db *gorm.DB) {
    // Check if the audit log flag is set
    if db.Statement.Context.Value(auditLogFlagKey) != nil {
        return // Skip logging if already inside audit log creation
    }

    if db.Statement.Schema != nil {

        var entityID interface{}

        if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil {
            entityID, _ = db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
        }

        if db.Statement.Table == "audit_logs" {
            return
        }

        userID := user_id
        role := user_role

        var oldData map[string]interface{}
        db.Session(&gorm.Session{NewDB: true}).Table(db.Statement.Table).Where("id = ?", entityID).Find(&oldData)
        newData := make(map[string]interface{})
        newJSON, _ := json.Marshal(db.Statement.Model)
        json.Unmarshal(newJSON, &newData)

        // Fetch the old data from the database

        //remove password from audit log
        delete(newData, "password")
        delete(newData, "Password")
        delete(oldData, "password")

        oldJSON, err := json.Marshal(oldData)
        if err != nil {
            logrus.Error("Error marshalling old data: ", err)
        }

        newJSON, err = json.Marshal(newData)
        if err != nil {
            logrus.Error("Error marshalling new data: ", err)
        }

        createAuditLogEntry(db, userID, role, "Update Event", "UPDATE", db.Statement.Table, entityID, oldJSON, newJSON)

    }

}

function for save audit log

const auditLogFlagKey = "audit_log"

// createAuditLogEntry creates a log entry in the AuditLog table.
func createAuditLogEntry(db *gorm.DB, userID uuid.UUID, role, eventName, action, entityModel string, entityID interface{}, oldData,
    newData []byte) {

    // Check if the flag is already set in the context, to avoid recursive logging
    if db.Statement.Context.Value(auditLogFlagKey) != nil {
        // If the flag is set, skip creating the audit log entry
        return
    }

    var diffJSON []byte

    if newData != nil {

        // Optionally, you can calculate the diff here
        diffJSON = calculateDiff(oldData, newData)
    }

    auditLog := AuditLog{
        UserID:      userID,
        Role:        role,
        EventName:   eventName,
        Action:      action,
        EntityID:    fmt.Sprintf("%v", entityID),
        EntityModel: entityModel,
        OldData:     string(oldData),
        NewData:     string(newData),
        Diff:        string(diffJSON),
        UserEmail:   user_email,
        UserName:    user_name,
        PhcID:       phc_id,
        Domain:      domain,
    }

    // Set the flag in the context to avoid infinite loop
    newCtx := context.WithValue(db.Statement.Context, auditLogFlagKey, true)
    db = db.WithContext(newCtx)

    // time.Sleep(5 * time.Second)

    // Create the audit log entry
    auditLogDB.Create(&auditLog)
}

and I create this extra functions for normalize keys and get diff


func convertToSnakeCase(str string) string {
    // Compile a regex to find capital letters and replace with _<lowercase>
    var matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)")
    var matchAllCap = regexp.MustCompile("([a-z0-9])([A-Z])")

    // Insert underscore before the capital letters
    snake := matchFirstCap.ReplaceAllString(str, "${1}_${2}")
    snake = matchAllCap.ReplaceAllString(snake, "${1}_${2}")

    // Convert to lowercase
    return strings.ToLower(snake)
}

// NormalizeKeys converts all keys in a map to snake_case
func normalizeKeys(data map[string]interface{}) map[string]interface{} {
    normalized := make(map[string]interface{})
    for key, value := range data {
        // Convert each key to snake_case
        snakeKey := convertToSnakeCase(key)
        normalized[snakeKey] = value
    }
    return normalized
}

func calculateDiff(oldData, newData []byte) []byte {
    diffMap := make(map[string]interface{})

    oldMap := make(map[string]interface{})
    newMap := make(map[string]interface{})
    json.Unmarshal(oldData, &oldMap)
    json.Unmarshal(newData, &newMap)

    oldMap = normalizeKeys(oldMap)
    newMap = normalizeKeys(newMap)

    // fmt.Println("👉️ Old Map: ", oldMap)
    // fmt.Println("👉️ New Map: ", newMap)

    for key, value := range oldMap {
        if key == "password" {
            continue
        }
        if newMap[key] != value {
            diffMap[key] = map[string]interface{}{
                "old": value,
                "new": newMap[key],
            }
        }
    }

    diffJSON, _ := json.Marshal(diffMap)

    return diffJSON

}

Upvotes: 2

Related Questions