ory-am/hydra

View on GitHub
cmd/cli/handler_migrate.go

Summary

Maintainability
C
7 hrs
Test Coverage
// Copyright © 2022 Ory Corp
// SPDX-License-Identifier: Apache-2.0

package cli

import (
    "bytes"
    "context"
    "fmt"
    "io"
    "io/fs"
    "os"
    "path/filepath"
    "regexp"
    "strings"
    "time"

    "github.com/ory/x/popx"
    "github.com/ory/x/servicelocatorx"

    "github.com/pkg/errors"

    "github.com/ory/x/configx"

    "github.com/ory/x/errorsx"

    "github.com/ory/x/cmdx"

    "github.com/spf13/cobra"

    "github.com/ory/hydra/v2/driver"
    "github.com/ory/hydra/v2/driver/config"
    "github.com/ory/hydra/v2/persistence"
    "github.com/ory/x/flagx"
)

type MigrateHandler struct {
    slOpts []servicelocatorx.Option
    dOpts  []driver.OptionsModifier
    cOpts  []configx.OptionModifier
}

func newMigrateHandler(slOpts []servicelocatorx.Option, dOpts []driver.OptionsModifier, cOpts []configx.OptionModifier) *MigrateHandler {
    return &MigrateHandler{
        slOpts: slOpts,
        dOpts:  dOpts,
        cOpts:  cOpts,
    }
}

const (
    genericDialectKey = "any"
)

var fragmentHeader = []byte(strings.TrimLeft(`
-- Migration generated by the command below; DO NOT EDIT.
-- hydra:generate hydra migrate gen
`, "\n"))

var blankFragment = []byte(strings.TrimLeft(`
-- This blank migration was generated to meet ory/x/popx validation criteria, see https://github.com/ory/x/pull/509; DO NOT EDIT.
-- hydra:generate hydra migrate gen
`, "\n"))

var mrx = regexp.MustCompile(`^(\d{14})000000_([^.]+)(\.[a-z0-9]+)?\.(up|down)\.sql$`)

type migration struct {
    Path      string
    ID        string
    Name      string
    Dialect   string
    Direction string
}

type migrationGroup struct {
    ID                    string
    Name                  string
    Children              []*migration
    fallbackUpMigration   *migration
    fallbackDownMigration *migration
}

func (m *migration) ReadSource(fs fs.FS) ([]byte, error) {
    f, err := fs.Open(m.Path)
    if err != nil {
        return nil, errors.WithStack(err)
    }
    defer f.Close()
    return io.ReadAll(f)
}

func (m migration) generateMigrationFragments(source []byte) ([][]byte, error) {
    chunks := bytes.Split(source, []byte("--split"))
    if len(chunks) < 1 {
        return nil, errors.New("no migration chunks found")
    }
    for i := range chunks {
        chunks[i] = append(fragmentHeader, chunks[i]...)
    }
    return chunks, nil
}

func (mg migrationGroup) fragmentName(m *migration, i int) string {
    if m.Dialect == genericDialectKey {
        return fmt.Sprintf("%s%06d_%s.%s.sql", mg.ID, i, mg.Name, m.Direction)
    } else {
        return fmt.Sprintf("%s%06d_%s.%s.%s.sql", mg.ID, i, mg.Name, m.Dialect, m.Direction)
    }
}

// GenerateSQL splits the migration sources into chunks and writes them to the
// target directory.
func (mg migrationGroup) generateSQL(sourceFS fs.FS, target string) error {
    ms := mg.Children
    if mg.fallbackDownMigration != nil {
        ms = append(ms, mg.fallbackDownMigration)
    }
    if mg.fallbackUpMigration != nil {
        ms = append(ms, mg.fallbackUpMigration)
    }
    dialectFragmentCounts := map[string]int{}
    maxFragmentCount := -1
    for _, m := range ms {
        source, err := m.ReadSource(sourceFS)
        if err != nil {
            return errors.WithStack(err)
        }

        fragments, err := m.generateMigrationFragments(source)
        dialectFragmentCounts[m.Dialect] = len(fragments)
        if maxFragmentCount < len(fragments) {
            maxFragmentCount = len(fragments)
        }
        if err != nil {
            return errors.Errorf("failed to process %s: %s", m.Path, err.Error())
        }
        for i, fragment := range fragments {
            dst := filepath.Join(target, mg.fragmentName(m, i))
            if err = os.WriteFile(dst, fragment, 0600); err != nil {
                return errors.WithStack(errors.Errorf("failed to write file %s", dst))
            }
        }
    }
    for _, m := range ms {
        for i := dialectFragmentCounts[m.Dialect]; i < maxFragmentCount; i += 1 {
            dst := filepath.Join(target, mg.fragmentName(m, i))
            if err := os.WriteFile(dst, blankFragment, 0600); err != nil {
                return errors.WithStack(errors.Errorf("failed to write file %s", dst))
            }
        }
    }
    return nil
}

