Reputation: 11
We want a track the history of operations (audits) on all the entries of various tables in the database (PostgreSQL). The programming language is golang with GORM ( https://gorm.io/ ) as ORM.
Basically capturing the CREATE, UPDATE and DELETE operations on the tables in a separate (say audit_logs) table.
The desired schema of audit_logs table is:
Such functionality is available in Hibernate's enver ( https://www.baeldung.com/database-auditing-jpa#hibernate ).
However, for golang's GORM there isn't any such functionality.
Has anyone landed up on such requirement in golang. If yes, what was the approach or any solution was found/built? Any leads would be helpful.
We've tried registering the hooks on gorm.DB object.
say object for audit_log is
import (
"gorm.io/gorm"
)
type AuditLog struct {
EntityClass string `gorm:"column:entity_class"`
EntityID string `gorm:"column:entity_id"`
Action string `gorm:"column:action"`
ActionTaker string `gorm:"column:action_taker"`
PrevData interface{} `gorm:"column:prev_data"`
CurData interface{} `gorm:"column:cur_data"`
At time.Time `gorm:"column:at"`
}
Now, in initialisation function of database, registered the hooks
_ = db.AutoMigrate(&AuditLog{}) // creating the table if not exists
_ = db.Callback().Create().After("gorm:create").Register("audit_log:create", auditLogCreateCallback)
_ = db.Callback().Update().After("gorm:update").Register("audit_log:update", auditLogUpdateCallback)
_ = db.Callback().Delete().After("gorm:delete").Register("audit_log:delete", auditLogDeleteCallback)
And defined the respective functions, eg,
func auditLogCreateCallback(db *gorm.DB) {
if db.Error != nil {
return
}
entityName := db.Statement.Table
// here we need to extract the information from db.Statement object (OR other location if required)
// basically we need EntityID, PrevData and CurData here
// but db.Statement.Dest or db.Statement.Model both contains data indifferent forms in different situation (eg, sometimes updated via map, sometimes via model itself). Also, incase of bulk inserts / conditional edits (eg, update product's discount where product category is 'xyz') the data is stored different in the db.Statement model.
usernameInterface := db.Statement.Context.Value(constant.UserId)
username, _ := usernameInterface.(string)
auditLog := AuditLog{
EntityClass: entityName,
EntityID: "<id here>",
Action: "CREATE",
ActionTaker: username,
PrevData: "<json dump of prev data>",
CurData: "<json dump of cur data>",
At: time.Now(),
}
auditDB.Create(&auditLog)
}
Upvotes: 1
Views: 720
Reputation: 61
I have created an audit log using that callback function. I have made an audit_log table in the same database but it's not working then I made an audit_log table separate database it's worked fine
here are code samples
schema
type AuditLog struct {
gorm.Modal
UserID uuid.UUID `json:"user_id"`
UserEmail string `json:"user_email"`
UserName string `json:"user_name"`
PhcID uuid.UUID `json:"phc_id"`
Role string `json:"role"`
EventName string `json:"event_name"`
Action string `json:"action"`
EntityID string `json:"entity_id"`
EntityModel string `json:"entity_model"`
OldData string `json:"old_data"`
NewData string `json:"new_data"`
Diff string `json:"diff"`
Domain string `json:"domain"`
}
this function for register callbacks
func RegisterAuditLogCallbacks(db *gorm.DB) {
db.Callback().Create().After("gorm:create").Register("audit_log_create", logCreate)
db.Callback().Update().Before("gorm:update").Register("audit_log_update", logUpdate)
}
logging functions
create
// logCreate callback
func logCreate(db *gorm.DB) {
// Check if the audit log flag is set
if db.Statement.Context.Value(auditLogFlagKey) != nil {
return // Skip logging if already inside audit log creation
}
if db.Statement.Schema != nil {
// Get the primary key value
var entityID interface{}
if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil {
entityID, _ = db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
}
if db.Statement.Table == "audit_logs" {
return
}
userID := user_id
role := user_role
newData := make(map[string]interface{})
newJSON, _ := json.Marshal(db.Statement.Model)
json.Unmarshal(newJSON, &newData)
//remove password from audit log ,I do this because encripted passwords can be expose :)
delete(newData, "password")
delete(newData, "Password")
newJSON, err := json.Marshal(newData)
if err != nil {
logrus.Error("Error marshalling new data: ", err)
}
createAuditLogEntry(db, userID, role, "Create Event", "CREATE", db.Statement.Table, entityID, nil, newJSON)
}
}
update
func logUpdate(db *gorm.DB) {
// Check if the audit log flag is set
if db.Statement.Context.Value(auditLogFlagKey) != nil {
return // Skip logging if already inside audit log creation
}
if db.Statement.Schema != nil {
var entityID interface{}
if db.Statement.Schema != nil && db.Statement.Schema.PrioritizedPrimaryField != nil {
entityID, _ = db.Statement.Schema.PrioritizedPrimaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue)
}
if db.Statement.Table == "audit_logs" {
return
}
userID := user_id
role := user_role
var oldData map[string]interface{}
db.Session(&gorm.Session{NewDB: true}).Table(db.Statement.Table).Where("id = ?", entityID).Find(&oldData)
newData := make(map[string]interface{})
newJSON, _ := json.Marshal(db.Statement.Model)
json.Unmarshal(newJSON, &newData)
// Fetch the old data from the database
//remove password from audit log
delete(newData, "password")
delete(newData, "Password")
delete(oldData, "password")
oldJSON, err := json.Marshal(oldData)
if err != nil {
logrus.Error("Error marshalling old data: ", err)
}
newJSON, err = json.Marshal(newData)
if err != nil {
logrus.Error("Error marshalling new data: ", err)
}
createAuditLogEntry(db, userID, role, "Update Event", "UPDATE", db.Statement.Table, entityID, oldJSON, newJSON)
}
}
function for save audit log
const auditLogFlagKey = "audit_log"
// createAuditLogEntry creates a log entry in the AuditLog table.
func createAuditLogEntry(db *gorm.DB, userID uuid.UUID, role, eventName, action, entityModel string, entityID interface{}, oldData,
newData []byte) {
// Check if the flag is already set in the context, to avoid recursive logging
if db.Statement.Context.Value(auditLogFlagKey) != nil {
// If the flag is set, skip creating the audit log entry
return
}
var diffJSON []byte
if newData != nil {
// Optionally, you can calculate the diff here
diffJSON = calculateDiff(oldData, newData)
}
auditLog := AuditLog{
UserID: userID,
Role: role,
EventName: eventName,
Action: action,
EntityID: fmt.Sprintf("%v", entityID),
EntityModel: entityModel,
OldData: string(oldData),
NewData: string(newData),
Diff: string(diffJSON),
UserEmail: user_email,
UserName: user_name,
PhcID: phc_id,
Domain: domain,
}
// Set the flag in the context to avoid infinite loop
newCtx := context.WithValue(db.Statement.Context, auditLogFlagKey, true)
db = db.WithContext(newCtx)
// time.Sleep(5 * time.Second)
// Create the audit log entry
auditLogDB.Create(&auditLog)
}
and I create this extra functions for normalize keys and get diff
func convertToSnakeCase(str string) string {
// Compile a regex to find capital letters and replace with _<lowercase>
var matchFirstCap = regexp.MustCompile("(.)([A-Z][a-z]+)")
var matchAllCap = regexp.MustCompile("([a-z0-9])([A-Z])")
// Insert underscore before the capital letters
snake := matchFirstCap.ReplaceAllString(str, "${1}_${2}")
snake = matchAllCap.ReplaceAllString(snake, "${1}_${2}")
// Convert to lowercase
return strings.ToLower(snake)
}
// NormalizeKeys converts all keys in a map to snake_case
func normalizeKeys(data map[string]interface{}) map[string]interface{} {
normalized := make(map[string]interface{})
for key, value := range data {
// Convert each key to snake_case
snakeKey := convertToSnakeCase(key)
normalized[snakeKey] = value
}
return normalized
}
func calculateDiff(oldData, newData []byte) []byte {
diffMap := make(map[string]interface{})
oldMap := make(map[string]interface{})
newMap := make(map[string]interface{})
json.Unmarshal(oldData, &oldMap)
json.Unmarshal(newData, &newMap)
oldMap = normalizeKeys(oldMap)
newMap = normalizeKeys(newMap)
// fmt.Println("👉️ Old Map: ", oldMap)
// fmt.Println("👉️ New Map: ", newMap)
for key, value := range oldMap {
if key == "password" {
continue
}
if newMap[key] != value {
diffMap[key] = map[string]interface{}{
"old": value,
"new": newMap[key],
}
}
}
diffJSON, _ := json.Marshal(diffMap)
return diffJSON
}
Upvotes: 2