utils/filesystem/zip.go
package filesystem
import (
"archive/zip"
"context"
"fmt"
"net/http"
"os"
"path/filepath"
"strings"
"unicode/utf8"
"go.uber.org/atomic"
"golang.org/x/text/encoding/unicode"
"github.com/ARM-software/golang-utils/utils/charset"
"github.com/ARM-software/golang-utils/utils/collection"
"github.com/ARM-software/golang-utils/utils/commonerrors"
"github.com/ARM-software/golang-utils/utils/parallelisation"
"github.com/ARM-software/golang-utils/utils/safeio"
)
const (
zipExt = ".zip"
sevenzipExt = ".7z"
sevenzipmacExt = ".s7z"
gzipExt = ".gz"
lzipExt = ".lz"
zipxExt = ".zipx"
targzExt = ".tar.gz"
targz2Ext = ".tgz"
xzExt = ".xz"
lzmaExt = ".lzma"
rzipExt = ".rz"
packExt = ".pack"
compressExt = ".z"
jarExt = ".jar"
zipMimeType = "application/zip"
zipxMimeType = "application/x-zip"
zipCompressedMimeType = "application/x-zip-compressed"
jarMimeType = "application/jar"
epubMimeType = "application/epub+zip"
)
var (
// ZipFileExtensions returns a list of commonly used extensions to describe zip archive files
// This list was populated from [Wikipedia](https://en.wikipedia.org/wiki/List_of_archive_formats)
ZipFileExtensions = []string{zipExt, zipxExt, sevenzipExt, sevenzipmacExt, gzipExt, targzExt, targz2Ext, xzExt, lzipExt, lzmaExt, rzipExt, packExt, compressExt, jarExt}
// ZipMimeTypes returns a list of MIME types describing zip archive files.
ZipMimeTypes = []string{zipMimeType, zipxMimeType, zipCompressedMimeType, jarMimeType, epubMimeType}
)
// Zip zips a source directory into a destination archive
func Zip(source string, destination string) error {
return globalFileSystem.Zip(source, destination)
}
func (fs *VFS) Zip(source, destination string) error {
return fs.ZipWithContext(context.Background(), source, destination)
}
func (fs *VFS) ZipWithContext(ctx context.Context, source, destination string) (err error) {
return fs.ZipWithContextAndLimits(ctx, source, destination, NoLimits())
}
func (fs *VFS) ZipWithContextAndLimits(ctx context.Context, source, destination string, limits ILimits) error {
return fs.ZipWithContextAndLimitsAndExclusionPatterns(ctx, source, destination, limits)
}
func (fs *VFS) ZipWithContextAndLimitsAndExclusionPatterns(ctx context.Context, source string, destination string, limits ILimits, exclusionPatterns ...string) (err error) {
if limits == nil {
err = fmt.Errorf("%w: missing file system limits", commonerrors.ErrUndefined)
return
}
err = parallelisation.DetermineContextError(ctx)
if err != nil {
return
}
file, err := fs.CreateFile(destination)
if err != nil {
return
}
defer func() { _ = file.Close() }()
// create a new zip archive
w := zip.NewWriter(file)
defer func() { _ = w.Close() }()
walker := func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if limits.Apply() && info.Size() > limits.GetMaxFileSize() {
err = fmt.Errorf("%w: file [%v] is too big (%v B) and beyond limits (max: %v B)", commonerrors.ErrTooLarge, path, info.Size(), limits.GetMaxFileSize())
return err
}
// Get the relative path
relPath, err := filepath.Rel(source, path)
if err != nil {
return err
}
// If directory
if info.IsDir() {
if path == source {
return nil
}
header := &zip.FileHeader{
Name: relPath + "/",
Method: zip.Deflate,
Modified: info.ModTime(),
}
_, err = w.CreateHeader(header)
return err
}
// if file
src, err := fs.GenericOpen(path)
if err != nil {
return err
}
defer func() { _ = src.Close() }()
// create file in archive
relPath, err = filepath.Rel(source, path)
if err != nil {
return err
}
header := &zip.FileHeader{
Name: relPath,
Method: zip.Deflate,
Modified: info.ModTime(),
}
dest, err := w.CreateHeader(header)
if err != nil {
return err
}
n, err := safeio.CopyDataWithContext(ctx, src, dest)
if err != nil {
return err
}
if info.Size() != n && !IsSymLink(info) {
return fmt.Errorf("could not write the full file [%v] content (wrote %v/%v bytes)", relPath, n, info.Size())
}
return nil
}
err = fs.WalkWithContextAndExclusionPatterns(ctx, source, walker, exclusionPatterns...)
if limits.Apply() {
stat, subErr := file.Stat()
if subErr != nil {
return subErr
}
if stat.Size() > limits.GetMaxFileSize() {
subErr = fmt.Errorf("%w: file [%v] is too big (%v B) and beyond limits (max: %v B)", commonerrors.ErrTooLarge, destination, stat.Size(), limits.GetMaxFileSize())
return subErr
}
}
return
}
// Prevents any ZipSlip (files outside extraction dirPath) https://snyk.io/research/zip-slip-vulnerability#go
func sanitiseZipExtractPath(fs FS, filePath string, destination string) (destPath string, err error) {
destPath = filepath.Join(destination, filePath) // join cleans the destpath so we can check for ZipSlip
if destPath == destination {
return
}
if strings.HasPrefix(destPath, fmt.Sprintf("%v%v", destination, string(fs.PathSeparator()))) {
return
}
if strings.HasPrefix(destPath, fmt.Sprintf("%v/", destination)) {
return
}
err = fmt.Errorf("%w: zipslip security breach detected, file dirPath '%s' not in destination directory '%s'", commonerrors.ErrMalicious, filePath, destination)
return
}
// Unzip unzips an source archive file into destination.
func Unzip(source, destination string) ([]string, error) {
return globalFileSystem.Unzip(source, destination)
}
func (fs *VFS) Unzip(source, destination string) ([]string, error) {
return fs.UnzipWithContext(context.Background(), source, destination)
}
func (fs *VFS) UnzipWithContext(ctx context.Context, source string, destination string) (fileList []string, err error) {
fileList, _, _, err = fs.unzip(ctx, source, destination, NoLimits(), 0)
return
}
// UnzipWithContextAndLimits unzips an source archive file into destination.
func UnzipWithContextAndLimits(ctx context.Context, source string, destination string, limits ILimits) ([]string, error) {
return globalFileSystem.UnzipWithContextAndLimits(ctx, source, destination, limits)
}
func (fs *VFS) UnzipWithContextAndLimits(ctx context.Context, source string, destination string, limits ILimits) (fileList []string, err error) {
fileList, _, _, err = fs.unzip(ctx, source, destination, limits, 0)
return
}
func newZipReader(fs FS, source string, limits ILimits, currentDepth int64) (zipReader *zip.Reader, file File, err error) {
if fs == nil {
err = fmt.Errorf("%w: missing file system", commonerrors.ErrUndefined)
return
}
if limits == nil {
err = fmt.Errorf("%w: missing file system limits", commonerrors.ErrUndefined)
return
}
if limits.Apply() && limits.GetMaxDepth() >= 0 && currentDepth > limits.GetMaxDepth() {
err = fmt.Errorf("%w: depth [%v] of zip file [%v] is beyond allowed limits (max: %v)", commonerrors.ErrTooLarge, currentDepth, source, limits.GetMaxDepth())
return
}
if !fs.Exists(source) {
err = fmt.Errorf("%w: could not find archive [%v]", commonerrors.ErrNotFound, source)
return
}
info, err := fs.Lstat(source)
if err != nil {
return
}
file, err = fs.GenericOpen(source)
if err != nil {
return
}
zipFileSize := info.Size()
if limits.Apply() && zipFileSize > limits.GetMaxFileSize() {
err = fmt.Errorf("%w: zip file [%v] is too big (%v B) and beyond limits (max: %v B)", commonerrors.ErrTooLarge, source, zipFileSize, limits.GetMaxFileSize())
return
}
zipReader, err = zip.NewReader(file, zipFileSize)
err = convertZipError(err)
return
}
func (fs *VFS) unzip(ctx context.Context, source string, destination string, limits ILimits, currentDepth int64) (fileList []string, fileOnDiskCount uint64, sizeOnDisk uint64, err error) {
err = parallelisation.DetermineContextError(ctx)
if err != nil {
return
}
fileCounter := atomic.NewUint64(0)
// List of file paths to return
totalSizeOnDisk := atomic.NewUint64(0)
zipReader, f, err := newZipReader(fs, source, limits, currentDepth)
defer func() {
if f != nil {
_ = f.Close()
}
}()
if err != nil {
return
}
// Clean the destination to find shortest dirPath
destination = filepath.Clean(destination)
err = fs.MkDir(destination)
if err != nil {
return
}
directoryInfo := map[string]os.FileInfo{}
// For each file in the zip file
for i := range zipReader.File {
zippedFile := zipReader.File[i]
subErr := parallelisation.DetermineContextError(ctx)
if subErr != nil {
return fileList, fileCounter.Load(), totalSizeOnDisk.Load(), subErr
}
// Calculate file dirPath
filePath, subErr := sanitiseZipExtractPath(fs, zippedFile.Name, destination)
if subErr != nil {
return fileList, fileCounter.Load(), totalSizeOnDisk.Load(), subErr
}
var fileDepth int64
if limits.Apply() && limits.GetMaxDepth() >= 0 {
depth, subErr := FileTreeDepth(fs, destination, filePath)
fileDepth = depth + currentDepth
if subErr != nil {
return fileList, fileCounter.Load(), totalSizeOnDisk.Load(), subErr
}
if fileDepth > limits.GetMaxDepth() {
subErr = fmt.Errorf("%w: depth [%v] of file [%v] within zip [%v] is beyond allowed limits (max: %v)", commonerrors.ErrTooLarge, fileDepth, filepath.Base(filePath), filepath.Base(source), limits.GetMaxDepth())
return fileList, fileCounter.Load(), totalSizeOnDisk.Load(), subErr
}
}
// record unzipped files (except zip files if they get unzipped later)
if !(limits.ApplyRecursively() && fs.isZipWithContext(ctx, zippedFile.Name)) {
fileCounter.Inc()
fileList = append(fileList, filePath)
}
if zippedFile.FileInfo().IsDir() {
// Create directory
subErr = fs.MkDir(filePath)
if subErr != nil {
return fileList, fileCounter.Load(), totalSizeOnDisk.Load(), fmt.Errorf("unable to create directory [%s]: %w", filePath, subErr)
}
// recording directory dirInfo to preserve timestamps
directoryInfo[filePath] = zippedFile.FileInfo()
// Nothing more to do for a directory, move to next zip file
continue
}
// If a file create the dirPath into which to write the file
directoryPath := filepath.Dir(filePath)
subErr = fs.MkDir(directoryPath)
if subErr != nil {
return fileList, fileCounter.Load(), totalSizeOnDisk.Load(), fmt.Errorf("unable to create directory '%s': %w", directoryPath, subErr)
}
fileSizeOnDisk, subErr := fs.unzipZippedFile(ctx, filePath, zippedFile, limits, fileDepth)
if subErr != nil {
return fileList, fileCounter.Load(), totalSizeOnDisk.Load(), subErr
}
// If the copied file is a zip, unzip that zip if the action is marked as recursive
if limits.ApplyRecursively() {
if fs.isZipWithContext(ctx, filePath) {
nestedUnzippedFiles, filesOnDiskCount, filesSizeOnDisk, subErr := fs.unzipNestedZipFiles(ctx, filePath, limits, fileDepth)
if subErr != nil {
return fileList, fileCounter.Load(), totalSizeOnDisk.Load(), subErr
}
totalSizeOnDisk.Add(filesSizeOnDisk)
fileCounter.Add(filesOnDiskCount)
fileList = append(fileList, nestedUnzippedFiles...)
} else {
if fs.isZipWithContext(ctx, zippedFile.Name) { // If not an actual zip file but with a zip name.
fileCounter.Inc()
fileList = append(fileList, filePath)
}
totalSizeOnDisk.Add(uint64(fileSizeOnDisk))
}
} else {
totalSizeOnDisk.Add(uint64(fileSizeOnDisk))
}
if limits.Apply() && totalSizeOnDisk.Load() > limits.GetMaxTotalSize() {
return fileList, fileCounter.Load(), totalSizeOnDisk.Load(), fmt.Errorf("%w: more than %v B of disk space was used while unzipping %v (%v B used already)", commonerrors.ErrTooLarge, limits.GetMaxTotalSize(), source, totalSizeOnDisk.Load())
}
if limits.Apply() && int64(fileCounter.Load()) > limits.GetMaxFileCount() {
return fileList, fileCounter.Load(), totalSizeOnDisk.Load(), fmt.Errorf("%w: more than %v files were created while unzipping %v (%v files created already)", commonerrors.ErrTooLarge, limits.GetMaxFileCount(), source, fileCounter.Load())
}
}
// Ensuring directory timestamps are preserved (this needs to be done after all the files have been created).
err = preserveDirectoriesTimestamps(ctx, fs, directoryInfo)
if err != nil {
return fileList, fileCounter.Load(), totalSizeOnDisk.Load(), err
}
return fileList, fileCounter.Load(), totalSizeOnDisk.Load(), nil
}
func (fs *VFS) unzipNestedZipFiles(ctx context.Context, nestedZipFile string, limits ILimits, currentDepth int64) (nestedUnzippedFiles []string, fileOnDiskCount uint64, filesSizeOnDisk uint64, err error) {
destination := filepath.Join(filepath.Dir(nestedZipFile), FilepathStem(nestedZipFile))
nestedUnzippedFiles, fileOnDiskCount, filesSizeOnDisk, subErr := fs.unzip(ctx, nestedZipFile, destination, limits, currentDepth+1)
if subErr != nil {
err = fmt.Errorf("unable to unzip nested zip [%s] present at depth (%d) to [%s] : %w", filepath.Base(nestedZipFile), currentDepth, destination, subErr)
return
}
subErr = fs.Rm(nestedZipFile)
if subErr != nil {
err = fmt.Errorf("unable to remove nested zip [%s] : %w", nestedZipFile, subErr)
}
return
}
func preserveDirectoriesTimestamps(ctx context.Context, fs FS, directoryInfo map[string]os.FileInfo) error {
for dirPath, dirInfo := range directoryInfo {
subErr := parallelisation.DetermineContextError(ctx)
if subErr != nil {
return subErr
}
times := newDefaultTimeInfo(dirInfo)
subErr = fs.Chtimes(dirPath, times.AccessTime(), times.ModTime())
if subErr != nil {
return fmt.Errorf("unable to set directory timestamp [%s]: %w", dirPath, subErr)
}
}
return nil
}
// unzipZippedFile unzips file to destination directory
func (fs *VFS) unzipZippedFile(ctx context.Context, dest string, zippedFile *zip.File, limits ILimits, currentDepth int64) (fileSizeOnDisk int64, err error) {
err = parallelisation.DetermineContextError(ctx)
if err != nil {
return
}
if limits.Apply() && limits.GetMaxDepth() > 0 && currentDepth > limits.GetMaxDepth() {
err = fmt.Errorf("%w: depth [%v] of zipped file [%v] is beyond allowed limits (max: %v)", commonerrors.ErrTooLarge, currentDepth, zippedFile.Name, limits.GetMaxDepth())
return
}
destinationPath, err := determineUnzippedFilepath(dest)
if err != nil {
return
}
destinationFile, err := fs.OpenFile(destinationPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, zippedFile.Mode())
if err != nil {
err = fmt.Errorf("%w: unable to open file '%s': %v", commonerrors.ErrUnexpected, destinationPath, err.Error())
return
}
defer func() { _ = destinationFile.Close() }()
sourceFile, err := zippedFile.Open()
if err != nil {
err = fmt.Errorf("%w: unable to open zipped file '%s': %v", commonerrors.ErrUnsupported, zippedFile.Name, err.Error())
return
}
defer func() { _ = sourceFile.Close() }()
info := zippedFile.FileInfo()
fileSizeOnDisk = info.Size()
if limits.Apply() {
if fileSizeOnDisk > limits.GetMaxFileSize() {
err = fmt.Errorf("%w: zipped file [%v] is too big (%v B) and above max size (%v B)", commonerrors.ErrTooLarge, info.Name(), info.Size(), limits.GetMaxFileSize())
return
}
}
_, err = safeio.CopyNWithContext(ctx, sourceFile, destinationFile, fileSizeOnDisk)
if err != nil {
err = fmt.Errorf("copy of zipped file to '%s' failed: %w", destinationPath, err)
return
}
err = destinationFile.Close()
if err != nil {
return
}
// Ensuring the timestamp is preserved.
times := newDefaultTimeInfo(info)
err = fs.Chtimes(destinationPath, times.AccessTime(), times.ModTime())
return
}
func determineUnzippedFilepath(destinationPath string) (string, error) {
// See https://go-review.googlesource.com/c/go/+/75592/
// Character encodings other than CP-437 and UTF-8
// are not officially supported by the ZIP specification, pragmatically
// the world has permitted use of them.
//
// When a non-standard encoding is used, it is the user's responsibility
// to ensure that the target system is expecting the encoding used
// (e.g., producing a ZIP file you know is used on a Chinese version of Windows).
if utf8.ValidString(destinationPath) {
return destinationPath, nil
}
// Nonetheless, instead of raising an error when non-UTF8 characters are present in filepath,
// we try to guess the encoding and then, convert the string to UTF-8.
encoding, charsetName, err := charset.DetectTextEncoding([]byte(destinationPath))
if err != nil {
return "", fmt.Errorf("%w: file path [%s] is not a valid utf-8 string and charset could not be detected: %v", commonerrors.ErrInvalid, destinationPath, err.Error())
}
convertedDestinationPath, err := charset.IconvString(destinationPath, encoding, unicode.UTF8)
if err != nil {
return "", fmt.Errorf("%w: file path [%s] is encoded using charset [%v] but could not be converted to valid utf-8: %v", commonerrors.ErrUnexpected, destinationPath, charsetName, err.Error())
// If zip file paths must be accepted even when their encoding is unknown, or conversion to utf-8 failed, then the following can be done.
// destinationPath = strings.ToValidUTF8(dest, charset.InvalidUTF8CharacterReplacement)
}
return convertedDestinationPath, err
}
func (fs *VFS) IsZip(path string) bool {
return fs.isZipWithContext(context.Background(), path)
}
func (fs *VFS) isZipWithContext(ctx context.Context, path string) bool {
found, _ := fs.IsZipWithContext(ctx, path)
return found
}
func (fs *VFS) IsZipWithContext(ctx context.Context, path string) (ok bool, err error) {
if path == "" {
err = fmt.Errorf("%w: missing path", commonerrors.ErrUndefined)
return
}
err = parallelisation.DetermineContextError(ctx)
if err != nil {
return
}
_, found := collection.Find(&ZipFileExtensions, strings.ToLower(filepath.Ext(path)))
if !found || err != nil {
return
}
if !fs.Exists(path) {
ok = found
return
}
f, err := fs.GenericOpen(path)
if err != nil {
return
}
defer func() { _ = f.Close() }()
content, err := safeio.ReadAll(ctx, f)
if err != nil {
return
}
mime := http.DetectContentType(content)
_, ok = collection.Find(&ZipMimeTypes, mime)
return
}
func convertZipError(err error) error {
if err == nil {
return nil
}
err = commonerrors.ConvertContextError(err)
if commonerrors.Any(err, zip.ErrFormat, zip.ErrChecksum) {
return fmt.Errorf("%w: %v", commonerrors.ErrInvalid, err.Error())
}
if commonerrors.Any(err, zip.ErrFormat, zip.ErrAlgorithm) {
return fmt.Errorf("%w: %v", commonerrors.ErrUnsupported, err.Error())
}
return err
}