WesternFriend/WF-website

View on GitHub
cart/cart.py

Summary

Maintainability
A
0 mins
Test Coverage
A
100%
from decimal import Decimal
from collections.abc import Generator

from django.conf import settings
from django.http import HttpRequest

from shipping.calculator import get_book_shipping_cost
from store.models import Product


class Cart:
    def __init__(self, request: HttpRequest) -> None:
        """Initialize the cart."""
        self.session = request.session

        cart = self.session.get(settings.CART_SESSION_ID)

        if not cart:
            # save an empty cart in the session
            cart = self.session[settings.CART_SESSION_ID] = {}

        self.cart = cart

    def add(
        self,
        product: Product,
        quantity: int = 1,
    ) -> None:
        """Add a product to the cart or update its quantity."""
        product_id = str(product.id)  # type: ignore

        self.cart[product_id] = {
            "product_title": product.title,
            "product_id": product_id,
            "quantity": quantity,
            "price": str(product.price),
        }

        self.save()

    def save(self) -> None:
        # mark the session as "modified"
        # to make sure it gets saved

        self.session.modified = True

    def remove(self, product: Product) -> None:
        """Remove a product from the cart."""
        product_id = str(product.id)  # type: ignore

        if product_id in self.cart:
            del self.cart[product_id]

            self.save()

    def get_cart_products(self) -> list[Product]:
        product_ids = self.cart.keys()

        # get the product objects and add them to the cart
        return Product.objects.filter(id__in=product_ids)

    def get_total_cost(self) -> Decimal:
        int_sum = sum(
            [
                self.get_subtotal_cost(),
                self.get_shipping_cost(),
            ],
        )
        return Decimal(int_sum).quantize(Decimal("0.01"))

    def get_subtotal_cost(self) -> Decimal:
        totals = [
            Decimal(item["price"]) * item["quantity"] for item in self.cart.values()
        ]
        product_sum = sum(totals)
        return Decimal(product_sum).quantize(Decimal("0.01"))

    def get_shipping_cost(self) -> Decimal:
        book_quantity = sum(item["quantity"] for item in self.cart.values())

        return get_book_shipping_cost(book_quantity)

    def clear(self) -> None:
        # remove cart from session
        del self.session[settings.CART_SESSION_ID]

        self.cart = {}

        self.save()

    def __iter__(self) -> Generator:
        """Get cart products from the database."""
        # get the product objects and add them to the cart
        products = self.get_cart_products()

        cart = self.cart.copy()

        for product in products:
            cart[str(product.id)]["product"] = product

        for item in cart.values():
            item["price"] = Decimal(item["price"])
            item["total_price"] = item["price"] * item["quantity"]

            yield item

    def __len__(self) -> int:
        """Count all items in the cart."""

        # TODO: determine whether this should count the number of products
        # or the total quantity of products
        item_quantities = [item["quantity"] for item in self.cart.values()]

        return sum(item_quantities)