uktrade/directory-api

View on GitHub
dataservices/management/commands/import_comtrade_data.py

Summary

Maintainability
B
6 hrs
Test Coverage
import csv

from django.conf import settings
from django.core.management import BaseCommand
from django.db import connection

from core.helpers import get_s3_file_stream
from dataservices.models import ComtradeReport


class Command(BaseCommand):
    help = 'Import Comtrade data'

    def add_arguments(self, parser):
        # Positional arguments
        parser.add_argument('filenames', nargs='*', type=str)

        parser.add_argument(
            '--wipe',
            action='store_true',
            help='Wipe table only',
        )

        parser.add_argument(
            '--raw',
            action='store_true',
            help='load raw data files',
        )

        parser.add_argument(
            '--link_countries',
            action='store_true',
            help='Link existing data to countries',
        )

        parser.add_argument(
            '--unlink_countries',
            action='store_true',
            help='Unlink existing countries so that country data can be deleted',
        )

        parser.add_argument(
            '--test',
            action='store_true',
            help='limit rowcount to 1000 for testing',
        )

    def load_raw_files(self, filenames):
        # Loads a raw file as downloaded from comtrade on top of existing data in db

        for filename in filenames:
            self.stdout.write(self.style.SUCCESS(f'********  Loading: {filename}'))
            with open(filename, 'r', encoding='utf-8-sig') as f:
                written = 0
                read = 0
                file_reader = csv.DictReader(f)
                for row in file_reader:
                    read = read + 1
                    if row.get('Is Leaf Code') == '1':
                        reporter_iso3 = row.get('Reporter ISO')
                        partner_iso3 = row.get('Partner ISO')
                        flow = row.get('Trade Flow')
                        uk_or_world = None
                        country_iso3 = None
                        if reporter_iso3 == 'GBR' and flow == 'Export':
                            uk_or_world = reporter_iso3
                            country_iso3 = partner_iso3
                        if partner_iso3 == 'WLD' and flow == 'Import':
                            uk_or_world = partner_iso3
                            country_iso3 = reporter_iso3
                        if country_iso3 and uk_or_world:
                            written = written + 1
                            report = ComtradeReport(
                                country_iso3=country_iso3,
                                year=row.get('Year'),
                                classification=row.get('Classification'),
                                commodity_code=row.get('Commodity Code'),
                                trade_value=float(row.get('Trade Value (US$)') or '0'),
                                uk_or_world=uk_or_world,
                            )
                            report.save()
                            if written % 100 == 0:
                                print(f'{read} read, {written} written', end='\r', flush=True)
                self.stdout.write(self.style.SUCCESS(f'{read} read, {written} written'))

    def link_countries(self):
        cursor = connection.cursor()
        self.stdout.write('Linking countries')
        cursor.execute(
            "UPDATE dataservices_comtradereport as d \
            set country_id=c.id \
            from dataservices_country as c where d.country_iso3=c.iso3;"
        )

    def unlink_countries(self):
        cursor = connection.cursor()
        self.stdout.write('Un-linking countries')
        cursor.execute("UPDATE dataservices_comtradereport set country_id=null;")

    def populate_db_from_s3(self, filename, test):
        # Read from S3, write into local DB, hook up country table
        cursor = connection.cursor()
        filestream = get_s3_file_stream(filename or settings.COMTRADE_DATA_FILE_NAME)
        file_reader = csv.DictReader(filestream.split())
        self.stdout.write('*********   Loading comtrade data')
        written = 0
        for row in file_reader:
            cursor.execute(
                "INSERT INTO \
                dataservices_comtradereport \
                (id, year, classification, commodity_code, trade_value, uk_or_world, country_iso3 )\
                VALUES\
                (%s, %s, %s, %s, %s, %s, %s)",
                [
                    row.get('id'),
                    row.get('year'),
                    row.get('classification'),
                    row.get('commodity_code'),
                    row.get('trade_value'),
                    row.get('uk_or_world'),
                    row.get('country_iso3'),
                ],
            )

            written = written + 1
            if written % 1000 == 0:
                print(f'  {written} rows written', end='\r', flush=True)
            if written >= 1000 and test:
                break
        self.stdout.write(self.style.SUCCESS(f'Loaded table - {written} rows written'))
        self.link_countries()

    def handle(self, *args, **options):
        filenames = options['filenames']
        if options['wipe']:
            ComtradeReport.objects.all().delete()
        elif options['link_countries']:
            self.link_countries()
        elif options['unlink_countries']:
            self.unlink_countries()
        elif filenames and options['raw']:
            self.load_raw_files(filenames)
        else:
            self.populate_db_from_s3(filenames and filenames[0], test=options['test'])

        self.stdout.write(self.style.SUCCESS('All done, bye!'))