Hackman238/legion

View on GitHub
tests/db/repositories/test_PortRepository.py

Summary

Maintainability
A
0 mins
Test Coverage
"""
LEGION (https://shanewilliamscott.com)
Copyright (c) 2018-2019 Hackman238

    This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public
    License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later
    version.

    This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied
    warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
    details.

    You should have received a copy of the GNU General Public License along with this program.
    If not, see <http://www.gnu.org/licenses/>.

Author(s): Shane Scott (sscott@shanewilliamscott.com), Dmitriy Dubson (d.dubson@gmail.com)
"""
import unittest
from unittest import mock
from unittest.mock import patch, MagicMock

from tests.db.helpers.db_helpers import mockFirstByReturnValue, mockExecuteFetchAll


class PortRepositoryTest(unittest.TestCase):
    def setUp(self) -> None:
        from db.repositories.PortRepository import PortRepository
        self.mockDbAdapter = MagicMock()
        self.mockDbSession = MagicMock()
        self.mockDbAdapter.session.return_value = self.mockDbSession
        self.repository = PortRepository(self.mockDbAdapter)

    def test_getPortsByIPAndProtocol_ReturnsPorts(self):
        expected_query = ("SELECT ports.portId FROM portObj AS ports "
                          "INNER JOIN hostObj AS hosts ON hosts.id = ports.hostId "
                          "WHERE hosts.ip = ? and ports.protocol = ?")
        self.mockDbAdapter.metadata.bind.execute.return_value = mockFirstByReturnValue(
            [['port-id1'], ['port-id2']])
        ports = self.repository.getPortsByIPAndProtocol("some_host_ip", "tcp")

        self.mockDbAdapter.metadata.bind.execute.assert_called_once_with(expected_query, "some_host_ip", "tcp")
        self.assertEqual([['port-id1'], ['port-id2']], ports)

    def test_getPortStatesByHostId_ReturnsPortsStates(self):
        expected_query = 'SELECT port.state FROM portObj as port WHERE port.hostId = ?'
        self.mockDbAdapter.metadata.bind.execute.return_value = mockExecuteFetchAll(
            [['port-state1'], ['port-state2']])
        port_states = self.repository.getPortStatesByHostId("some_host_id")

        self.mockDbAdapter.metadata.bind.execute.assert_called_once_with(expected_query, "some_host_id")
        self.assertEqual([['port-state1'], ['port-state2']], port_states)

    def test_getPortsAndServicesByHostIP_InvokedWithNoFilters_ReturnsPortsAndServices(self):
        from app.auxiliary import Filters

        expected_query = ("SELECT hosts.ip, ports.portId, ports.protocol, ports.state, ports.hostId, ports.serviceId, "
                          "services.name, services.product, services.version, services.extrainfo, services.fingerprint "
                          "FROM portObj AS ports INNER JOIN hostObj AS hosts ON hosts.id = ports.hostId "
                          "LEFT OUTER JOIN serviceObj AS services ON services.id = ports.serviceId "
                          "WHERE hosts.ip = ?")
        self.mockDbAdapter.metadata.bind.execute.return_value = mockExecuteFetchAll([['ip1'], ['ip2']])

        filters: Filters = Filters()
        filters.apply(up=True, down=True, checked=True, portopen=True, portfiltered=True, portclosed=True,
                      tcp=True, udp=True)
        results = self.repository.getPortsAndServicesByHostIP("some_host_ip", filters)

        self.mockDbAdapter.metadata.bind.execute.assert_called_once_with(expected_query, "some_host_ip")
        self.assertEqual([['ip1'], ['ip2']], results)

    def test_getPortsAndServicesByHostIP_InvokedWithFewFilters_ReturnsPortsAndServices(self):
        from app.auxiliary import Filters

        expected_query = ("SELECT hosts.ip, ports.portId, ports.protocol, ports.state, ports.hostId, ports.serviceId, "
                          "services.name, services.product, services.version, services.extrainfo, services.fingerprint "
                          "FROM portObj AS ports INNER JOIN hostObj AS hosts ON hosts.id = ports.hostId "
                          "LEFT OUTER JOIN serviceObj AS services ON services.id = ports.serviceId "
                          "WHERE hosts.ip = ? AND ports.protocol != 'tcp' AND ports.protocol != 'udp'")
        self.mockDbAdapter.metadata.bind.execute.return_value = mockExecuteFetchAll([['ip1'], ['ip2']])

        filters: Filters = Filters()
        filters.apply(up=True, down=True, checked=True, portopen=True, portfiltered=True, portclosed=True,
                      tcp=False, udp=False)
        results = self.repository.getPortsAndServicesByHostIP("some_host_ip", filters)

        self.mockDbAdapter.metadata.bind.execute.assert_called_once_with(expected_query, "some_host_ip")
        self.assertEqual([['ip1'], ['ip2']], results)

    def test_deleteAllPortsAndScriptsByHostId_WhenProvidedByHostIDAndProtocol_DeletesAllPortsAndScripts(self):
        mockFilterHost = mockProtocolFilter = mockReturnAll = MagicMock()
        mockPort1 = mockPort2 = MagicMock()
        mockReturnAll.all.return_value = [mockPort1, mockPort2]
        mockProtocolFilter.filter.return_value = mockReturnAll
        mockFilterHost.filter.return_value = mockProtocolFilter

        mockFilterScript = mockReturnAllScripts = MagicMock()
        mockReturnAllScripts.all.return_value = ['some-script1', 'some-script2']
        mockFilterScript.filter.return_value = mockReturnAllScripts

        self.mockDbSession.query.side_effect = [mockFilterHost, mockFilterScript, mockFilterScript]

        self.repository.deleteAllPortsAndScriptsByHostId("some-host-id", "some-protocol")
        self.mockDbSession.delete.assert_has_calls([
            mock.call('some-script1'), mock.call('some-script2'),
            mock.call('some-script1'), mock.call('some-script2'),
            mock.call(mockPort1), mock.call(mockPort2)
        ])
        self.mockDbSession.commit.assert_called_once()