ekristen/aws-nuke

View on GitHub
pkg/awsmod/batch.go

Summary

Maintainability
A
1 hr
Test Coverage
package awsmod

import (
    "bytes"
    "context"
    "fmt"

    "github.com/aws/aws-sdk-go-v2/aws"
    "github.com/aws/aws-sdk-go-v2/service/s3"
    s3types "github.com/aws/aws-sdk-go-v2/service/s3/types"
    "github.com/aws/aws-sdk-go/aws/awserr"
)

const (
    // DefaultBatchSize is the batch size we initialize when constructing a batch delete client.
    // This value is used when calling DeleteObjects. This represents how many objects to delete
    // per DeleteObjects call.
    DefaultBatchSize = 100
)

// BatchError will contain the key and bucket of the object that failed to
// either upload or download.
type BatchError struct {
    Errors  Errors
    code    string
    message string
}

// Errors is a typed alias for a slice of errors to satisfy the error
// interface.
type Errors []Error

func (errs Errors) Error() string {
    buf := bytes.NewBuffer(nil)
    for i, err := range errs {
        buf.WriteString(err.Error())
        if i+1 < len(errs) {
            buf.WriteString("\n")
        }
    }
    return buf.String()
}

// Error will contain the original error, bucket, and key of the operation that failed
// during batch operations.
type Error struct {
    OrigErr error
    Bucket  *string
    Key     *string
}

func newError(err error, bucket, key *string) Error {
    return Error{
        err,
        bucket,
        key,
    }
}

func (err *Error) Error() string {
    origErr := ""
    if err.OrigErr != nil {
        origErr = ":\n" + err.OrigErr.Error()
    }
    return fmt.Sprintf("failed to perform batch operation on %q to %q%s",
        aws.ToString(err.Key),
        aws.ToString(err.Bucket),
        origErr,
    )
}

// NewBatchError will return a BatchError that satisfies the awserr.Error interface.
func NewBatchError(code, message string, err []Error) awserr.Error {
    return &BatchError{
        Errors:  err,
        code:    code,
        message: message,
    }
}

// Code will return the code associated with the batch error.
func (err *BatchError) Code() string {
    return err.code
}

// Message will return the message associated with the batch error.
func (err *BatchError) Message() string {
    return err.message
}

func (err *BatchError) Error() string {
    return awserr.SprintError(err.Code(), err.Message(), "", err.Errors)
}

// OrigErr will return the original error. Which, in this case, will always be nil
// for batched operations.
func (err *BatchError) OrigErr() error {
    return err.Errors
}

// BatchDeleteIterator is an interface that uses the scanner pattern to
// iterate through what needs to be deleted.
type BatchDeleteIterator interface {
    Next() bool
    Err() error
    DeleteObject() BatchDeleteObject
}

// DeleteListIterator is an alternative iterator for the BatchDelete client. This will
// iterate through a list of objects and delete the objects.
//
// Example:
//
//    iter := &s3manager.DeleteListIterator{
//        Client: svc,
//        Input: &s3.ListObjectsInput{
//            Bucket:  aws.String("bucket"),
//            MaxKeys: aws.Int64(5),
//        },
//        Paginator: request.Pagination{
//            NewRequest: func() (*request.Request, error) {
//                var inCpy *ListObjectsInput
//                if input != nil {
//                    tmp := *input
//                    inCpy = &tmp
//                }
//                req, _ := c.ListObjectsRequest(inCpy)
//                return req, nil
//            },
//        },
//    }
//
//    batcher := s3manager.NewBatchDeleteWithClient(svc)
//    if err := batcher.Delete(aws.BackgroundContext(), iter); err != nil {
//        return err
//    }
type DeleteListIterator struct {
    Bucket    *string
    Paginator *s3.ListObjectsV2Paginator
    objects   []s3types.Object
    err       error
}

// NewDeleteListIterator will return a new DeleteListIterator.
func NewDeleteListIterator(
    svc s3.ListObjectsV2APIClient, input *s3.ListObjectsV2Input, opts ...func(*DeleteListIterator),
) BatchDeleteIterator {
    iter := &DeleteListIterator{
        Bucket:    input.Bucket,
        Paginator: s3.NewListObjectsV2Paginator(svc, input),
    }

    for _, opt := range opts {
        opt(iter)
    }
    return iter
}

