vorteil/direktiv

View on GitHub
pkg/datastore/datastoresql/runtime_variables.go

Summary

Maintainability
B
6 hrs
Test Coverage
package datastoresql

import (
    "context"
    "errors"
    "fmt"
    "path"
    "regexp"
    "strings"

    "github.com/direktiv/direktiv/pkg/datastore"
    "github.com/google/uuid"
    "gorm.io/gorm"
)

type sqlRuntimeVariablesStore struct {
    db *gorm.DB
}

func (s *sqlRuntimeVariablesStore) GetForNamespace(ctx context.Context, namespace string, name string) (*datastore.RuntimeVariable, error) {
    variable := &datastore.RuntimeVariable{}

    if name == "" || namespace == "" {
        return nil, datastore.ErrNotFound
    }

    res := s.db.WithContext(ctx).Raw(`
                            SELECT 
                                id, namespace, workflow_path, instance_id, 
                                name, length(data) AS size, mime_type,
                                created_at, updated_at
                            FROM runtime_variables WHERE name = ? AND namespace = ? AND workflow_path IS NULL AND instance_id IS NULL`,
        name, namespace).First(variable)
    if errors.Is(res.Error, gorm.ErrRecordNotFound) {
        return nil, datastore.ErrNotFound
    }
    if res.Error != nil {
        return nil, res.Error
    }

    return variable, nil
}

func (s *sqlRuntimeVariablesStore) GetForInstance(ctx context.Context, instanceID uuid.UUID, name string) (*datastore.RuntimeVariable, error) {
    variable := &datastore.RuntimeVariable{}

    if name == "" || instanceID.String() == (uuid.UUID{}).String() {
        return nil, datastore.ErrNotFound
    }

    res := s.db.WithContext(ctx).Raw(`
                            SELECT 
                                id, namespace, workflow_path, instance_id, 
                                name, length(data) AS size, mime_type,
                                created_at, updated_at
                            FROM runtime_variables WHERE name = ? AND (instance_id=?)`,
        name, instanceID.String()).First(variable)
    if errors.Is(res.Error, gorm.ErrRecordNotFound) {
        return nil, datastore.ErrNotFound
    }
    if res.Error != nil {
        return nil, res.Error
    }

    return variable, nil
}

func (s *sqlRuntimeVariablesStore) GetForWorkflow(ctx context.Context, namespace string, workflowPath string, name string) (*datastore.RuntimeVariable, error) {
    variable := &datastore.RuntimeVariable{}

    if name == "" || workflowPath == "" || namespace == "" {
        return nil, datastore.ErrNotFound
    }

    workflowPath = path.Clean("/" + workflowPath)

    res := s.db.WithContext(ctx).Raw(`
                            SELECT 
                                id, namespace, workflow_path, instance_id, 
                                name, length(data) AS size, mime_type,
                                created_at, updated_at
                            FROM runtime_variables WHERE namespace = ? AND name = ? AND (workflow_path=?)`,
        namespace, name, workflowPath).First(variable)
    if errors.Is(res.Error, gorm.ErrRecordNotFound) {
        return nil, datastore.ErrNotFound
    }
    if res.Error != nil {
        return nil, res.Error
    }

    return variable, nil
}

func (s *sqlRuntimeVariablesStore) GetByID(ctx context.Context, id uuid.UUID) (*datastore.RuntimeVariable, error) {
    variable := &datastore.RuntimeVariable{}
    res := s.db.WithContext(ctx).Raw(`
                            SELECT 
                                id, namespace, workflow_path, instance_id, 
                                name, length(data) AS size, mime_type,
                                created_at, updated_at
                            FROM runtime_variables WHERE "id" = ?`,
        id).First(variable)
    if errors.Is(res.Error, gorm.ErrRecordNotFound) {
        return nil, datastore.ErrNotFound
    }
    if res.Error != nil {
        return nil, res.Error
    }

    return variable, nil
}