func parseMigration(filename string) (*migration, error) {
    matches := mrx.FindAllStringSubmatch(filename, -1)
    if matches == nil {
        return nil, errors.Errorf("failed to parse migration filename %s; %s does not match pattern ", filename, mrx.String())
    }
    if len(matches) != 1 && len(matches[0]) != 5 {
        return nil, errors.Errorf("invalid migration %s; expected %s", filename, mrx.String())
    }
    dialect := matches[0][3]
    if dialect == "" {
        dialect = genericDialectKey
    } else {
        dialect = dialect[1:]
    }
    return &migration{
        Path:      filename,
        ID:        matches[0][1],
        Name:      matches[0][2],
        Dialect:   dialect,
        Direction: matches[0][4],
    }, nil
}

func readMigrations(migrationSourceFS fs.FS, expectedDialects []string) (map[string]*migrationGroup, error) {
    mgs := make(map[string]*migrationGroup)
    err := fs.WalkDir(migrationSourceFS, ".", func(p string, d fs.DirEntry, err2 error) error {
        if err2 != nil {
            fmt.Println("Warning: unexpected error " + err2.Error())
            return nil
        }
        if d.IsDir() {
            return nil
        }
        if p != filepath.Base(p) {
            fmt.Println("Warning: ignoring nested file " + p)
            return nil
        }

        m, err := parseMigration(p)
        if err != nil {
            return err
        }

        if _, ok := mgs[m.ID]; !ok {
            mgs[m.ID] = &migrationGroup{
                ID:       m.ID,
                Name:     m.Name,
                Children: nil,
            }
        }

        if m.Dialect == genericDialectKey && m.Direction == "up" {
            mgs[m.ID].fallbackUpMigration = m
        } else if m.Dialect == genericDialectKey && m.Direction == "down" {
            mgs[m.ID].fallbackDownMigration = m
        } else {
            mgs[m.ID].Children = append(mgs[m.ID].Children, m)
        }

        return nil
    })
    if err != nil {
        return nil, err
    }

    if len(expectedDialects) == 0 {
        return mgs, nil
    }

    eds := make(map[string]struct{})
    for i := range expectedDialects {
        eds[expectedDialects[i]] = struct{}{}
    }
    for _, mg := range mgs {
        expect := make(map[string]struct{})
        for _, m := range mg.Children {
            if _, ok := eds[m.Dialect]; !ok {
                return nil, errors.Errorf("unexpected dialect %s in filename %s", m.Dialect, m.Path)
            }

            expect[m.Dialect+"."+m.Direction] = struct{}{}
        }
        for _, d := range expectedDialects {
            if _, ok := expect[d+".up"]; !ok && mg.fallbackUpMigration == nil {
                return nil, errors.Errorf("dialect %s not found for up migration %s; use --dialects=\"\" to disable dialect validation", d, mg.ID)
            }
            if _, ok := expect[d+".down"]; !ok && mg.fallbackDownMigration == nil {
                return nil, errors.Errorf("dialect %s not found for down migration %s; use --dialects=\"\" to disable dialect validation", d, mg.ID)
            }
        }
    }

    return mgs, nil
}

func (h *MigrateHandler) MigrateGen(cmd *cobra.Command, args []string) {
    cmdx.ExactArgs(cmd, args, 2)
    expectedDialects := flagx.MustGetStringSlice(cmd, "dialects")

    sourceDir := args[0]
    targetDir := args[1]
    sourceFS := os.DirFS(sourceDir)
    mgs, err := readMigrations(sourceFS, expectedDialects)
    if err != nil {
        fmt.Println(err.Error())
        os.Exit(1)
    }
    for _, mg := range mgs {
        err = mg.generateSQL(sourceFS, targetDir)
        if err != nil {
            fmt.Println(err.Error())
            os.Exit(1)
        }
    }

    os.Exit(0)
}