// Next will use the S3API client to iterate through a list of objects.
func (iter *DeleteListIterator) Next() bool {
    if len(iter.objects) > 0 {
        iter.objects = iter.objects[1:]
        if len(iter.objects) > 0 {
            return true
        }
    }

    if !iter.Paginator.HasMorePages() {
        return false
    }

    page, err := iter.Paginator.NextPage(context.TODO())
    if err != nil {
        iter.err = err
        return false
    }

    iter.objects = page.Contents
    return len(iter.objects) > 0
}

// Err will return the last known error from Next.
func (iter *DeleteListIterator) Err() error {
    return iter.err
}

// DeleteObject will return the current object to be deleted.
func (iter *DeleteListIterator) DeleteObject() BatchDeleteObject {
    return BatchDeleteObject{
        Object: &s3.DeleteObjectInput{
            Bucket: iter.Bucket,
            Key:    iter.objects[0].Key,
        },
    }
}

// DeleteObjectsAPIClient implements the S3.DeleteObjects operation.
type DeleteObjectsAPIClient interface {
    DeleteObjects(context.Context, *s3.DeleteObjectsInput, ...func(*s3.Options)) (*s3.DeleteObjectsOutput, error)
}

// BatchDelete will use the s3 package's service client to perform a batch
// delete.
type BatchDelete struct {
    Client    DeleteObjectsAPIClient
    BatchSize int
}

// NewBatchDeleteWithClient will return a new delete client that can delete a batched amount of
// objects.
//
// Example:
//
//    batcher := s3manager.NewBatchDeleteWithClient(client, size)
//
//    objects := []BatchDeleteObject{
//        {
//            Object:    &s3.DeleteObjectInput {
//                Key: aws.String("key"),
//                Bucket: aws.String("bucket"),
//            },
//        },
//    }
//
//    if err := batcher.Delete(aws.BackgroundContext(), &s3manager.DeleteObjectsIterator{
//        Objects: objects,
//    }); err != nil {
//        return err
//    }
func NewBatchDeleteWithClient(s3client DeleteObjectsAPIClient, options ...func(*BatchDelete)) *BatchDelete {
    svc := &BatchDelete{
        Client:    s3client,
        BatchSize: DefaultBatchSize,
    }

    for _, opt := range options {
        opt(svc)
    }

    return svc
}

// NewBatchDelete will return a new delete client that can delete a batched amount of
// objects.
//
// Example:
//
//    batcher := s3manager.NewBatchDelete(sess, size)
//
//    objects := []BatchDeleteObject{
//        {
//            Object:    &s3.DeleteObjectInput {
//                Key: aws.String("key"),
//                Bucket: aws.String("bucket"),
//            },
//        },
//    }
//
//    if err := batcher.Delete(aws.BackgroundContext(), &s3manager.DeleteObjectsIterator{
//        Objects: objects,
//    }); err != nil {
//        return err
//    }
func NewBatchDelete(c *aws.Config, options ...func(*BatchDelete)) *BatchDelete {
    s3client := s3.NewFromConfig(*c)
    return NewBatchDeleteWithClient(s3client, options...)
}

// BatchDeleteObject is a wrapper object for calling the batch delete operation.
type BatchDeleteObject struct {
    Object *s3.DeleteObjectInput
    // After will run after each iteration during the batch process. This function will
    // be executed whether the request was successful.
    After func() error
}

// DeleteObjectsIterator is an interface that uses the scanner pattern to iterate
// through a series of objects to be deleted.
type DeleteObjectsIterator struct {
    Objects []BatchDeleteObject
    index   int
    inc     bool
}

// Next will increment the default iterators index and ensure that there
// is another object to iterator to.
func (iter *DeleteObjectsIterator) Next() bool {
    if iter.inc {
        iter.index++
    } else {
        iter.inc = true
    }
    return iter.index < len(iter.Objects)
}

// Err will return an error. Since this is just used to satisfy the BatchDeleteIterator interface
// this will only return nil.
func (iter *DeleteObjectsIterator) Err() error {
    return nil
}

// DeleteObject will return the BatchDeleteObject at the current batched index.
func (iter *DeleteObjectsIterator) DeleteObject() BatchDeleteObject {
    object := iter.Objects[iter.index]
    return object
}

