mirror of
https://github.com/coder/coder.git
synced 2025-07-23 21:32:07 +00:00
chore: Support anonymously embedded fields for audit diffs (#5746)
- Anonymously embedded structs are expanded as top level fields. - Unit tests for anonymously embedded structs Co-authored-by: Steven Masley <stevenmasley@coder.com>
This commit is contained in:
@ -6,6 +6,7 @@ import (
|
||||
"reflect"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/coderd/audit"
|
||||
"github.com/coder/coder/coderd/database"
|
||||
@ -18,11 +19,9 @@ func structName(t reflect.Type) string {
|
||||
func diffValues(left, right any, table Table) audit.Map {
|
||||
var (
|
||||
baseDiff = audit.Map{}
|
||||
|
||||
leftV = reflect.ValueOf(left)
|
||||
|
||||
rightV = reflect.ValueOf(right)
|
||||
rightT = reflect.TypeOf(right)
|
||||
rightT = reflect.TypeOf(right)
|
||||
leftV = reflect.ValueOf(left)
|
||||
rightV = reflect.ValueOf(right)
|
||||
|
||||
diffKey = table[structName(rightT)]
|
||||
)
|
||||
@ -31,19 +30,25 @@ func diffValues(left, right any, table Table) audit.Map {
|
||||
panic(fmt.Sprintf("dev error: type %q (type %T) attempted audit but not auditable", rightT.Name(), right))
|
||||
}
|
||||
|
||||
for i := 0; i < rightT.NumField(); i++ {
|
||||
if !rightT.Field(i).IsExported() {
|
||||
continue
|
||||
}
|
||||
// allFields contains all top level fields of the struct.
|
||||
allFields, err := flattenStructFields(leftV, rightV)
|
||||
if err != nil {
|
||||
// This should never happen. Only structs should be flattened. If an
|
||||
// error occurs, an unsupported or non-struct type was passed in.
|
||||
panic(fmt.Sprintf("dev error: failed to flatten struct fields: %v", err))
|
||||
}
|
||||
|
||||
for _, field := range allFields {
|
||||
var (
|
||||
leftF = leftV.Field(i)
|
||||
rightF = rightV.Field(i)
|
||||
leftF = field.LeftF
|
||||
rightF = field.RightF
|
||||
|
||||
leftI = leftF.Interface()
|
||||
rightI = rightF.Interface()
|
||||
)
|
||||
|
||||
diffName = rightT.Field(i).Tag.Get("json")
|
||||
var (
|
||||
diffName = field.FieldType.Tag.Get("json")
|
||||
)
|
||||
|
||||
atype, ok := diffKey[diffName]
|
||||
@ -145,6 +150,64 @@ func convertDiffType(left, right any) (newLeft, newRight any, changed bool) {
|
||||
}
|
||||
}
|
||||
|
||||
// fieldDiff has all the required information to return an audit diff for a
|
||||
// given field.
|
||||
type fieldDiff struct {
|
||||
FieldType reflect.StructField
|
||||
LeftF reflect.Value
|
||||
RightF reflect.Value
|
||||
}
|
||||
|
||||
// flattenStructFields will return all top level fields for a given structure.
|
||||
// Only anonymously embedded structs will be recursively flattened such that their
|
||||
// fields are returned as top level fields. Named nested structs will be returned
|
||||
// as a single field.
|
||||
// Conflicting field names need to be handled by the caller.
|
||||
func flattenStructFields(leftV, rightV reflect.Value) ([]fieldDiff, error) {
|
||||
// Dereference pointers if the field is a pointer field.
|
||||
if leftV.Kind() == reflect.Ptr {
|
||||
leftV = derefPointer(leftV)
|
||||
rightV = derefPointer(rightV)
|
||||
}
|
||||
|
||||
if leftV.Kind() != reflect.Struct {
|
||||
return nil, xerrors.Errorf("%q is not a struct, kind=%s", leftV.String(), leftV.Kind())
|
||||
}
|
||||
|
||||
var allFields []fieldDiff
|
||||
rightT := rightV.Type()
|
||||
|
||||
// Loop through all top level fields of the struct.
|
||||
for i := 0; i < rightT.NumField(); i++ {
|
||||
if !rightT.Field(i).IsExported() {
|
||||
continue
|
||||
}
|
||||
|
||||
var (
|
||||
leftF = leftV.Field(i)
|
||||
rightF = rightV.Field(i)
|
||||
)
|
||||
|
||||
if rightT.Field(i).Anonymous {
|
||||
// Anonymous fields are recursively flattened.
|
||||
anonFields, err := flattenStructFields(leftF, rightF)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("flatten anonymous field %q: %w", rightT.Field(i).Name, err)
|
||||
}
|
||||
allFields = append(allFields, anonFields...)
|
||||
continue
|
||||
}
|
||||
|
||||
// Single fields append as is.
|
||||
allFields = append(allFields, fieldDiff{
|
||||
LeftF: leftF,
|
||||
RightF: rightF,
|
||||
FieldType: rightT.Field(i),
|
||||
})
|
||||
}
|
||||
return allFields, nil
|
||||
}
|
||||
|
||||
// derefPointer deferences a reflect.Value that is a pointer to its underlying
|
||||
// value. It dereferences recursively until it finds a non-pointer value. If the
|
||||
// pointer is nil, it will be coerced to the zero value of the underlying type.
|
||||
|
Reference in New Issue
Block a user