func (s *sqlRuntimeVariablesStore) listByFieldValue(ctx context.Context, fieldNames []string, fieldValues []interface{}) ([]*datastore.RuntimeVariable, error) {
    var variables []*datastore.RuntimeVariable

    conditions := make([]string, 0)
    vals := make([]interface{}, 0)

    for idx, fieldName := range fieldNames {
        if fieldValues[idx] != nil {
            conditions = append(conditions, fmt.Sprintf(`"%s" = ?`, fieldName))
            vals = append(vals, fieldValues[idx])
        } else {
            conditions = append(conditions, fmt.Sprintf(`"%s" IS NULL`, fieldName))
        }
    }

    aggregateConditions := strings.Join(conditions, " AND ")

    res := s.db.WithContext(ctx).Raw(fmt.Sprintf(`
                            SELECT 
                                id, namespace, workflow_path, instance_id, 
                                name, length(data) AS size, mime_type, 
                                created_at, updated_at
                            FROM runtime_variables WHERE %s ORDER BY created_at ASC `, aggregateConditions),
        vals...).Find(&variables)
    if res.Error != nil {
        return nil, res.Error
    }

    return variables, nil
}

func (s *sqlRuntimeVariablesStore) ListForInstance(ctx context.Context, instanceID uuid.UUID) ([]*datastore.RuntimeVariable, error) {
    return s.listByFieldValue(ctx, []string{"instance_id"}, []interface{}{instanceID.String()})
}

func (s *sqlRuntimeVariablesStore) ListForWorkflow(ctx context.Context, namespace string, workflowPath string) ([]*datastore.RuntimeVariable, error) {
    workflowPath = path.Clean("/" + workflowPath)

    return s.listByFieldValue(ctx, []string{"namespace", "workflow_path"}, []interface{}{namespace, workflowPath})
}

func (s *sqlRuntimeVariablesStore) ListForNamespace(ctx context.Context, namespace string) ([]*datastore.RuntimeVariable, error) {
    return s.listByFieldValue(ctx, []string{"namespace", "workflow_path", "instance_id"}, []interface{}{namespace, nil, nil})
}

func (s *sqlRuntimeVariablesStore) get(ctx context.Context, variable *datastore.RuntimeVariable) (*datastore.RuntimeVariable, error) {
    if variable.WorkflowPath != "" {
        return s.GetForWorkflow(ctx, variable.Namespace, variable.WorkflowPath, variable.Name)
    }

    zeroUUID := (uuid.UUID{}).String()

    if variable.InstanceID.String() != zeroUUID {
        return s.GetForInstance(ctx, variable.InstanceID, variable.Name)
    }

    return s.GetForNamespace(ctx, variable.Namespace, variable.Name)
}

