2024-11-21 17:03:31 +08:00

522 lines
12 KiB

package gorm
import (
// for Config.cacheStore store PreparedStmtDB key
const preparedStmtDBKey = "preparedStmt"
// Config GORM config
type Config struct {
// GORM perform single create, update, delete operations in transactions by default to ensure database data integrity
// You can disable it by setting `SkipDefaultTransaction` to true
SkipDefaultTransaction bool
// NamingStrategy tables, columns naming strategy
NamingStrategy schema.Namer
// FullSaveAssociations full save associations
FullSaveAssociations bool
// Logger
Logger logger.Interface
// NowFunc the function to be used when creating a new timestamp
NowFunc func() time.Time
// DryRun generate sql without execute
DryRun bool
// PrepareStmt executes the given query in cached statement
PrepareStmt bool
// DisableAutomaticPing
DisableAutomaticPing bool
// DisableForeignKeyConstraintWhenMigrating
DisableForeignKeyConstraintWhenMigrating bool
// IgnoreRelationshipsWhenMigrating
IgnoreRelationshipsWhenMigrating bool
// DisableNestedTransaction disable nested transaction
DisableNestedTransaction bool
// AllowGlobalUpdate allow global update
AllowGlobalUpdate bool
// QueryFields executes the SQL query with all fields of the table
QueryFields bool
// CreateBatchSize default create batch size
CreateBatchSize int
// TranslateError enabling error translation
TranslateError bool
// PropagateUnscoped propagate Unscoped to every other nested statement
PropagateUnscoped bool
// ClauseBuilders clause builder
ClauseBuilders map[string]clause.ClauseBuilder
// ConnPool db conn pool
ConnPool ConnPool
// Dialector database dialector
// Plugins registered plugins
Plugins map[string]Plugin
callbacks *callbacks
cacheStore *sync.Map
// Apply update config to new config
func (c *Config) Apply(config *Config) error {
if config != c {
*config = *c
return nil
// AfterInitialize initialize plugins after db connected
func (c *Config) AfterInitialize(db *DB) error {
if db != nil {
for _, plugin := range c.Plugins {
if err := plugin.Initialize(db); err != nil {
return err
return nil
// Option gorm option interface
type Option interface {
Apply(*Config) error
AfterInitialize(*DB) error
// DB GORM DB definition
type DB struct {
Error error
RowsAffected int64
Statement *Statement
clone int
// Session session config when create session with Session() method
type Session struct {
DryRun bool
PrepareStmt bool
NewDB bool
Initialized bool
SkipHooks bool
SkipDefaultTransaction bool
DisableNestedTransaction bool
AllowGlobalUpdate bool
FullSaveAssociations bool
PropagateUnscoped bool
QueryFields bool
Context context.Context
Logger logger.Interface
NowFunc func() time.Time
CreateBatchSize int
// Open initialize db session based on dialector
func Open(dialector Dialector, opts ...Option) (db *DB, err error) {
config := &Config{}
sort.Slice(opts, func(i, j int) bool {
_, isConfig := opts[i].(*Config)
_, isConfig2 := opts[j].(*Config)
return isConfig && !isConfig2
for _, opt := range opts {
if opt != nil {
if applyErr := opt.Apply(config); applyErr != nil {
return nil, applyErr
defer func(opt Option) {
if errr := opt.AfterInitialize(db); errr != nil {
err = errr
if d, ok := dialector.(interface{ Apply(*Config) error }); ok {
if err = d.Apply(config); err != nil {
if config.NamingStrategy == nil {
config.NamingStrategy = schema.NamingStrategy{IdentifierMaxLength: 64} // Default Identifier length is 64
if config.Logger == nil {
config.Logger = logger.Default
if config.NowFunc == nil {
config.NowFunc = func() time.Time { return time.Now().Local() }
if dialector != nil {
config.Dialector = dialector
if config.Plugins == nil {
config.Plugins = map[string]Plugin{}
if config.cacheStore == nil {
config.cacheStore = &sync.Map{}
db = &DB{Config: config, clone: 1}
db.callbacks = initializeCallbacks(db)
if config.ClauseBuilders == nil {
config.ClauseBuilders = map[string]clause.ClauseBuilder{}
if config.Dialector != nil {
err = config.Dialector.Initialize(db)
if err != nil {
if db, _ := db.DB(); db != nil {
_ = db.Close()
if config.TranslateError {
if _, ok := db.Dialector.(ErrorTranslator); !ok {
config.Logger.Warn(context.Background(), "The TranslateError option is enabled, but the Dialector %s does not implement ErrorTranslator.", db.Dialector.Name())
if config.PrepareStmt {
preparedStmt := NewPreparedStmtDB(db.ConnPool)
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
db.ConnPool = preparedStmt
db.Statement = &Statement{
DB: db,
ConnPool: db.ConnPool,
Context: context.Background(),
Clauses: map[string]clause.Clause{},
if err == nil && !config.DisableAutomaticPing {
if pinger, ok := db.ConnPool.(interface{ Ping() error }); ok {
err = pinger.Ping()
if err != nil {
config.Logger.Error(context.Background(), "failed to initialize database, got error %v", err)
// Session create new db session
func (db *DB) Session(config *Session) *DB {
var (
txConfig = *db.Config
tx = &DB{
Config: &txConfig,
Statement: db.Statement,
Error: db.Error,
clone: 1,
if config.CreateBatchSize > 0 {
tx.Config.CreateBatchSize = config.CreateBatchSize
if config.SkipDefaultTransaction {
tx.Config.SkipDefaultTransaction = true
if config.AllowGlobalUpdate {
txConfig.AllowGlobalUpdate = true
if config.FullSaveAssociations {
txConfig.FullSaveAssociations = true
if config.PropagateUnscoped {
txConfig.PropagateUnscoped = true
if config.Context != nil || config.PrepareStmt || config.SkipHooks {
tx.Statement = tx.Statement.clone()
tx.Statement.DB = tx
if config.Context != nil {
tx.Statement.Context = config.Context
if config.PrepareStmt {
var preparedStmt *PreparedStmtDB
if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok {
preparedStmt = v.(*PreparedStmtDB)
} else {
preparedStmt = NewPreparedStmtDB(db.ConnPool)
db.cacheStore.Store(preparedStmtDBKey, preparedStmt)
switch t := tx.Statement.ConnPool.(type) {
case Tx:
tx.Statement.ConnPool = &PreparedStmtTX{
Tx: t,
PreparedStmtDB: preparedStmt,
tx.Statement.ConnPool = &PreparedStmtDB{
ConnPool: db.Config.ConnPool,
Mux: preparedStmt.Mux,
Stmts: preparedStmt.Stmts,
txConfig.ConnPool = tx.Statement.ConnPool
txConfig.PrepareStmt = true
if config.SkipHooks {
tx.Statement.SkipHooks = true
if config.DisableNestedTransaction {
txConfig.DisableNestedTransaction = true
if !config.NewDB {
tx.clone = 2
if config.DryRun {
tx.Config.DryRun = true
if config.QueryFields {
tx.Config.QueryFields = true
if config.Logger != nil {
tx.Config.Logger = config.Logger
if config.NowFunc != nil {
tx.Config.NowFunc = config.NowFunc
if config.Initialized {
tx = tx.getInstance()
return tx
// WithContext change current instance db's context to ctx
func (db *DB) WithContext(ctx context.Context) *DB {
return db.Session(&Session{Context: ctx})
// Debug start debug mode
func (db *DB) Debug() (tx *DB) {
tx = db.getInstance()
return tx.Session(&Session{
Logger: db.Logger.LogMode(logger.Info),
// Set store value with key into current db instance's context
func (db *DB) Set(key string, value interface{}) *DB {
tx := db.getInstance()
tx.Statement.Settings.Store(key, value)
return tx
// Get get value with key from current db instance's context
func (db *DB) Get(key string) (interface{}, bool) {
return db.Statement.Settings.Load(key)
// InstanceSet store value with key into current db instance's context
func (db *DB) InstanceSet(key string, value interface{}) *DB {
tx := db.getInstance()
tx.Statement.Settings.Store(fmt.Sprintf("%p", tx.Statement)+key, value)
return tx
// InstanceGet get value with key from current db instance's context
func (db *DB) InstanceGet(key string) (interface{}, bool) {
return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key)
// Callback returns callback manager
func (db *DB) Callback() *callbacks {
return db.callbacks
// AddError add error to db
func (db *DB) AddError(err error) error {
if err != nil {
if db.Config.TranslateError {
if errTranslator, ok := db.Dialector.(ErrorTranslator); ok {
err = errTranslator.Translate(err)
if db.Error == nil {
db.Error = err
} else {
db.Error = fmt.Errorf("%v; %w", db.Error, err)
return db.Error
// DB returns `*sql.DB`
func (db *DB) DB() (*sql.DB, error) {
connPool := db.ConnPool
if db.Statement != nil && db.Statement.ConnPool != nil {
connPool = db.Statement.ConnPool
if tx, ok := connPool.(*sql.Tx); ok && tx != nil {
return (*sql.DB)(reflect.ValueOf(tx).Elem().FieldByName("db").UnsafePointer()), nil
if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil {
if sqldb, err := dbConnector.GetDBConn(); sqldb != nil || err != nil {
return sqldb, err
if sqldb, ok := connPool.(*sql.DB); ok && sqldb != nil {
return sqldb, nil
return nil, ErrInvalidDB
func (db *DB) getInstance() *DB {
if db.clone > 0 {
tx := &DB{Config: db.Config, Error: db.Error}
if db.clone == 1 {
// clone with new statement
tx.Statement = &Statement{
DB: tx,
ConnPool: db.Statement.ConnPool,
Context: db.Statement.Context,
Clauses: map[string]clause.Clause{},
Vars: make([]interface{}, 0, 8),
SkipHooks: db.Statement.SkipHooks,
if db.Config.PropagateUnscoped {
tx.Statement.Unscoped = db.Statement.Unscoped
} else {
// with clone statement
tx.Statement = db.Statement.clone()
tx.Statement.DB = tx
return tx
return db
// Expr returns clause.Expr, which can be used to pass SQL expression as params
func Expr(expr string, args ...interface{}) clause.Expr {
return clause.Expr{SQL: expr, Vars: args}
// SetupJoinTable setup join table schema
func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error {
var (
tx = db.getInstance()
stmt = tx.Statement
modelSchema, joinSchema *schema.Schema
err := stmt.Parse(model)
if err != nil {
return err
modelSchema = stmt.Schema
err = stmt.Parse(joinTable)
if err != nil {
return err
joinSchema = stmt.Schema
relation, ok := modelSchema.Relationships.Relations[field]
isRelation := ok && relation.JoinTable != nil
if !isRelation {
return fmt.Errorf("failed to find relation: %s", field)
for _, ref := range relation.References {
f := joinSchema.LookUpField(ref.ForeignKey.DBName)
if f == nil {
return fmt.Errorf("missing field %s for join table", ref.ForeignKey.DBName)
f.DataType = ref.ForeignKey.DataType
f.GORMDataType = ref.ForeignKey.GORMDataType
if f.Size == 0 {
f.Size = ref.ForeignKey.Size
ref.ForeignKey = f
for name, rel := range relation.JoinTable.Relationships.Relations {
if _, ok := joinSchema.Relationships.Relations[name]; !ok {
rel.Schema = joinSchema
joinSchema.Relationships.Relations[name] = rel
relation.JoinTable = joinSchema
return nil
// Use use plugin
func (db *DB) Use(plugin Plugin) error {
name := plugin.Name()
if _, ok := db.Plugins[name]; ok {
return ErrRegistered
if err := plugin.Initialize(db); err != nil {
return err
db.Plugins[name] = plugin
return nil
// ToSQL for generate SQL string.
// db.ToSQL(func(tx *gorm.DB) *gorm.DB {
// return tx.Model(&User{}).Where(&User{Name: "foo", Age: 20})
// .Limit(10).Offset(5)
// .Order("name ASC")
// .First(&User{})
// })
func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string {
tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}))
stmt := tx.Statement
return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...)