func (h *MigrateHandler) makePersister(cmd *cobra.Command, args []string) (p persistence.Persister, err error) {
    var d driver.Registry

    if flagx.MustGetBool(cmd, "read-from-env") {
        d, err = driver.New(
            cmd.Context(),
            servicelocatorx.NewOptions(),
            append([]driver.OptionsModifier{
                driver.WithOptions(
                    configx.SkipValidation(),
                    configx.WithFlags(cmd.Flags())),
                driver.DisableValidation(),
                driver.DisablePreloading(),
                driver.SkipNetworkInit(),
            }, h.dOpts...))
        if err != nil {
            return nil, err
        }
        if len(d.Config().DSN()) == 0 {
            _, _ = fmt.Fprintln(cmd.ErrOrStderr(), "When using flag -e, environment variable DSN must be set.")
            return nil, cmdx.FailSilently(cmd)
        }
    } else {
        if len(args) != 1 {
            _, _ = fmt.Fprintln(cmd.ErrOrStderr(), "Please provide the database URL.")
            return nil, cmdx.FailSilently(cmd)
        }
        d, err = driver.New(
            cmd.Context(),
            servicelocatorx.NewOptions(),
            append([]driver.OptionsModifier{
                driver.WithOptions(
                    configx.WithFlags(cmd.Flags()),
                    configx.SkipValidation(),
                    configx.WithValue(config.KeyDSN, args[0]),
                ),
                driver.DisableValidation(),
                driver.DisablePreloading(),
                driver.SkipNetworkInit(),
            }, h.dOpts...))
        if err != nil {
            return nil, err
        }
    }
    return d.Persister(), nil
}

func (h *MigrateHandler) MigrateSQL(cmd *cobra.Command, args []string) (err error) {
    p, err := h.makePersister(cmd, args)
    if err != nil {
        return err
    }
    conn := p.Connection(context.Background())
    if conn == nil {
        _, _ = fmt.Fprintln(cmd.ErrOrStderr(), "Migrations can only be executed against a SQL-compatible driver but DSN is not a SQL source.")
        return cmdx.FailSilently(cmd)
    }

    if err := conn.Open(); err != nil {
        _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "Could not open the database connection:\n%+v\n", err)
        return cmdx.FailSilently(cmd)
    }

    // convert migration tables
    if err := p.PrepareMigration(context.Background()); err != nil {
        _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "Could not convert the migration table:\n%+v\n", err)
        return cmdx.FailSilently(cmd)
    }

    // print migration status
    _, _ = fmt.Fprintln(cmd.OutOrStdout(), "The following migration is planned:")

    status, err := p.MigrationStatus(context.Background())
    if err != nil {
        fmt.Fprintf(cmd.ErrOrStderr(), "Could not get the migration status:\n%+v\n", errorsx.WithStack(err))
        return cmdx.FailSilently(cmd)
    }
    _ = status.Write(os.Stdout)

    if !flagx.MustGetBool(cmd, "yes") {
        _, _ = fmt.Fprintln(cmd.ErrOrStderr(), "To skip the next question use flag --yes (at your own risk).")
        if !cmdx.AskForConfirmation("Do you wish to execute this migration plan?", nil, nil) {
            _, _ = fmt.Fprintln(cmd.OutOrStdout(), "Migration aborted.")
            return nil
        }
    }

    // apply migrations
    if err := p.MigrateUp(context.Background()); err != nil {
        _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "Could not apply migrations:\n%+v\n", errorsx.WithStack(err))
        return cmdx.FailSilently(cmd)
    }

    _, _ = fmt.Fprintln(cmd.OutOrStdout(), "Successfully applied migrations!")
    return nil
}

func (h *MigrateHandler) MigrateStatus(cmd *cobra.Command, args []string) error {
    p, err := h.makePersister(cmd, args)
    if err != nil {
        return err
    }
    conn := p.Connection(context.Background())
    if conn == nil {
        _, _ = fmt.Fprintln(cmd.ErrOrStderr(), "Migrations can only be checked against a SQL-compatible driver but DSN is not a SQL source.")
        return cmdx.FailSilently(cmd)
    }

    if err := conn.Open(); err != nil {
        _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "Could not open the database connection:\n%+v\n", err)
        return cmdx.FailSilently(cmd)
    }

    block := flagx.MustGetBool(cmd, "block")
    ctx := cmd.Context()
    s, err := p.MigrationStatus(ctx)
    if err != nil {
        _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "Could not get migration status: %+v\n", err)
        return cmdx.FailSilently(cmd)
    }

    for block && s.HasPending() {
        _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Waiting for migrations to finish...\n")
        for _, m := range s {
            if m.State == popx.Pending {
                _, _ = fmt.Fprintf(cmd.OutOrStdout(), " - %s\n", m.Name)
            }
        }
        time.Sleep(time.Second)
        s, err = p.MigrationStatus(ctx)
        if err != nil {
            _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "Could not get migration status: %+v\n", err)
            return cmdx.FailSilently(cmd)
        }
    }

    cmdx.PrintTable(cmd, s)
    return nil

}