// nolint:goconst
func (s *sqlRuntimeVariablesStore) Set(ctx context.Context, variable *datastore.RuntimeVariable) (*datastore.RuntimeVariable, error) {
    if variable.Name == "" {
        return nil, datastore.ErrInvalidRuntimeVariableName
    }
    if matched, _ := regexp.MatchString(datastore.RuntimeVariableNameRegexPattern, variable.Name); !matched {
        return nil, datastore.ErrInvalidRuntimeVariableName
    }
    zeroUUID := (uuid.UUID{}).String()
    if variable.InstanceID.String() == zeroUUID && variable.Namespace == "" && variable.WorkflowPath == "" {
        return nil, datastore.ErrInvalidRuntimeVariableName
    }

    if variable.WorkflowPath != "" {
        variable.WorkflowPath = path.Clean("/" + variable.WorkflowPath)
    }

    selectorField := ""

    var extra string
    args := []interface{}{
        variable.MimeType, variable.Data, variable.Namespace, variable.Name,
    }

    if variable.InstanceID.String() != zeroUUID {
        selectorField = "instance_id"
        extra = fmt.Sprintf("AND %s = ?", selectorField)
        args = append(args, variable.InstanceID.String())
    } else if variable.WorkflowPath != "" {
        selectorField = "workflow_path"
        extra = fmt.Sprintf("AND %s = ?", selectorField)
        args = append(args, variable.WorkflowPath)
    } else {
        extra = "AND workflow_path IS NULL AND instance_id IS NULL"
        // args = append(args, nil, nil)
    }

    queryString := fmt.Sprintf(
        `UPDATE runtime_variables SET
                        mime_type=?,
                        data=?,
                        updated_at=CURRENT_TIMESTAMP
                    WHERE namespace = ? AND name = ? %s;`, extra)

    res := s.db.WithContext(ctx).Exec(queryString, args...)

    // checks for duplicate key value violates unique constraint (SQLSTATE 23505)
    if res.Error != nil && strings.Contains(res.Error.Error(), "23505") {
        return nil, fmt.Errorf("%w + %w", res.Error, datastore.ErrDuplication)
    }
    if res.Error != nil {
        return nil, res.Error
    }
    if res.RowsAffected > 1 {
        return nil, fmt.Errorf("unexpected gorm update count, got: %d, want: %d", res.RowsAffected, 1)
    }
    if res.RowsAffected == 1 {
        return s.get(ctx, variable)
    }

    extraVal := ""

    if selectorField != "" {
        selectorField = ", " + selectorField
        extraVal = ", ?"
    }

    newUUID := uuid.New()
    args = append([]interface{}{newUUID}, args...)

    res = s.db.WithContext(ctx).Exec(fmt.Sprintf(`
                            INSERT INTO runtime_variables(
                                id, mime_type, data, namespace, name%s) 
                            VALUES(?, ?, ?, ?, ?%s);`, selectorField, extraVal),
        args...)

    // checks for duplicate key value violates unique constraint (SQLSTATE 23505)
    if res.Error != nil && strings.Contains(res.Error.Error(), "23505") {
        return nil, fmt.Errorf("%w + %w", res.Error, datastore.ErrDuplication)
    }
    if res.Error != nil {
        return nil, res.Error
    }
    if res.RowsAffected != 1 {
        return nil, fmt.Errorf("unexpected gorm insert count, got: %d, want: %d", res.RowsAffected, 1)
    }

    return s.GetByID(ctx, newUUID)
}

func (s *sqlRuntimeVariablesStore) Delete(ctx context.Context, id uuid.UUID) error {
    res := s.db.WithContext(ctx).Exec(
        `DELETE FROM runtime_variables WHERE id = ?`,
        id)
    if res.Error != nil {
        return res.Error
    }
    if res.RowsAffected == 0 {
        return datastore.ErrNotFound
    }
    if res.RowsAffected > 1 {
        return fmt.Errorf("unexpected runtime_variables delete count, got: %d, want: %d", res.RowsAffected, 1)
    }

    return nil
}

func (s *sqlRuntimeVariablesStore) LoadData(ctx context.Context, id uuid.UUID) ([]byte, error) {
    variable := &datastore.RuntimeVariable{}
    res := s.db.WithContext(ctx).Raw(`
                            SELECT 
                                id, namespace, workflow_path, instance_id, 
                                name, length(data) AS size, mime_type, data,
                                created_at, updated_at
                            FROM runtime_variables WHERE "id" = ?`,
        id).First(variable)
    if errors.Is(res.Error, gorm.ErrRecordNotFound) {
        return nil, datastore.ErrNotFound
    }
    if res.Error != nil {
        return nil, res.Error
    }

    return variable.Data, nil
}

func (s *sqlRuntimeVariablesStore) DeleteForWorkflow(ctx context.Context, namespace string, workflowPath string) error {
    workflowPath = path.Clean("/" + workflowPath)
    res := s.db.WithContext(ctx).Exec(
        `DELETE FROM runtime_variables WHERE namespace=? AND workflow_path=?`,
        namespace, workflowPath)
    if res.Error != nil {
        return res.Error
    }

    return nil
}

