main.go
/*
* The Clammit application intercepts HTTP POST/PATCH/PUT requests, forwards any
* "file" form-data elements to ClamAV and only forwards the request to the
* application if ClamAV passes all of these elements as virus-free.
*/
package main
import (
"bytes"
"clammit/forwarder"
"clammit/scanner"
"encoding/json"
"flag"
"fmt"
"log"
"net"
"net/http"
"net/url"
"os"
"os/signal"
"runtime"
"strconv"
"strings"
"syscall"
"gopkg.in/gcfg.v1"
)
/* This is for Go Releaser.
* https://github.com/goreleaser/goreleaser#a-note-about-mainversion
*/
var version = "master"
// Configuration structure, designed for gcfg
type Config struct {
App ApplicationConfig `gcfg:"application"`
}
type ApplicationConfig struct {
// The address to listen on. This can be one of:
// * tcp:host:port
// * tcp:port
// * unix:filename
// * host:port
// * :port
//
// For example:
// Listen: tcp:0.0.0.0:8438
// Listen: unix:/tmp/clammit.sock
// Listen: :8438
Listen string `gcfg:"listen"`
// Socket file permissions (only used if listening on a unix socket), in octal form.
//
// For example:
// SocketPerms: 0766
SocketPerms string `gcfg:"unix-socket-perms"`
// The URL of the application that Clammit is proxying. Generally, this will
// be the base URL (http://host:port/), but you can also add a path prefix
// if needed (http://host:port/prefix)
ApplicationURL string `gcfg:"application-url"`
// The URL of clamd, which will either be TCP or Unix:
//
// For example:
// ClamdURL: tcp://localhost:3310
// ClamdURL: unix:/tmp/clamd.sock
ClamdURL string `gcfg:"clamd-url"`
// The HTTP status code to return when a virus is found
VirusStatusCode int `gcfg:"virus-status-code"`
// If the body content-length exceeds this value, it will be written to
// disk. Below it, we'll hold the whole body in memory to improve speed.
ContentMemoryThreshold int64 `gcfg:"content-memory-threshold"`
// Log file name (default is to log to stdout)
Logfile string `gcfg:"log-file"`
// If true, clammit will expose a small test HTML page.
TestPages bool `gcfg:"test-pages"`
// If true, will log the progression of each request through the forwarder
Debug bool `gcfg:"debug"`
// Number of CPU threads to use
NumThreads int `gcfg:"num-threads"`
}
// Default configuration
var DefaultApplicationConfig = ApplicationConfig{
Listen: ":8438",
SocketPerms: "0777",
ApplicationURL: "",
ClamdURL: "",
VirusStatusCode: 418,
ContentMemoryThreshold: 1024 * 1024,
Logfile: "",
TestPages: true,
Debug: false,
NumThreads: runtime.NumCPU(),
}
// Application context
type Ctx struct {
Config Config
ApplicationURL *url.URL
ScanInterceptor *ScanInterceptor
Scanner scanner.Scanner
Logger *log.Logger
Listener net.Listener
ActivityChan chan int
ShuttingDown bool
}
// JSON server information response
type Info struct {
Version string `json:"clammit_version"`
Address string `json:"scan_server_url"`
PingResult string `json:"ping_result"`
ScannerVersion string `json:"scan_server_version"`
TestScanVirusResult string `json:"test_scan_virus"`
TestScanCleanResult string `json:"test_scan_clean"`
}
// Global variables and config
var ctx *Ctx
var configFile string
var EICAR = []byte(`X5O!P%@AP[4\PZX54(P^)7CC)7}$EICAR-STANDARD-ANTIVIRUS-TEST-FILE!$H+H*`)
func init() {
flag.StringVar(&configFile, "config", "", "Configuration file")
}
func main() {
/*
* Construct configuration, set up logging
*/
constructConfig()
// Socket perms are octal!
socketPerms := 0777
if ctx.Config.App.SocketPerms != "" {
if sp, err := strconv.ParseInt(ctx.Config.App.SocketPerms, 8, 0); err == nil {
socketPerms = int(sp)
} else {
log.Fatalf("SocketPerms invalid (expected 4-digit octal: %s", err.Error())
}
}
// Allow multi-proc
runtime.GOMAXPROCS(ctx.Config.App.NumThreads)
startLogging()
/*
* Construct objects, validate the URLs
*/
ctx.ApplicationURL = checkURL(ctx.Config.App.ApplicationURL)
checkURL(ctx.Config.App.ClamdURL)
ctx.Scanner = new(scanner.Clamav)
ctx.Scanner.SetLogger(ctx.Logger, ctx.Config.App.Debug)
ctx.Scanner.SetAddress(ctx.Config.App.ClamdURL)
ctx.ScanInterceptor = &ScanInterceptor{
VirusStatusCode: ctx.Config.App.VirusStatusCode,
Scanner: ctx.Scanner,
}
/*
* Set up the HTTP server
*/
router := http.NewServeMux()
router.HandleFunc("/clammit", infoHandler)
router.HandleFunc("/clammit/scan", scanHandler)
router.HandleFunc("/clammit/readyz", readyzHandler)
if ctx.Config.App.TestPages {
fs := http.FileServer(http.Dir("testfiles"))
router.Handle("/clammit/test/", http.StripPrefix("/clammit/test/", fs))
}
router.HandleFunc("/", scanForwardHandler)
if listener, err := getListener(ctx.Config.App.Listen, socketPerms); err != nil {
ctx.Logger.Fatal("Unable to listen on: ", ctx.Config.App.Listen, ", reason: ", err)
} else {
ctx.Listener = listener
beGraceful() // graceful shutdown from here on in
ctx.Logger.Println("Listening on", ctx.Config.App.Listen)
http.Serve(listener, router)
}
}
/*
* Returns the value of an environment variable, or a default value
*/
func getEnv(key, fallback string) string {
if value, ok := os.LookupEnv(key); ok {
return value
}
return fallback
}
/*
* Returns the value of an environment variable casted as int, or a default value
*/
func getIntEnv(key string, fallback int) int {
if value, ok := os.LookupEnv(key); ok {
if i, err := strconv.Atoi(value); err == nil {
return i
}
}
return fallback
}
/*
* Returns the value of an environment variable casted as int64, or a default value
*/
func getInt64Env(key string, fallback int64) int64 {
if value, ok := os.LookupEnv(key); ok {
if i, err := strconv.ParseInt(value, 10, 64); err == nil {
return i
}
}
return fallback
}
/*
* Returns the value of an environment variable casted as boolean, or a default value
*/
func getBoolEnv(key string, fallback bool) bool {
if value, ok := os.LookupEnv(key); ok {
if b, err := strconv.ParseBool(value); err == nil {
return b
}
}
return fallback
}
/*
* Sets the configuration from the file and environment variables
*/
func constructConfig() {
flag.Parse()
ctx = &Ctx{
ActivityChan: make(chan int),
ShuttingDown: false,
}
ctx.Config.App = DefaultApplicationConfig
// Read the configuration file if configfile is set
if configFile != "" {
if err := gcfg.ReadFileInto(&ctx.Config, configFile); err != nil {
log.Fatalf("Configuration read failure: %s", err.Error())
}
}
// Check for environmant variables to overwrite config
ctx.Config.App.Listen = getEnv("CLAMMIT_LISTEN", ctx.Config.App.Listen)
ctx.Config.App.SocketPerms = getEnv("CLAMMIT_SOCKET_PERMS", ctx.Config.App.SocketPerms)
ctx.Config.App.ApplicationURL = getEnv("CLAMMIT_APPLICATION_URL", ctx.Config.App.ApplicationURL)
ctx.Config.App.ClamdURL = getEnv("CLAMMIT_CLAMD_URL", ctx.Config.App.ClamdURL)
ctx.Config.App.VirusStatusCode = getIntEnv("CLAMMIT_VIRUS_STATUS_CODE", ctx.Config.App.VirusStatusCode)
ctx.Config.App.ContentMemoryThreshold = getInt64Env("CLAMMIT_CONTENT_MEMORY_THRESHOLD", ctx.Config.App.ContentMemoryThreshold)
ctx.Config.App.Logfile = getEnv("CLAMMIT_LOGFILE", ctx.Config.App.Logfile)
ctx.Config.App.TestPages = getBoolEnv("CLAMMIT_TEST_PAGES", ctx.Config.App.TestPages)
ctx.Config.App.Debug = getBoolEnv("CLAMMIT_DEBUG", ctx.Config.App.Debug)
ctx.Config.App.NumThreads = getIntEnv("CLAMMIT_NUM_THREADS", ctx.Config.App.NumThreads)
}
/*
* Starts logging
*/
func startLogging() {
if ctx.Config.App.Logfile != "" {
w, err := os.OpenFile(ctx.Config.App.Logfile, os.O_CREATE|os.O_WRONLY|os.O_APPEND, 0660)
if err == nil {
ctx.Logger = log.New(w, "", log.LstdFlags)
} else {
log.Fatal("Failed to open log file", ctx.Config.App.Logfile, ":", err)
}
} else {
ctx.Logger = log.New(os.Stdout, "", log.LstdFlags)
ctx.Logger.Println("No log file configured - using stdout")
}
}
/*
* Handles graceful shutdown. Sets ctx.ShuttingDown = true to stop any new
* requests, then waits for active requests to complete before closing the
* HTTP listener.
*/
func beGraceful() {
sigchan := make(chan os.Signal)
signal.Notify(sigchan, syscall.SIGINT, syscall.SIGTERM)
go func() {
activity := 0
for {
select {
case _ = <-sigchan:
ctx.Logger.Println("Received termination signal")
ctx.ShuttingDown = true
for activity > 0 {
ctx.Logger.Printf("There are %d active requests, waiting", activity)
i := <-ctx.ActivityChan
activity += i
}
// This will cause main() to continue from http.Serve()
// it will also clean up the unix socket (if relevant)
ctx.Listener.Close()
case i := <-ctx.ActivityChan:
activity += i
}
}
}()
}
/*
* Validates the URL is OK (fatal error if not) and returns it
*/
func checkURL(urlString string) *url.URL {
parsedURL, err := url.Parse(urlString)
if err != nil {
log.Fatal("Invalid URL:", urlString)
}
return parsedURL
}
/*
* Returns a TCP or Unix socket listener, according to the scheme prefix:
*
* unix:/tmp/foo.sock
* tcp::8438
* :8438 - tcp listener
*/
func getListener(address string, socketPerms int) (listener net.Listener, err error) {
if address == "" {
return nil, fmt.Errorf("No listen address specified")
}
if idx := strings.Index(address, ":"); idx >= 0 {
scheme := address[0:idx]
switch scheme {
case "tcp", "tcp4":
path := address[idx+1:]
if strings.Index(path, ":") == -1 {
path = ":" + path
}
listener, err = net.Listen(scheme, path)
case "tcp6": // general form: [host]:port
path := address[idx+1:]
if strings.Index(path, "[") != 0 { // port only
if strings.Index(path, ":") != 0 { // no leading :
path = ":" + path
}
}
listener, err = net.Listen(scheme, path)
case "unix", "unixpacket":
path := address[idx+1:]
if listener, err = net.Listen(scheme, path); err == nil {
os.Chmod(path, os.FileMode(socketPerms))
}
default: // assume TCP4 address
listener, err = net.Listen("tcp", address)
}
} else { // no scheme, port only specified
listener, err = net.Listen("tcp", ":"+address)
}
return listener, err
}
/*
* Handler for /scan
*
* Virus checks file and sends response
*/
func scanHandler(w http.ResponseWriter, req *http.Request) {
if ctx.ShuttingDown {
return
}
ctx.ActivityChan <- 1
defer func() { ctx.ActivityChan <- -1 }()
if !ctx.ScanInterceptor.Handle(w, req, req.Body) {
w.Write([]byte("No virus found"))
}
}
/*
* Handler for scan & forward
*
* Constructs a forwarder and calls it
*/
func scanForwardHandler(w http.ResponseWriter, req *http.Request) {
if ctx.ShuttingDown {
return
}
ctx.ActivityChan <- 1
defer func() { ctx.ActivityChan <- -1 }()
fw := forwarder.NewForwarder(ctx.ApplicationURL, ctx.Config.App.ContentMemoryThreshold, ctx.ScanInterceptor)
fw.SetLogger(ctx.Logger, ctx.Config.App.Debug)
fw.HandleRequest(w, req)
}
/*
* Handler for /info
*
* Validates the Scanner connection
* Emits the information as a JSON response
*/
func infoHandler(w http.ResponseWriter, req *http.Request) {
if ctx.ShuttingDown {
return
}
ctx.ActivityChan <- 1
defer func() { ctx.ActivityChan <- -1 }()
info := &Info{
Address: ctx.Scanner.Address(),
Version: version,
}
if err := ctx.Scanner.Ping(); err != nil {
info.PingResult = err.Error()
} else {
info.PingResult = "Connected to server OK"
if response, err := ctx.Scanner.Version(); err != nil {
info.ScannerVersion = err.Error()
} else {
info.ScannerVersion = response
}
/*
* Validate the Clamd response for a viral string
*/
reader := bytes.NewReader(EICAR)
if result, err := ctx.Scanner.Scan(reader); err != nil {
info.TestScanVirusResult = err.Error()
} else {
info.TestScanVirusResult = result.String()
}
/*
* Validate the Clamd response for a non-viral string
*/
reader = bytes.NewReader([]byte("foo bar mcgrew"))
if result, err := ctx.Scanner.Scan(reader); err != nil {
info.TestScanCleanResult = err.Error()
} else {
info.TestScanCleanResult = result.String()
}
}
// Aaaand return
s, _ := json.Marshal(info)
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(s))
}
/*
* Handler for /clammit/readyz
*
* Returns 200 OK unless we are shutting down. Used in k8s.
* See https://github.com/ifad/clammit/issues/23
*/
func readyzHandler(w http.ResponseWriter, req *http.Request) {
if ctx.ShuttingDown {
w.WriteHeader(503)
} else {
w.WriteHeader(200)
}
}