Files
coder/enterprise/audit/diff.go

165 lines
3.7 KiB
Go

package audit
import (
"database/sql"
"fmt"
"reflect"
"github.com/google/uuid"
"github.com/coder/coder/coderd/audit"
)
func structName(t reflect.Type) string {
return t.PkgPath() + "." + t.Name()
}
func diffValues(left, right any, table Table) audit.Map {
var (
baseDiff = audit.Map{}
leftV = reflect.ValueOf(left)
rightV = reflect.ValueOf(right)
rightT = reflect.TypeOf(right)
diffKey = table[structName(rightT)]
)
if diffKey == nil {
panic(fmt.Sprintf("dev error: type %q (type %T) attempted audit but not auditable", rightT.Name(), right))
}
for i := 0; i < rightT.NumField(); i++ {
var (
leftF = leftV.Field(i)
rightF = rightV.Field(i)
leftI = leftF.Interface()
rightI = rightF.Interface()
diffName = rightT.Field(i).Tag.Get("json")
)
atype, ok := diffKey[diffName]
if !ok {
panic(fmt.Sprintf("dev error: field %q lacks audit information", diffName))
}
if atype == ActionIgnore {
continue
}
// coerce struct types that would produce bad diffs.
if leftI, rightI, ok = convertDiffType(leftI, rightI); ok {
leftF, rightF = reflect.ValueOf(leftI), reflect.ValueOf(rightI)
}
// If the field is a pointer, dereference it. Nil pointers are coerced
// to the zero value of their underlying type.
if leftF.Kind() == reflect.Ptr && rightF.Kind() == reflect.Ptr {
leftF, rightF = derefPointer(leftF), derefPointer(rightF)
leftI, rightI = leftF.Interface(), rightF.Interface()
}
if !reflect.DeepEqual(leftI, rightI) {
switch atype {
case ActionTrack:
baseDiff[diffName] = audit.OldNew{Old: leftI, New: rightI}
case ActionSecret:
baseDiff[diffName] = audit.OldNew{
Old: reflect.Zero(rightF.Type()).Interface(),
New: reflect.Zero(rightF.Type()).Interface(),
Secret: true,
}
}
}
}
return baseDiff
}
// convertDiffType converts external struct types to primitive types.
//
//nolint:forcetypeassert
func convertDiffType(left, right any) (newLeft, newRight any, changed bool) {
switch typedLeft := left.(type) {
case uuid.UUID:
typedRight := right.(uuid.UUID)
// Automatically coerce Nil UUIDs to empty strings.
outLeft := typedLeft.String()
if typedLeft == uuid.Nil {
outLeft = ""
}
outRight := typedRight.String()
if typedRight == uuid.Nil {
outRight = ""
}
return outLeft, outRight, true
case uuid.NullUUID:
leftStr, _ := typedLeft.MarshalText()
rightStr, _ := right.(uuid.NullUUID).MarshalText()
return string(leftStr), string(rightStr), true
case sql.NullString:
leftStr := typedLeft.String
if !typedLeft.Valid {
leftStr = "null"
}
rightStr := right.(sql.NullString).String
if !right.(sql.NullString).Valid {
rightStr = "null"
}
return leftStr, rightStr, true
case sql.NullInt64:
var leftInt64Ptr *int64
var rightInt64Ptr *int64
if !typedLeft.Valid {
leftInt64Ptr = nil
} else {
leftInt64Ptr = ptr(typedLeft.Int64)
}
rightInt64Ptr = ptr(right.(sql.NullInt64).Int64)
if !right.(sql.NullInt64).Valid {
rightInt64Ptr = nil
}
return leftInt64Ptr, rightInt64Ptr, true
default:
return left, right, false
}
}
// derefPointer deferences a reflect.Value that is a pointer to its underlying
// value. It dereferences recursively until it finds a non-pointer value. If the
// pointer is nil, it will be coerced to the zero value of the underlying type.
func derefPointer(ptr reflect.Value) reflect.Value {
if !ptr.IsNil() {
// Grab the value the pointer references.
ptr = ptr.Elem()
} else {
// Coerce nil ptrs to zero'd values of their underlying type.
ptr = reflect.Zero(ptr.Type().Elem())
}
// Recursively deref nested pointers.
if ptr.Kind() == reflect.Ptr {
return derefPointer(ptr)
}
return ptr
}
func ptr[T any](x T) *T {
return &x
}