internal/repository/generate_repository.go
package repository
import (
"fmt"
"io"
"strings"
"github.com/yoyo-project/yoyo/internal/repository/template"
"github.com/yoyo-project/yoyo/internal/schema"
)
func NewEntityRepositoryGenerator(packageName string, adapter Adapter, reposPath string, packagePath Finder, db schema.Database) EntityGenerator {
return func(t schema.Table, w io.StringWriter) (err error) {
var pkNames, insertCNames, selectCNames, scanFields, inFields, pkFields, colAssignments []string
for _, col := range t.Columns {
if col.PrimaryKey {
pkFields = append(pkFields, strings.ReplaceAll(template.PKFieldTemplate, template.FieldName, col.ExportedGoName()))
pkNames = append(pkNames, col.Name)
}
if !col.AutoIncrement {
insertCNames = append(insertCNames, col.Name)
}
selectCNames = append(selectCNames, col.Name)
scanFields = append(scanFields, fmt.Sprintf("&ent.%s", col.ExportedGoName()))
inFields = append(inFields, fmt.Sprintf("in.%s", col.ExportedGoName()))
}
for _, r := range t.References {
if r.HasOne {
ft, _ := db.GetTable(r.TableName)
for _, cn := range r.ColNames(ft) {
selectCNames = append(selectCNames, cn)
insertCNames = append(insertCNames, cn)
}
for _, cn := range ft.PKColNames() {
c, _ := ft.GetColumn(cn)
goName := fmt.Sprintf("%s%s", ft.ExportedGoName(), c.ExportedGoName())
scanFields = append(scanFields, fmt.Sprintf("&ent.%s", goName))
inFields = append(inFields, fmt.Sprintf("in.%s", goName))
}
}
}
for _, t2 := range db.Tables {
for _, r := range t2.References {
if r.HasMany && r.TableName == t.Name {
for _, col := range t2.PKColumns() {
selectCNames = append(selectCNames, col.Name)
insertCNames = append(insertCNames, col.Name)
scanFields = append(scanFields, fmt.Sprintf("&ent.%s", t2.ExportedGoName() + col.ExportedGoName()))
inFields = append(inFields, fmt.Sprintf("in.%s", col.ExportedGoName()))
}
}
}
}
var queryImportPath string
queryImportPath, err = packagePath(fmt.Sprintf("%s/query/%s", reposPath, t.QueryPackageName()))
if err != nil {
return fmt.Errorf("unable to generate repository: %w", err)
}
var pkCapture, pkCapTemplate string
pkReplacer := strings.NewReplacer()
switch len(t.PKColumns()) {
case 0:
// Do nothing
case 1:
col := t.PKColumns()[0]
switch col.AutoIncrement {
case true:
pkCapTemplate = template.SinglePKCaptureTemplate
case false:
pkCapTemplate = template.NoPKCapture
}
pkReplacer = strings.NewReplacer(
template.FieldName,
col.ExportedGoName(),
template.Type,
col.GoTypeString(),
)
default:
pkCapTemplate = template.MultiPKCaptureTemplate
pkReplacer = strings.NewReplacer()
}
pkCapture = pkReplacer.Replace(pkCapTemplate)
pkQueryReplacer := strings.NewReplacer(
template.QueryPackageName,
t.QueryPackageName(),
template.PKFields,
strings.Join(pkFields, "\n "),
)
pkQuery := pkQueryReplacer.Replace(template.PKQueryTemplate)
preparedStatementPlaceholders := adapter.PreparedStatementPlaceholders(len(selectCNames))
for i, colName := range selectCNames {
colAssignments = append(colAssignments, fmt.Sprintf("%s = %s", colName, preparedStatementPlaceholders[i]))
}
var saveFuncs string
if len(t.PKColumns()) > 0 {
saveFuncs = template.SaveWithPK
} else {
saveFuncs = template.SaveWithoutPK
}
r := strings.NewReplacer(
template.PackageName,
packageName,
template.Imports,
fmt.Sprintf(`"%s"`, queryImportPath),
template.ScanFields,
strings.Join(scanFields, ", "),
template.InFields,
strings.Join(inFields, ", "),
template.EntityName,
t.ExportedGoName(),
template.TableName,
t.Name,
template.InsertColumnNames,
strings.Join(insertCNames, ", "),
template.SelectColumnNames,
strings.Join(selectCNames, ", "),
template.StatementPlaceholders,
strings.Join(preparedStatementPlaceholders, ", "),
template.PKCapture,
pkCapture,
template.PKQuery,
pkQuery,
template.ColumnAssignments,
strings.Join(colAssignments, ", "),
template.QueryPackageName,
t.QueryPackageName(),
)
_, err = w.WriteString(r.Replace(strings.ReplaceAll(template.RepositoryFile, template.SaveFuncs, saveFuncs)))
return err
}
}