// Delete will use the iterator to queue up objects that need to be deleted.
// Once the batch size is met, this will call the deleteBatch function.
func (d *BatchDelete) Delete(ctx context.Context, iter BatchDeleteIterator, opts ...func(input *s3.DeleteObjectsInput)) error {
    var errs []Error
    var objects []BatchDeleteObject
    var input *s3.DeleteObjectsInput

    for iter.Next() {
        o := iter.DeleteObject()

        if input == nil {
            input = initDeleteObjectsInput(o.Object)
        }

        for _, opt := range opts {
            opt(input)
        }

        parity := hasParity(input, o)
        if parity {
            input.Delete.Objects = append(input.Delete.Objects, s3types.ObjectIdentifier{
                Key:       o.Object.Key,
                VersionId: o.Object.VersionId,
            })
            objects = append(objects, o)
        }

        if len(input.Delete.Objects) == d.BatchSize || !parity {
            if err := deleteBatch(ctx, d, input, objects); err != nil {
                errs = append(errs, err...)
            }

            objects = objects[:0]
            input = nil

            if !parity {
                objects = append(objects, o)
                input = initDeleteObjectsInput(o.Object)

                for _, opt := range opts {
                    opt(input)
                }

                input.Delete.Objects = append(input.Delete.Objects, s3types.ObjectIdentifier{
                    Key:       o.Object.Key,
                    VersionId: o.Object.VersionId,
                })
            }
        }
    }

    // iter.Next() could return false (above) plus populate iter.Err()
    if iter.Err() != nil {
        errs = append(errs, newError(iter.Err(), nil, nil))
    }

    if input != nil && len(input.Delete.Objects) > 0 {
        if err := deleteBatch(ctx, d, input, objects); err != nil {
            errs = append(errs, err...)
        }
    }

    if len(errs) > 0 {
        return NewBatchError("BatchedDeleteIncomplete", "some objects have failed to be deleted.", errs)
    }
    return nil
}

func initDeleteObjectsInput(o *s3.DeleteObjectInput) *s3.DeleteObjectsInput {
    return &s3.DeleteObjectsInput{
        Bucket:       o.Bucket,
        MFA:          o.MFA,
        RequestPayer: o.RequestPayer,
        Delete:       &s3types.Delete{},
    }
}

const (
    // ErrDeleteBatchFailCode represents an error code which will be returned
    // only when DeleteObjects.Errors has an error that does not contain a code.
    ErrDeleteBatchFailCode       = "DeleteBatchError"
    errDefaultDeleteBatchMessage = "failed to delete"
)

// deleteBatch will delete a batch of items in the objects parameters.
func deleteBatch(ctx context.Context, d *BatchDelete, input *s3.DeleteObjectsInput, objects []BatchDeleteObject) []Error {
    var errs []Error

    if result, err := d.Client.DeleteObjects(ctx, input); err != nil {
        for i := 0; i < len(input.Delete.Objects); i++ {
            errs = append(errs, newError(err, input.Bucket, input.Delete.Objects[i].Key))
        }
    } else if len(result.Errors) > 0 {
        for i := 0; i < len(result.Errors); i++ {
            code := ErrDeleteBatchFailCode
            msg := errDefaultDeleteBatchMessage
            if result.Errors[i].Message != nil {
                msg = *result.Errors[i].Message
            }
            if result.Errors[i].Code != nil {
                code = *result.Errors[i].Code
            }

            errs = append(errs, newError(awserr.New(code, msg, err), input.Bucket, result.Errors[i].Key))
        }
    }
    for _, object := range objects {
        if object.After == nil {
            continue
        }
        if err := object.After(); err != nil {
            errs = append(errs, newError(err, object.Object.Bucket, object.Object.Key))
        }
    }

    return errs
}

func hasParity(o1 *s3.DeleteObjectsInput, o2 BatchDeleteObject) bool {
    if o1.Bucket != nil && o2.Object.Bucket != nil {
        if *o1.Bucket != *o2.Object.Bucket {
            return false
        }
    } else if o1.Bucket != o2.Object.Bucket {
        return false
    }

    if o1.MFA != nil && o2.Object.MFA != nil {
        if *o1.MFA != *o2.Object.MFA {
            return false
        }
    } else if o1.MFA != o2.Object.MFA {
        return false
    }

    if o1.RequestPayer != "" && o2.Object.RequestPayer != "" {
        if o1.RequestPayer != o2.Object.RequestPayer {
            return false
        }
    } else if o1.RequestPayer != o2.Object.RequestPayer {
        return false
    }

    return true
}