Boostport/mjml-go

View on GitHub
mjml.go

Summary

Maintainability
A
0 mins
Test Coverage
C
71%
package mjml

import (
    "bytes"
    "context"
    _ "embed"
    "encoding/json"
    "errors"
    "fmt"
    "io"
    "sync"
    "time"

    "github.com/andybalholm/brotli"
    "github.com/jackc/puddle/v2"
    "github.com/tetratelabs/wazero"
    "github.com/tetratelabs/wazero/api"
    "github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1"
)

//go:embed wasm/mjml.wasm.br
var wasm []byte

var (
    runtime      wazero.Runtime
    compiled     wazero.CompiledModule
    results      *sync.Map
    resourcePool *puddle.Pool[api.Module]
)

func init() {
    ctx := context.Background()

    results = &sync.Map{}

    br := brotli.NewReader(bytes.NewReader(wasm))
    decompressed, err := io.ReadAll(br)

    if err != nil {
        panic(fmt.Sprintf("Error decompressing wasm file: %s", err))
    }

    runtime = wazero.NewRuntime(ctx) // TODO: this should be closed

    if _, err := wasi_snapshot_preview1.Instantiate(ctx, runtime); err != nil {
        panic(fmt.Sprintf("Error instantiating wasi snapshot preview 1: %s", err))
    }

    err = registerHostFunctions(ctx, runtime)

    if err != nil {
        panic(fmt.Sprintf("Error registering host functions: %s", err))
    }

    compiled, err = runtime.CompileModule(ctx, decompressed)

    if err != nil {
        panic(fmt.Sprintf("Error compiling wasm module: %s", err))
    }

    resourcePool, err = newResourcePool(10)

    if err != nil {
        panic(fmt.Sprintf("Error creating resource pool: %s", err))
    }

    go periodicallyRemoveIdleResources(resourcePool)
}

func SetMaxWorkers(maxSize int32) error {
    oldPool := resourcePool

    newPool, err := newResourcePool(maxSize)

    if err != nil {
        return fmt.Errorf("error creating new resource pool: %w", err)
    }

    resourcePool = newPool
    oldPool.Close()

    return nil
}

type jsonResult struct {
    HTML  string `json:"html"`
    Error *Error `json:"error,omitempty"`
}

// ToHTML converts a string containing mjml to HTML while using any of the optionally provided options
func ToHTML(ctx context.Context, mjml string, toHTMLOptions ...ToHTMLOption) (string, error) {
    data := map[string]interface{}{
        "mjml": mjml,
    }

    o := options{
        data: map[string]interface{}{},
    }

    for _, opt := range toHTMLOptions {
        opt(o)
    }

    if len(o.data) > 0 {
        data["options"] = o.data
    }

    inputBytes := bytes.NewBuffer([]byte{})

    encoder := json.NewEncoder(inputBytes)
    encoder.SetEscapeHTML(false)

    err := encoder.Encode(data)

    if err != nil {
        return "", fmt.Errorf("error encoding input data: %w", err)
    }

    jsonInput := inputBytes.String()
    jsonInputLen := uint64(len(jsonInput))

    var (
        module *puddle.Resource[api.Module]
        tries  int
    )

    for {
        tries++

        var err error

        module, err = resourcePool.Acquire(ctx)

        if err != nil {

            if tries >= 30 {
                return "", fmt.Errorf("unable to accquire wasm module after 30 tries: %w", err)
            }

            if err == puddle.ErrClosedPool {
                time.Sleep(1 * time.Millisecond)
                continue
            }

            return "", fmt.Errorf("error accquiring wasm module: %w", err)
        }

        break
    }

    defer module.Release()

    mod, ok := module.Value().(api.Module)

    if !ok {
        return "", errors.New("pool resource is not an api.Module")
    }

    deallocate := mod.ExportedFunction("deallocate")
    allocate := mod.ExportedFunction("allocate")
    run := mod.ExportedFunction("run_e")
    memory := mod.Memory()

    allocation, err := allocate.Call(ctx, jsonInputLen)

    if err != nil {
        return "", fmt.Errorf("error allocating memory: %w", err)
    }

    if len(allocation) != 1 {
        return "", errors.New("invalid input pointer allocated")
    }

    inputPtr := allocation[0]

    defer deallocate.Call(ctx, inputPtr)

    if !memory.Write(uint32(inputPtr), []byte(jsonInput)) {
        return "", fmt.Errorf("error writing input to memory: %w", err)
    }

    ident, err := randomIdentifier()

    if err != nil {
        return "", fmt.Errorf("error generating identifier: %w", err)
    }

    resultCh := make(chan []byte, 1)

    results.Store(ident, resultCh)

    defer results.Delete(ident)

    _, err = run.Call(ctx, inputPtr, jsonInputLen, uint64(ident))

    if err != nil {
        return "", fmt.Errorf("error calling run: %w", err)
    }

    result := <-resultCh

    res := jsonResult{}

    err = json.Unmarshal(result, &res)

    if err != nil {
        return "", fmt.Errorf("error decoding result json: %w", err)
    }

    if res.Error != nil {
        return "", *res.Error
    }

    return res.HTML, nil
}

