piotrkowalczuk/mnemosyne

View on GitHub
mnemosyned/middleware.go

Summary

Maintainability
A
45 mins
Test Coverage
package mnemosyned

import (
    "database/sql"
    "time"

    "github.com/lib/pq"
    "github.com/piotrkowalczuk/mnemosyne"
    "github.com/piotrkowalczuk/mnemosyne/internal/service/logger"
    "github.com/piotrkowalczuk/mnemosyne/internal/storage"
    "go.uber.org/zap"
    "go.uber.org/zap/zapcore"
    "golang.org/x/net/context"
    "google.golang.org/grpc"
    "google.golang.org/grpc/codes"
    "google.golang.org/grpc/metadata"
    "google.golang.org/grpc/status"
)

func unaryServerInterceptors(interceptors ...grpc.UnaryServerInterceptor) grpc.UnaryServerInterceptor {
    return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
        wrap := func(current grpc.UnaryServerInterceptor, next grpc.UnaryHandler) grpc.UnaryHandler {
            return func(currentCtx context.Context, currentReq interface{}) (interface{}, error) {
                return current(currentCtx, currentReq, info, next)
            }
        }
        chain := handler
        for _, i := range interceptors {
            chain = wrap(i, chain)
        }
        return chain(ctx, req)
    }
}

func unaryClientInterceptors(interceptors ...grpc.UnaryClientInterceptor) grpc.UnaryClientInterceptor {
    return func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
        buildChain := func(current grpc.UnaryClientInterceptor, next grpc.UnaryInvoker) grpc.UnaryInvoker {
            return func(currentCtx context.Context, currentMethod string, currentReq, currentReply interface{}, currentCC *grpc.ClientConn, currentOpts ...grpc.CallOption) error {
                return current(currentCtx, currentMethod, currentReq, currentReply, currentCC, next, currentOpts...)
            }
        }
        chain := invoker
        for _, i := range interceptors {
            chain = buildChain(i, chain)
        }
        return chain(ctx, method, req, reply, cc, opts...)
    }
}

func errorInterceptor(log *zap.Logger) func(context.Context, interface{}, *grpc.UnaryServerInfo, grpc.UnaryHandler) (interface{}, error) {
    {
        return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
            now := time.Now()

            if md, ok := metadata.FromIncomingContext(ctx); ok {
                ctx = metadata.NewOutgoingContext(ctx, metadata.MD{
                    mnemosyne.AccessTokenMetadataKey: md[mnemosyne.AccessTokenMetadataKey],
                    "request_id":                     md["request_id"],
                })
            }

            res, err := handler(ctx, req)

            code := status.Code(err)
            if err != nil && code != codes.OK {
                if code == codes.Unknown {
                    switch err {
                    case sql.ErrNoRows:
                        code = codes.NotFound
                    default:
                        if pqerr, ok := err.(*pq.Error); ok {
                            switch pqerr.Code {
                            case pq.ErrorCode("57014"):
                                code = codes.Canceled
                            }
                        } else {
                            code = codes.Internal
                        }
                    }
                }
                loggerBackground(ctx, log).Error("request failure",
                    zap.String("error", status.Convert(err).Message()),
                    logger.Ctx(ctx, info, code),
                )

                switch err {
                case errMissingAccessToken, errMissingSession, errMissingSubjectID:
                    return nil, err
                case storage.ErrSessionNotFound:
                    return nil, status.Errorf(codes.NotFound, "mnemosyned: %s", err.Error())
                case storage.ErrMissingAccessToken, storage.ErrMissingSession, storage.ErrMissingSubjectID:
                    return nil, status.Errorf(codes.InvalidArgument, "mnemosyned: %s", err.Error())
                default:
                    return nil, status.Errorf(status.Code(err), "mnemosyned: %s", status.Convert(err).Message())
                }
            }

            loggerBackground(ctx, log).Debug("request handled successfully",
                logger.Ctx(ctx, info, codes.OK),
                zap.Duration("elapsed", time.Since(now)),
            )
            return res, err
        }
    }
}

func loggerBackground(ctx context.Context, log *zap.Logger, fields ...zapcore.Field) *zap.Logger {
    l := log.With(fields...)
    if md, ok := metadata.FromIncomingContext(ctx); ok {
        if rid, ok := md["request_id"]; ok && len(rid) >= 1 {
            l = l.With(zap.String("request_id", rid[0]))
        }
    }
    return l
}