callback_create.go

Summary

Maintainability
D
2 days
Test Coverage
package gorm

import (
    "fmt"
    "strings"
)

// Define callbacks for creating
func init() {
    DefaultCallback.Create().Register("gorm:begin_transaction", beginTransactionCallback)
    DefaultCallback.Create().Register("gorm:before_create", beforeCreateCallback)
    DefaultCallback.Create().Register("gorm:save_before_associations", saveBeforeAssociationsCallback)
    DefaultCallback.Create().Register("gorm:update_time_stamp", updateTimeStampForCreateCallback)
    DefaultCallback.Create().Register("gorm:create", createCallback)
    DefaultCallback.Create().Register("gorm:force_reload_after_create", forceReloadAfterCreateCallback)
    DefaultCallback.Create().Register("gorm:save_after_associations", saveAfterAssociationsCallback)
    DefaultCallback.Create().Register("gorm:after_create", afterCreateCallback)
    DefaultCallback.Create().Register("gorm:commit_or_rollback_transaction", commitOrRollbackTransactionCallback)
}

// beforeCreateCallback will invoke `BeforeSave`, `BeforeCreate` method before creating
func beforeCreateCallback(scope *Scope) {
    if !scope.HasError() {
        scope.CallMethod("BeforeSave")
    }
    if !scope.HasError() {
        scope.CallMethod("BeforeCreate")
    }
}

// updateTimeStampForCreateCallback will set `CreatedAt`, `UpdatedAt` when creating
func updateTimeStampForCreateCallback(scope *Scope) {
    if !scope.HasError() {
        now := scope.db.nowFunc()

        if createdAtField, ok := scope.FieldByName("CreatedAt"); ok {
            if createdAtField.IsBlank {
                createdAtField.Set(now)
            }
        }

        if updatedAtField, ok := scope.FieldByName("UpdatedAt"); ok {
            if updatedAtField.IsBlank {
                updatedAtField.Set(now)
            }
        }
    }
}