func registerHostFunctions(ctx context.Context, r wazero.Runtime) error {
    _, err := r.NewHostModuleBuilder("env").
        NewFunctionBuilder().
        WithFunc(returnResult).
        WithParameterNames("ptr", "len", "ident").
        Export("return_result").
        NewFunctionBuilder().
        WithFunc(func(_ uint32, _ uint32, _ uint32) uint32 {
            panic("get_static_file is unimplemented")
        }).
        Export("get_static_file").
        NewFunctionBuilder().
        WithFunc(func(_ uint32, _ uint32, _ uint32, _ uint32, _ uint32, _ uint32) uint32 {
            panic("request_set_field is unimplemented")
        }).
        Export("request_set_field").
        NewFunctionBuilder().
        WithFunc(func(_ uint32, _ uint32, _ uint32, _ uint32, _ uint32) {
            panic("resp_set_header is unimplemented")
        }).
        Export("resp_set_header").
        NewFunctionBuilder().
        WithFunc(func(_ uint32, _ uint32, _ uint32) uint32 {
            panic("cache_get is unimplemented")
        }).
        Export("cache_get").
        NewFunctionBuilder().
        WithFunc(func(_ uint32, _ uint32, _ uint32, _ uint32, _ uint32) uint32 {
            panic("add_ffi_var is unimplemented")
        }).
        Export("add_ffi_var").
        NewFunctionBuilder().
        WithFunc(func(_ uint32, _ uint32) uint32 {
            panic("get_ffi_result is unimplemented")
        }).
        Export("get_ffi_result").
        NewFunctionBuilder().
        WithFunc(func(_ uint32, _ uint32, _ uint32, _ uint32) {
            panic("return_error is unimplemented")
        }).
        Export("return_error").
        NewFunctionBuilder().
        WithFunc(func(_ uint32, _ uint32, _ uint32, _ uint32, _ uint32, _ uint32) uint32 {
            panic("fetch_url is unimplemented")
        }).
        Export("fetch_url").
        NewFunctionBuilder().
        WithFunc(func(_ uint32, _ uint32, _ uint32, _ uint32, _ uint32) uint32 {
            panic("graphql_query is unimplemented")
        }).
        Export("graphql_query").
        NewFunctionBuilder().
        WithFunc(func(_ uint32, _ uint32, _ uint32, _ uint32) uint32 {
            panic("db_exec is unimplemented")
        }).
        Export("db_exec").
        NewFunctionBuilder().
        WithFunc(func(_ uint32, _ uint32, _ uint32, _ uint32, _ uint32, _ uint32) uint32 {
            panic("cache_set is unimplemented")
        }).
        Export("cache_set").
        NewFunctionBuilder().
        WithFunc(func(_ uint32, _ uint32, _ uint32, _ uint32) uint32 {
            panic("request_get_field is unimplemented")
        }).
        Export("request_get_field").
        NewFunctionBuilder().
        WithFunc(func(ctx context.Context, m api.Module, ptr uint32, size uint32, level uint32, ident uint32) {
            panic("log_msg is unimplemented")
        }).
        Export("log_msg").
        Instantiate(ctx)

    return err
}

// returnResult is defined with a reflective signature instead of
// api.GoModuleFunc because it isn't called frequently.
func returnResult(ctx context.Context, m api.Module, ptr uint32, len uint32, ident uint32) {
    if ch, ok := results.Load(int32(ident)); ok {
        result, ok := m.Memory().Read(ptr, len)

        resultCh, isResultCh := ch.(chan []byte)

        if ok && isResultCh {
            resultCh <- result
        }
    }
}