zidarsk8/simple_wbd

View on GitHub
tests/test_indicators.py

Summary

Maintainability
A
0 mins
Test Coverage
"""Unit tests for world bank data indicator API."""

import mock
import pycountry
from requests.exceptions import RequestException

import tests
import simple_wbd


class TestIndicators(tests.TestCase):
    """Tests for functions in simple_wbd.utils module."""

    NUM_OF_COUNTRIES = 304

    def setUp(self):
        super().setUp()
        self.api = simple_wbd.IndicatorAPI()

    def test_init(self):
        """Test Indicator api init function.

        Test if the default response class gets set correctly if the given
        class is a subclass of IndicatorDataset.
        """
        # pylint: disable=protected-access

        class Dummy(object):
            # pylint: disable=too-few-public-methods
            pass

        class DummySubclass(simple_wbd.IndicatorDataset):
            # pylint: disable=too-few-public-methods
            pass

        api = simple_wbd.IndicatorAPI()
        self.assertEqual(api._dataset_class, simple_wbd.IndicatorDataset)

        api = simple_wbd.IndicatorAPI(Dummy)
        self.assertEqual(api._dataset_class, simple_wbd.IndicatorDataset)

        api = simple_wbd.IndicatorAPI("bad data")
        self.assertEqual(api._dataset_class, simple_wbd.IndicatorDataset)

        api = simple_wbd.IndicatorAPI(DummySubclass)
        self.assertEqual(api._dataset_class, DummySubclass)

    def test_bad_request(self):
        """Test failed data fetch requests."""
        # pylint: disable=protected-access
        self.api._get_indicator_data = mock.Mock()
        self.api._get_indicator_data.side_effect = RequestException("Bam!")
        self.api.get_countries = mock.Mock()
        self.api.get_countries.return_value = []
        res = self.api.get_dataset("SL.MNF.0714.FE.ZS")
        self.assertEqual(res.api_responses, {})

    @tests.MY_VCR.use_cassette("country_list.json")
    def test_get_country_list(self):
        """Test fetching a 2D list of country and region codes."""
        country_list = self.api.get_country_list()
        # number of countries + one line for header
        self.assertEqual(self.NUM_OF_COUNTRIES + 1, len(country_list))

    @tests.MY_VCR.use_cassette("countries.json")
    def test_get_countries(self):
        """Test fetching all country and region codes."""
        all_codes = self.api.get_countries()
        self.assertEqual(self.NUM_OF_COUNTRIES, len(all_codes))
        countries = set(c["id"] for c in all_codes if c["capitalCity"])
        alpha3_codes = set(c.alpha3 for c in pycountry.countries)
        alpha3_codes.add("KSV")  # Non standard code for Kosovo.
        self.assertLess(countries, alpha3_codes)

    def test_filter_indicators(self):
        """Test _filter_indicators function."""
        # pylint: disable=protected-access

        bad_indicators = [
            {"id": "bad indicator"},
            {"id": "A7"},
            {"id": "A7ivi"},
            {"id": "BASIC"},
            {"id": "."},
            {"id": "2.0.hoi."},
        ]
        good_indicators = [
            {"id": "ag.lnd.frst.k2"},
            {"id": "ag.LND.FRst.zs"},
            {"id": "AG.LND.FRST.ZS"},
            {"id": "ag.lnd.irig.ag.zs"},
            {"id": "ag.lnd.totl.k2"},
        ]

        all_indicators = bad_indicators + good_indicators
        filtered = self.api._filter_indicators(all_indicators, "common")

        self.assertEqual(
            set(i.get("id").lower() for i in good_indicators),
            set(i.get("id").lower() for i in filtered),
        )

    def test_all_filter_indicators(self):
        """Test common, featured and bad indicator filters."""
        # pylint: disable=protected-access
        indicators = [
            {"id": "A7ivi"},  # none
            {"id": "ag.srf.totl.k2dt.dod.mdri.cd"},  # common
            {"id": "ag.LND.FRst.zs"},  # featured and common
        ]
        self.assertEqual(
            len(self.api._filter_indicators(indicators, "common")), 2)
        self.assertEqual(
            len(self.api._filter_indicators(indicators, "featured")), 1)
        self.assertEqual(
            len(self.api._filter_indicators(indicators, None)), 3)
        self.assertEqual(
            len(self.api._filter_indicators(indicators, "bad")), 3)

    @tests.MY_VCR.use_cassette("indicators_list.json")
    def test_get_indicator_list(self):
        """Test fetching indicators with and without filters."""
        all_indicators = self.api.get_indicator_list(filter_=None)
        common_indicators = self.api.get_indicator_list()
        featured_indicators = self.api.get_indicator_list(filter_="featured")

        self.assertEqual(all_indicators[0], common_indicators[0])
        self.assertEqual(all_indicators[0], featured_indicators[0])
        self.assertLess(len(featured_indicators), len(common_indicators))
        self.assertLess(len(common_indicators), len(all_indicators))

    @tests.MY_VCR.use_cassette("indicators.json")
    def test_get_indicators(self):
        """Test fetching indicators with and without filters."""
        all_indicators = self.api.get_indicators(filter_=None)
        common_indicators = self.api.get_indicators()
        featured_indicators = self.api.get_indicators(filter_="featured")

        all_codes = set(i.get("id").lower() for i in all_indicators)
        featured_codes = set(i.get("id").lower() for i in featured_indicators)
        common_codes = set(i.get("id").lower() for i in common_indicators)

        self.assertLess(featured_codes, all_codes, "Wrong featured subset.")
        self.assertLess(common_codes, all_codes, "Wrong common subset.")
        self.assertIn("AG.LND.ARBL.HA.PC".lower(), featured_codes)
        self.assertNotIn("ag.agr.trac.no".lower(), featured_codes)
        self.assertIn("ag.agr.trac.no".lower(), common_codes)

    @tests.MY_VCR.use_cassette("datasets.json")
    def test_get_dataset(self):
        """Test fetching datasests."""
        res = self.api.get_dataset("SL.MNF.0714.FE.ZS")
        self.assertIn("SL.MNF.0714.FE.ZS".lower(), res.api_responses)

        res = self.api.get_dataset(
            ["dp.dod.decx.cr.bc.z1", "SL.MNF.0714.FE.ZS", "2.0.HOI.cEL"],
            countries=["Svn", "us", "World"]
        )
        self.assertIn("SL.MNF.0714.FE.ZS".lower(), res.api_responses)
        self.assertIn("dp.dod.decx.cr.bc.z1".lower(), res.api_responses)
        self.assertIn("2.0.HOI.cEL".lower(), res.api_responses)

        ind_data = res.api_responses["SL.MNF.0714.FE.ZS".lower()]
        response_countries = set(i["country"]["value"] for i in ind_data)
        self.assertEqual(
            set(["United States", "Slovenia", "World"]),
            response_countries
        )

    @tests.MY_VCR.use_cassette("datasets2.json")
    def test_get_dataset_with_errors(self):
        """Test fetching datasests with bad parameters."""
        res = self.api.get_dataset(
            ["SL.MNF.0714.FE.ZS", "BAD Indicator"],
            countries=["Svn", "us", "bad country"]
        )
        self.assertIn("SL.MNF.0714.FE.ZS".lower(), res.api_responses)
        self.assertIn("BAD Indicator".lower(), res.api_responses)
        self.assertEqual(res.api_responses["BAD Indicator".lower()], [])

        ind_data = res.api_responses["SL.MNF.0714.FE.ZS".lower()]
        response_countries = set(i["country"]["value"] for i in ind_data)
        self.assertEqual(
            set(["United States", "Slovenia"]),
            response_countries
        )

        # If there is no good country parameter, we default to all
        res = self.api.get_dataset("SL.MNF.0714.FE.ZS", countries="badd")
        ind_data = res.api_responses["SL.MNF.0714.FE.ZS".lower()]
        response_countries = set(i["country"]["value"] for i in ind_data)
        self.assertLess(
            set(["United States", "Slovenia"]),
            response_countries
        )