// createCallback the callback used to insert data into database
func createCallback(scope *Scope) {
    if !scope.HasError() {
        defer scope.trace(NowFunc())

        var (
            columns, placeholders        []string
            blankColumnsWithDefaultValue []string
        )

        for _, field := range scope.Fields() {
            if scope.changeableField(field) {
                if field.IsNormal && !field.IsIgnored {
                    if field.IsBlank && field.HasDefaultValue {
                        blankColumnsWithDefaultValue = append(blankColumnsWithDefaultValue, scope.Quote(field.DBName))
                        scope.InstanceSet("gorm:blank_columns_with_default_value", blankColumnsWithDefaultValue)
                    } else if !field.IsPrimaryKey || !field.IsBlank {
                        columns = append(columns, scope.Quote(field.DBName))
                        placeholders = append(placeholders, scope.AddToVars(field.Field.Interface()))
                    }
                } else if field.Relationship != nil && field.Relationship.Kind == "belongs_to" {
                    for _, foreignKey := range field.Relationship.ForeignDBNames {
                        if foreignField, ok := scope.FieldByName(foreignKey); ok && !scope.changeableField(foreignField) {
                            columns = append(columns, scope.Quote(foreignField.DBName))
                            placeholders = append(placeholders, scope.AddToVars(foreignField.Field.Interface()))
                        }
                    }
                }
            }
        }

        var (
            returningColumn = "*"
            quotedTableName = scope.QuotedTableName()
            primaryField    = scope.PrimaryField()
            extraOption     string
            insertModifier  string
        )

        if str, ok := scope.Get("gorm:insert_option"); ok {
            extraOption = fmt.Sprint(str)
        }
        if str, ok := scope.Get("gorm:insert_modifier"); ok {
            insertModifier = strings.ToUpper(fmt.Sprint(str))
            if insertModifier == "INTO" {
                insertModifier = ""
            }
        }

        if primaryField != nil {
            returningColumn = scope.Quote(primaryField.DBName)
        }

        lastInsertIDOutputInterstitial := scope.Dialect().LastInsertIDOutputInterstitial(quotedTableName, returningColumn, columns)
        var lastInsertIDReturningSuffix string
        if lastInsertIDOutputInterstitial == "" {
            lastInsertIDReturningSuffix = scope.Dialect().LastInsertIDReturningSuffix(quotedTableName, returningColumn)
        }

        if len(columns) == 0 {
            scope.Raw(fmt.Sprintf(
                "INSERT%v INTO %v %v%v%v",
                addExtraSpaceIfExist(insertModifier),
                quotedTableName,
                scope.Dialect().DefaultValueStr(),
                addExtraSpaceIfExist(extraOption),
                addExtraSpaceIfExist(lastInsertIDReturningSuffix),
            ))
        } else {
            scope.Raw(fmt.Sprintf(
                "INSERT%v INTO %v (%v)%v VALUES (%v)%v%v",
                addExtraSpaceIfExist(insertModifier),
                scope.QuotedTableName(),
                strings.Join(columns, ","),
                addExtraSpaceIfExist(lastInsertIDOutputInterstitial),
                strings.Join(placeholders, ","),
                addExtraSpaceIfExist(extraOption),
                addExtraSpaceIfExist(lastInsertIDReturningSuffix),
            ))
        }

        // execute create sql: no primaryField
        if primaryField == nil {
            if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
                // set rows affected count
                scope.db.RowsAffected, _ = result.RowsAffected()

                // set primary value to primary field
                if primaryField != nil && primaryField.IsBlank {
                    if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil {
                        scope.Err(primaryField.Set(primaryValue))
                    }
                }
            }
            return
        }

        // execute create sql: lastInsertID implemention for majority of dialects
        if lastInsertIDReturningSuffix == "" && lastInsertIDOutputInterstitial == "" {
            if result, err := scope.SQLDB().Exec(scope.SQL, scope.SQLVars...); scope.Err(err) == nil {
                // set rows affected count
                scope.db.RowsAffected, _ = result.RowsAffected()

                // set primary value to primary field
                if primaryField != nil && primaryField.IsBlank {
                    if primaryValue, err := result.LastInsertId(); scope.Err(err) == nil {
                        scope.Err(primaryField.Set(primaryValue))
                    }
                }
            }
            return
        }

        // execute create sql: dialects with additional lastInsertID requirements (currently postgres & mssql)
        if primaryField.Field.CanAddr() {
            if err := scope.SQLDB().QueryRow(scope.SQL, scope.SQLVars...).Scan(primaryField.Field.Addr().Interface()); scope.Err(err) == nil {
                primaryField.IsBlank = false
                scope.db.RowsAffected = 1
            }
        } else {
            scope.Err(ErrUnaddressable)
        }
        return
    }
}

// forceReloadAfterCreateCallback will reload columns that having default value, and set it back to current object
func forceReloadAfterCreateCallback(scope *Scope) {
    if blankColumnsWithDefaultValue, ok := scope.InstanceGet("gorm:blank_columns_with_default_value"); ok {
        var shouldScan bool
        db := scope.DB().New().Table(scope.TableName()).Select(blankColumnsWithDefaultValue.([]string))
        for _, field := range scope.Fields() {
            if field.IsPrimaryKey && !field.IsBlank {
                db = db.Where(fmt.Sprintf("%v = ?", field.DBName), field.Field.Interface())
                shouldScan = true
            }
        }

        if !shouldScan {
            return
        }

        db.Scan(scope.Value)
    }
}

// afterCreateCallback will invoke `AfterCreate`, `AfterSave` method after creating
func afterCreateCallback(scope *Scope) {
    if !scope.HasError() {
        scope.CallMethod("AfterCreate")
    }
    if !scope.HasError() {
        scope.CallMethod("AfterSave")
    }
}