mirror of
https://github.com/go-gorm/gorm.git
synced 2025-03-15 17:47:28 +00:00
fix: circular reference save, close #5140
commit 2ac099a37ac7bd74f0a98a6fdc42cc8527404144 Author: Jinzhu <wosmvp@gmail.com> Date: Thu Mar 17 23:49:21 2022 +0800 Refactor #5140 commit 6e3ca2d1aa09943dcfb5d9a4b93bea28212f71be Author: a631807682 <631807682@qq.com> Date: Sun Mar 13 12:52:08 2022 +0800 test: add test for LoadOrStoreVisitMap commit 9d5c68e41000fd15dea124797dd5f2656bf6b304 Author: chenrui <chenrui@jingdaka.com> Date: Thu Mar 10 20:33:47 2022 +0800 chore: add more comment commit bfffefb179c883389b72bef8f04469c0a8418043 Author: chenrui <chenrui@jingdaka.com> Date: Thu Mar 10 20:28:48 2022 +0800 fix: should check values has been saved instead of rel.Name commit e55cdfa4b3fbcf8b80baf009e8ddb2e40d471494 Author: chenrui <chenrui@jingdaka.com> Date: Tue Mar 8 17:48:01 2022 +0800 chore: go lint commit fe4715c5bd4ac28950c97dded9848710d8becb88 Author: chenrui <chenrui@jingdaka.com> Date: Tue Mar 8 17:27:24 2022 +0800 chore: add test comment commit 326862f3f8980482a09d7d1a7f4d1011bb8a7c59 Author: chenrui <chenrui@jingdaka.com> Date: Tue Mar 8 17:22:33 2022 +0800 fix: circular reference save
This commit is contained in:
@ -69,7 +69,7 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
|
||||
}
|
||||
|
||||
if elems.Len() > 0 {
|
||||
if saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil) == nil {
|
||||
if saveAssociations(db, rel, elems, selectColumns, restricted, nil) == nil {
|
||||
for i := 0; i < elems.Len(); i++ {
|
||||
setupReferences(objs[i], elems.Index(i))
|
||||
}
|
||||
@ -82,7 +82,7 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) {
|
||||
rv = rv.Addr()
|
||||
}
|
||||
|
||||
if saveAssociations(db, rel, rv.Interface(), selectColumns, restricted, nil) == nil {
|
||||
if saveAssociations(db, rel, rv, selectColumns, restricted, nil) == nil {
|
||||
setupReferences(db.Statement.ReflectValue, rv)
|
||||
}
|
||||
}
|
||||
@ -146,7 +146,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
||||
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
||||
}
|
||||
|
||||
saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns)
|
||||
saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns)
|
||||
}
|
||||
case reflect.Struct:
|
||||
if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero {
|
||||
@ -166,7 +166,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
||||
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
||||
}
|
||||
|
||||
saveAssociations(db, rel, f.Interface(), selectColumns, restricted, assignmentColumns)
|
||||
saveAssociations(db, rel, f, selectColumns, restricted, assignmentColumns)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -237,7 +237,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
||||
assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName)
|
||||
}
|
||||
|
||||
saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, assignmentColumns)
|
||||
saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns)
|
||||
}
|
||||
}
|
||||
|
||||
@ -304,7 +304,7 @@ func SaveAfterAssociations(create bool) func(db *gorm.DB) {
|
||||
// optimize elems of reflect value length
|
||||
if elemLen := elems.Len(); elemLen > 0 {
|
||||
if v, ok := selectColumns[rel.Name+".*"]; !ok || v {
|
||||
saveAssociations(db, rel, elems.Interface(), selectColumns, restricted, nil)
|
||||
saveAssociations(db, rel, elems, selectColumns, restricted, nil)
|
||||
}
|
||||
|
||||
for i := 0; i < elemLen; i++ {
|
||||
@ -341,11 +341,17 @@ func onConflictOption(stmt *gorm.Statement, s *schema.Schema, selectColumns map[
|
||||
return
|
||||
}
|
||||
|
||||
func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{}, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error {
|
||||
func saveAssociations(db *gorm.DB, rel *schema.Relationship, rValues reflect.Value, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error {
|
||||
// stop save association loop
|
||||
if checkAssociationsSaved(db, rValues) {
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
selects, omits []string
|
||||
onConflict = onConflictOption(db.Statement, rel.FieldSchema, selectColumns, restricted, defaultUpdatingColumns)
|
||||
refName = rel.Name + "."
|
||||
values = rValues.Interface()
|
||||
)
|
||||
|
||||
for name, ok := range selectColumns {
|
||||
@ -390,3 +396,24 @@ func saveAssociations(db *gorm.DB, rel *schema.Relationship, values interface{},
|
||||
|
||||
return db.AddError(tx.Create(values).Error)
|
||||
}
|
||||
|
||||
// check association values has been saved
|
||||
// if values kind is Struct, check it has been saved
|
||||
// if values kind is Slice/Array, check all items have been saved
|
||||
var visitMapStoreKey = "gorm:saved_association_map"
|
||||
|
||||
func checkAssociationsSaved(db *gorm.DB, values reflect.Value) bool {
|
||||
if visit, ok := db.Get(visitMapStoreKey); ok {
|
||||
if v, ok := visit.(*visitMap); ok {
|
||||
if loadOrStoreVisitMap(v, values) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
} else {
|
||||
vistMap := make(visitMap)
|
||||
loadOrStoreVisitMap(&vistMap, values)
|
||||
db.Set(visitMapStoreKey, &vistMap)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sort"
|
||||
|
||||
"gorm.io/gorm"
|
||||
@ -120,3 +121,32 @@ func checkMissingWhereConditions(db *gorm.DB) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
type visitMap = map[reflect.Value]bool
|
||||
|
||||
// Check if circular values, return true if loaded
|
||||
func loadOrStoreVisitMap(vistMap *visitMap, v reflect.Value) (loaded bool) {
|
||||
if v.Kind() == reflect.Ptr {
|
||||
v = v.Elem()
|
||||
}
|
||||
|
||||
switch v.Kind() {
|
||||
case reflect.Slice, reflect.Array:
|
||||
loaded = true
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
if !loadOrStoreVisitMap(vistMap, v.Index(i)) {
|
||||
loaded = false
|
||||
}
|
||||
}
|
||||
case reflect.Struct, reflect.Interface:
|
||||
if v.CanAddr() {
|
||||
p := v.Addr()
|
||||
if _, ok := (*vistMap)[p]; ok {
|
||||
return true
|
||||
}
|
||||
(*vistMap)[p] = true
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
36
callbacks/visit_map_test.go
Normal file
36
callbacks/visit_map_test.go
Normal file
@ -0,0 +1,36 @@
|
||||
package callbacks
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadOrStoreVisitMap(t *testing.T) {
|
||||
var vm visitMap
|
||||
var loaded bool
|
||||
type testM struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
t1 := testM{Name: "t1"}
|
||||
t2 := testM{Name: "t2"}
|
||||
t3 := testM{Name: "t3"}
|
||||
|
||||
vm = make(visitMap)
|
||||
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); loaded {
|
||||
t.Fatalf("loaded should be false")
|
||||
}
|
||||
|
||||
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); !loaded {
|
||||
t.Fatalf("loaded should be true")
|
||||
}
|
||||
|
||||
// t1 already exist but t2 not
|
||||
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t1, &t2, &t3})); loaded {
|
||||
t.Fatalf("loaded should be false")
|
||||
}
|
||||
|
||||
if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t2, &t3})); !loaded {
|
||||
t.Fatalf("loaded should be true")
|
||||
}
|
||||
}
|
@ -220,3 +220,44 @@ func TestFullSaveAssociations(t *testing.T) {
|
||||
t.Errorf("Failed to preload AppliesToProduct")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveBelongsCircularReference(t *testing.T) {
|
||||
parent := Parent{}
|
||||
DB.Create(&parent)
|
||||
|
||||
child := Child{ParentID: &parent.ID, Parent: &parent}
|
||||
DB.Create(&child)
|
||||
|
||||
parent.FavChildID = child.ID
|
||||
parent.FavChild = &child
|
||||
DB.Save(&parent)
|
||||
|
||||
var parent1 Parent
|
||||
DB.First(&parent1, parent.ID)
|
||||
AssertObjEqual(t, parent, parent1, "ID", "FavChildID")
|
||||
|
||||
// Save and Updates is the same
|
||||
DB.Updates(&parent)
|
||||
DB.First(&parent1, parent.ID)
|
||||
AssertObjEqual(t, parent, parent1, "ID", "FavChildID")
|
||||
}
|
||||
|
||||
func TestSaveHasManyCircularReference(t *testing.T) {
|
||||
parent := Parent{}
|
||||
DB.Create(&parent)
|
||||
|
||||
child := Child{ParentID: &parent.ID, Parent: &parent, Name: "HasManyCircularReference"}
|
||||
child1 := Child{ParentID: &parent.ID, Parent: &parent, Name: "HasManyCircularReference1"}
|
||||
|
||||
parent.Children = []*Child{&child, &child1}
|
||||
DB.Save(&parent)
|
||||
|
||||
var children []*Child
|
||||
DB.Where("parent_id = ?", parent.ID).Find(&children)
|
||||
if len(children) != len(parent.Children) ||
|
||||
children[0].ID != parent.Children[0].ID ||
|
||||
children[1].ID != parent.Children[1].ID {
|
||||
t.Errorf("circular reference children save not equal children:%v parent.Children:%v",
|
||||
children, parent.Children)
|
||||
}
|
||||
}
|
||||
|
@ -95,7 +95,7 @@ func OpenTestConnection() (db *gorm.DB, err error) {
|
||||
|
||||
func RunMigrations() {
|
||||
var err error
|
||||
allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}}
|
||||
allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}, &Parent{}, &Child{}}
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] })
|
||||
|
||||
|
@ -80,3 +80,17 @@ type Order struct {
|
||||
Coupon *Coupon
|
||||
CouponID string
|
||||
}
|
||||
|
||||
type Parent struct {
|
||||
gorm.Model
|
||||
FavChildID uint
|
||||
FavChild *Child
|
||||
Children []*Child
|
||||
}
|
||||
|
||||
type Child struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
ParentID *uint
|
||||
Parent *Parent
|
||||
}
|
||||
|
Reference in New Issue
Block a user