func (s *sqlRuntimeVariablesStore) SetWorkflowPath(ctx context.Context, namespace string, oldWorkflowPath string, newWorkflowPath string) error {
    oldWorkflowPath = path.Clean("/" + oldWorkflowPath)
    newWorkflowPath = path.Clean("/" + newWorkflowPath)

    res := s.db.WithContext(ctx).Exec(
        `UPDATE runtime_variables SET workflow_path=?, updated_at=CURRENT_TIMESTAMP WHERE namespace=? AND workflow_path=?`,
        newWorkflowPath, namespace, oldWorkflowPath)
    if res.Error != nil {
        return res.Error
    }

    return nil
}

func (s *sqlRuntimeVariablesStore) Create(ctx context.Context, variable *datastore.RuntimeVariable) (*datastore.RuntimeVariable, error) {
    if variable.Name == "" {
        return nil, datastore.ErrInvalidRuntimeVariableName
    }
    if matched, _ := regexp.MatchString(datastore.RuntimeVariableNameRegexPattern, variable.Name); !matched {
        return nil, datastore.ErrInvalidRuntimeVariableName
    }

    fields := "id, namespace, name, mime_type, data"
    holders := "?, ?, ?, ?, ?"
    newUUID := uuid.New()
    args := []any{newUUID, variable.Namespace, variable.Name, variable.MimeType, variable.Data}

    if variable.WorkflowPath != "" {
        variable.WorkflowPath = path.Clean("/" + variable.WorkflowPath)
    }
    if variable.WorkflowPath != "" {
        fields += ", workflow_path"
        holders += ", ?"
        args = append(args, variable.WorkflowPath)
    }
    if variable.InstanceID.String() != (uuid.UUID{}).String() && variable.WorkflowPath == "" {
        fields += ", instance_id"
        holders += ", ?"
        args = append(args, variable.InstanceID)
    }

    query := fmt.Sprintf(`
                INSERT INTO runtime_variables(%s)
                VALUES(%s)`, fields, holders)

    res := s.db.WithContext(ctx).Exec(query, args...)

    // checks for duplicate key value violates unique constraint (SQLSTATE 23505)
    if res.Error != nil && strings.Contains(res.Error.Error(), "23505") {
        return nil, fmt.Errorf("%w + %w", res.Error, datastore.ErrDuplication)
    }
    if res.Error != nil {
        return nil, res.Error
    }
    if res.RowsAffected != 1 {
        return nil, fmt.Errorf("unexpected gorm insert count, got: %d, want: %d", res.RowsAffected, 1)
    }

    return s.GetByID(ctx, newUUID)
}

// nolint:goconst
func (s *sqlRuntimeVariablesStore) Patch(ctx context.Context, id uuid.UUID, patch *datastore.RuntimeVariablePatch) (*datastore.RuntimeVariable, error) {
    if patch.Name != nil {
        if *patch.Name == "" {
            return nil, datastore.ErrInvalidRuntimeVariableName
        }
        if matched, _ := regexp.MatchString(datastore.RuntimeVariableNameRegexPattern, *patch.Name); !matched {
            return nil, datastore.ErrInvalidRuntimeVariableName
        }
    }

    fields := ""
    args := []any{}

    if patch.Name != nil {
        fields += ", name=?"
        args = append(args, *patch.Name)
    }
    if patch.MimeType != nil {
        fields += ", mime_type=?"
        args = append(args, *patch.MimeType)
    }
    if patch.Data != nil {
        fields += ", data=?"
        args = append(args, patch.Data)
    }
    args = append(args, id)
    fields = strings.Trim(fields, ",")

    query := fmt.Sprintf(`UPDATE runtime_variables SET
                        %s
                    WHERE id=?`, fields)

    res := s.db.WithContext(ctx).Exec(query, args...)

    // checks for duplicate key value violates unique constraint (SQLSTATE 23505)
    if res.Error != nil && strings.Contains(res.Error.Error(), "23505") {
        return nil, fmt.Errorf("%w + %w", res.Error, datastore.ErrDuplication)
    }
    if res.Error != nil {
        return nil, res.Error
    }
    if res.RowsAffected != 1 {
        return nil, fmt.Errorf("unexpected gorm update count, got: %d, want: %d", res.RowsAffected, 1)
    }

    return s.GetByID(ctx, id)
}