
View on GitHub


0 mins
Test Coverage
 * The Clammit application intercepts HTTP POST/PATCH/PUT requests, forwards any
 * "file" form-data elements to ClamAV and only forwards the request to the
 * application if ClamAV passes all of these elements as virus-free.
package main

import (


/* This is for Go Releaser.
 * https://github.com/goreleaser/goreleaser#a-note-about-mainversion
var version = "master"

// Configuration structure, designed for gcfg
type Config struct {
    App ApplicationConfig `gcfg:"application"`

type ApplicationConfig struct {
    // The address to listen on. This can be one of:
    // * tcp:host:port
    // * tcp:port
    // * unix:filename
    // * host:port
    // * :port
    // For example:
    //   Listen: tcp:
    //   Listen: unix:/tmp/clammit.sock
    //   Listen: :8438
    Listen string `gcfg:"listen"`
    // Socket file permissions (only used if listening on a unix socket), in octal form.
    // For example:
    //   SocketPerms: 0766
    SocketPerms string `gcfg:"unix-socket-perms"`
    // The URL of the application that Clammit is proxying. Generally, this will
    // be the base URL (http://host:port/), but you can also add a path prefix
    // if needed (http://host:port/prefix)
    ApplicationURL string `gcfg:"application-url"`
    // The URL of clamd, which will either be TCP or Unix:
    // For example:
    //   ClamdURL: tcp://localhost:3310
    //   ClamdURL: unix:/tmp/clamd.sock
    ClamdURL string `gcfg:"clamd-url"`
    // The HTTP status code to return when a virus is found
    VirusStatusCode int `gcfg:"virus-status-code"`
    // If the body content-length exceeds this value, it will be written to
    // disk. Below it, we'll hold the whole body in memory to improve speed.
    ContentMemoryThreshold int64 `gcfg:"content-memory-threshold"`
    // Log file name (default is to log to stdout)
    Logfile string `gcfg:"log-file"`
    // If true, clammit will expose a small test HTML page.
    TestPages bool `gcfg:"test-pages"`
    // If true, will log the progression of each request through the forwarder
    Debug bool `gcfg:"debug"`
    // Number of CPU threads to use
    NumThreads int `gcfg:"num-threads"`

// Default configuration
var DefaultApplicationConfig = ApplicationConfig{
    Listen:                 ":8438",
    SocketPerms:            "0777",
    ApplicationURL:         "",
    ClamdURL:               "",
    VirusStatusCode:        418,
    ContentMemoryThreshold: 1024 * 1024,
    Logfile:                "",
    TestPages:              true,
    Debug:                  false,
    NumThreads:             runtime.NumCPU(),

// Application context
type Ctx struct {
    Config          Config
    ApplicationURL  *url.URL
    ScanInterceptor *ScanInterceptor
    Scanner         scanner.Scanner
    Logger          *log.Logger
    Listener        net.Listener
    ActivityChan    chan int
    ShuttingDown    bool

// JSON server information response
type Info struct {
    Version             string `json:"clammit_version"`
    Address             string `json:"scan_server_url"`
    PingResult          string `json:"ping_result"`
    ScannerVersion      string `json:"scan_server_version"`
    TestScanVirusResult string `json:"test_scan_virus"`
    TestScanCleanResult string `json:"test_scan_clean"`

// Global variables and config
var ctx *Ctx
var configFile string

func init() {
    flag.StringVar(&configFile, "config", "", "Configuration file")

func main() {
     * Construct configuration, set up logging

    // Socket perms are octal!
    socketPerms := 0777
    if ctx.Config.App.SocketPerms != "" {
        if sp, err := strconv.ParseInt(ctx.Config.App.SocketPerms, 8, 0); err == nil {
            socketPerms = int(sp)
        } else {
            log.Fatalf("SocketPerms invalid (expected 4-digit octal: %s", err.Error())

    // Allow multi-proc


     * Construct objects, validate the URLs
    ctx.ApplicationURL = checkURL(ctx.Config.App.ApplicationURL)

    ctx.Scanner = new(scanner.Clamav)
    ctx.Scanner.SetLogger(ctx.Logger, ctx.Config.App.Debug)

    ctx.ScanInterceptor = &ScanInterceptor{
        VirusStatusCode: ctx.Config.App.VirusStatusCode,
        Scanner:         ctx.Scanner,

     * Set up the HTTP server
    router := http.NewServeMux()

    router.HandleFunc("/clammit", infoHandler)
    router.HandleFunc("/clammit/scan", scanHandler)
    router.HandleFunc("/clammit/readyz", readyzHandler)

    if ctx.Config.App.TestPages {
        fs := http.FileServer(http.Dir("testfiles"))
        router.Handle("/clammit/test/", http.StripPrefix("/clammit/test/", fs))
    router.HandleFunc("/", scanForwardHandler)

    if listener, err := getListener(ctx.Config.App.Listen, socketPerms); err != nil {
        ctx.Logger.Fatal("Unable to listen on: ", ctx.Config.App.Listen, ", reason: ", err)
    } else {
        ctx.Listener = listener
        beGraceful() // graceful shutdown from here on in
        ctx.Logger.Println("Listening on", ctx.Config.App.Listen)
        http.Serve(listener, router)

 * Returns the value of an environment variable, or a default value
func getEnv(key, fallback string) string {
    if value, ok := os.LookupEnv(key); ok {
        return value
    return fallback

 * Returns the value of an environment variable casted as int, or a default value
func getIntEnv(key string, fallback int) int {
    if value, ok := os.LookupEnv(key); ok {
        if i, err := strconv.Atoi(value); err == nil {
            return i
    return fallback

 * Returns the value of an environment variable casted as int64, or a default value
func getInt64Env(key string, fallback int64) int64 {
    if value, ok := os.LookupEnv(key); ok {
        if i, err := strconv.ParseInt(value, 10, 64); err == nil {
            return i
    return fallback

 * Returns the value of an environment variable casted as boolean, or a default value
func getBoolEnv(key string, fallback bool) bool {
    if value, ok := os.LookupEnv(key); ok {
        if b, err := strconv.ParseBool(value); err == nil {
            return b
    return fallback

 * Sets the configuration from the file and environment variables
func constructConfig() {
    ctx = &Ctx{
        ActivityChan: make(chan int),
        ShuttingDown: false,

    ctx.Config.App = DefaultApplicationConfig

    // Read the configuration file if configfile is set
    if configFile != "" {
        if err := gcfg.ReadFileInto(&ctx.Config, configFile); err != nil {
            log.Fatalf("Configuration read failure: %s", err.Error())

    // Check for environmant variables to overwrite config
    ctx.Config.App.Listen = getEnv("CLAMMIT_LISTEN", ctx.Config.App.Listen)
    ctx.Config.App.SocketPerms = getEnv("CLAMMIT_SOCKET_PERMS", ctx.Config.App.SocketPerms)
    ctx.Config.App.ApplicationURL = getEnv("CLAMMIT_APPLICATION_URL", ctx.Config.App.ApplicationURL)
    ctx.Config.App.ClamdURL = getEnv("CLAMMIT_CLAMD_URL", ctx.Config.App.ClamdURL)
    ctx.Config.App.VirusStatusCode = getIntEnv("CLAMMIT_VIRUS_STATUS_CODE", ctx.Config.App.VirusStatusCode)
    ctx.Config.App.ContentMemoryThreshold = getInt64Env("CLAMMIT_CONTENT_MEMORY_THRESHOLD", ctx.Config.App.ContentMemoryThreshold)
    ctx.Config.App.Logfile = getEnv("CLAMMIT_LOGFILE", ctx.Config.App.Logfile)
    ctx.Config.App.TestPages = getBoolEnv("CLAMMIT_TEST_PAGES", ctx.Config.App.TestPages)
    ctx.Config.App.Debug = getBoolEnv("CLAMMIT_DEBUG", ctx.Config.App.Debug)
    ctx.Config.App.NumThreads = getIntEnv("CLAMMIT_NUM_THREADS", ctx.Config.App.NumThreads)

 * Starts logging
func startLogging() {
    if ctx.Config.App.Logfile != "" {
        w, err := os.OpenFile(ctx.Config.App.Logfile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0660)
        if err == nil {
            ctx.Logger = log.New(w, "", log.LstdFlags)
        } else {
            log.Fatal("Failed to open log file", ctx.Config.App.Logfile, ":", err)
    } else {
        ctx.Logger = log.New(os.Stdout, "", log.LstdFlags)
        ctx.Logger.Println("No log file configured - using stdout")

 * Handles graceful shutdown. Sets ctx.ShuttingDown = true to stop any new
 * requests, then waits for active requests to complete before closing the
 * HTTP listener.
func beGraceful() {
    sigchan := make(chan os.Signal)
    signal.Notify(sigchan, syscall.SIGINT, syscall.SIGTERM)
    go func() {
        activity := 0
        for {
            select {
            case _ = <-sigchan:
                ctx.Logger.Println("Received termination signal")
                ctx.ShuttingDown = true
                for activity > 0 {
                    ctx.Logger.Printf("There are %d active requests, waiting", activity)
                    i := <-ctx.ActivityChan
                    activity += i
                // This will cause main() to continue from http.Serve()
                // it will also clean up the unix socket (if relevant)
            case i := <-ctx.ActivityChan:
                activity += i

 * Validates the URL is OK (fatal error if not) and returns it
func checkURL(urlString string) *url.URL {
    parsedURL, err := url.Parse(urlString)
    if err != nil {
        log.Fatal("Invalid URL:", urlString)
    return parsedURL

 * Returns a TCP or Unix socket listener, according to the scheme prefix:
 *   unix:/tmp/foo.sock
 *   tcp::8438
 *   :8438                 - tcp listener
func getListener(address string, socketPerms int) (listener net.Listener, err error) {
    if address == "" {
        return nil, fmt.Errorf("No listen address specified")
    if idx := strings.Index(address, ":"); idx >= 0 {
        scheme := address[0:idx]
        switch scheme {
        case "tcp", "tcp4":
            path := address[idx+1:]
            if strings.Index(path, ":") == -1 {
                path = ":" + path
            listener, err = net.Listen(scheme, path)
        case "tcp6": // general form: [host]:port
            path := address[idx+1:]
            if strings.Index(path, "[") != 0 { // port only
                if strings.Index(path, ":") != 0 { // no leading :
                    path = ":" + path
            listener, err = net.Listen(scheme, path)
        case "unix", "unixpacket":
            path := address[idx+1:]
            if listener, err = net.Listen(scheme, path); err == nil {
                os.Chmod(path, os.FileMode(socketPerms))
        default: // assume TCP4 address
            listener, err = net.Listen("tcp", address)
    } else { // no scheme, port only specified
        listener, err = net.Listen("tcp", ":"+address)
    return listener, err

 * Handler for /scan
 * Virus checks file and sends response
func scanHandler(w http.ResponseWriter, req *http.Request) {
    if ctx.ShuttingDown {
    ctx.ActivityChan <- 1
    defer func() { ctx.ActivityChan <- -1 }()

    if !ctx.ScanInterceptor.Handle(w, req, req.Body) {
        w.Write([]byte("No virus found"))

 * Handler for scan & forward
 * Constructs a forwarder and calls it
func scanForwardHandler(w http.ResponseWriter, req *http.Request) {
    if ctx.ShuttingDown {
    ctx.ActivityChan <- 1
    defer func() { ctx.ActivityChan <- -1 }()

    fw := forwarder.NewForwarder(ctx.ApplicationURL, ctx.Config.App.ContentMemoryThreshold, ctx.ScanInterceptor)
    fw.SetLogger(ctx.Logger, ctx.Config.App.Debug)
    fw.HandleRequest(w, req)

 * Handler for /info
 * Validates the Scanner connection
 * Emits the information as a JSON response
func infoHandler(w http.ResponseWriter, req *http.Request) {
    if ctx.ShuttingDown {
    ctx.ActivityChan <- 1
    defer func() { ctx.ActivityChan <- -1 }()

    info := &Info{
        Address: ctx.Scanner.Address(),
        Version: version,
    if err := ctx.Scanner.Ping(); err != nil {
        info.PingResult = err.Error()
    } else {
        info.PingResult = "Connected to server OK"
        if response, err := ctx.Scanner.Version(); err != nil {
            info.ScannerVersion = err.Error()
        } else {
            info.ScannerVersion = response
         * Validate the Clamd response for a viral string
        reader := bytes.NewReader(EICAR)
        if result, err := ctx.Scanner.Scan(reader); err != nil {
            info.TestScanVirusResult = err.Error()
        } else {
            info.TestScanVirusResult = result.String()
         * Validate the Clamd response for a non-viral string
        reader = bytes.NewReader([]byte("foo bar mcgrew"))
        if result, err := ctx.Scanner.Scan(reader); err != nil {
            info.TestScanCleanResult = err.Error()
        } else {
            info.TestScanCleanResult = result.String()
    // Aaaand return
    s, _ := json.Marshal(info)
    w.Header().Set("Content-Type", "application/json")

 * Handler for /clammit/readyz
 * Returns 200 OK unless we are shutting down. Used in k8s.
 * See https://github.com/ifad/clammit/issues/23
func readyzHandler(w http.ResponseWriter, req *http.Request) {
    if ctx.ShuttingDown {
    } else {