extensions/classifiers/svm-prediction.service.spec.ts

Summary

Maintainability
F
2 wks
Test Coverage
// Copyright 2016 The Oppia Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS-IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

/**
 * @fileoverview Unit tests for the SVM prediction functions.
 */

import {TestBed} from '@angular/core/testing';

import {SVMPredictionService} from 'classifiers/svm-prediction.service';

describe('SVM prediction functions', () => {
  describe('Test SVM prediction functions', () => {
    let service: SVMPredictionService;
    beforeEach(() => {
      TestBed.configureTestingModule({
        providers: [SVMPredictionService],
      });

      service = TestBed.get(SVMPredictionService);
    });

    it('should calculate correct kernel values', () => {
      var kernelParams = {
        kernel: 'rbf',
        coef0: 0.0,
        degree: 3,
        gamma: 0.5,
      };

      var supportVectors = [
        [0, 0],
        [1, 1],
      ];
      var input = [1, 0];
      var kvalues = service.kernel(kernelParams, supportVectors, input);
      var expectedKvalues = [0.6065306597126334, 0.6065306597126334];
      expect(kvalues.length).toEqual(2);
      expect(kvalues).toEqual(expectedKvalues);
    });

    it('should calculate correct kernel values', () => {
      var kernelParams = {
        kernel: 'linear',
        coef0: 0.0,
        degree: 3,
        gamma: 0.5,
      };

      var supportVectors = [
        [0, 0],
        [1, 1],
      ];
      var input = [1, 0];
      var kvalues = service.kernel(kernelParams, supportVectors, input);
      var expectedKvalues = [0, 1];
      expect(kvalues.length).toEqual(2);
      expect(kvalues).toEqual(expectedKvalues);
    });

    it('should give correct labels and confidence estimations', () => {
      // This is classifier data of a pretrained SVM classifier trained
      // on a synthetic dataset created for the purpose of testing SVM
      // prediction service. The classes in the classifier data are equivalent
      // to two non-default answer groups of the training data of some
      // exploration.
      var svmData = {
        classes: [0, 1],
        kernel_params: {
          kernel: 'rbf',
          coef0: 0.0,
          degree: 3,
          gamma: 0.5,
        },
        intercept: [0.04554340162799716],
        n_support: [80, 66],
        probA: [-4.76812258346006],
        support_vectors: [
          [5.0, 0.0],
          [4.0, -2.0],
          [3.0, -4.0],
          [-5.0, 0.0],
          [4.0, -2.0],
          [0.0, 0.0],
          [2.0, -1.0],
          [3.0, 4.0],
          [-4.0, -2.0],
          [-3.0, 4.0],
          [-5.0, 0.0],
          [-4.0, -2.0],
          [0.0, -5.0],
          [-1.0, 1.0],
          [-1.0, -1.0],
          [-1.0, 0.0],
          [4.0, -2.0],
          [-2.0, -4.0],
          [-5.0, 0.0],
          [4.0, -3.0],
          [5.0, 0.0],
          [3.0, -4.0],
          [-5.0, 0.0],
          [5.0, 0.0],
          [0.0, -5.0],
          [-4.0, 3.0],
          [4.0, 1.0],
          [-1.0, 4.0],
          [4.0, 2.0],
          [1.0, 4.0],
          [4.0, 3.0],
          [-1.0, 4.0],
          [5.0, 0.0],
          [0.0, -1.0],
          [2.0, 0.0],
          [-4.0, -3.0],
          [-2.0, -4.0],
          [4.0, 3.0],
          [-4.0, -1.0],
          [0.0, 5.0],
          [4.0, 3.0],
          [-2.0, -1.0],
          [3.0, -4.0],
          [5.0, 0.0],
          [4.0, 1.0],
          [5.0, 0.0],
          [0.0, -2.0],
          [2.0, 1.0],
          [0.0, -5.0],
          [1.0, -4.0],
          [0.0, 2.0],
          [-4.0, 3.0],
          [-2.0, 0.0],
          [4.0, 3.0],
          [-3.0, 4.0],
          [-1.0, -2.0],
          [1.0, -2.0],
          [1.0, -4.0],
          [1.0, -1.0],
          [1.0, 2.0],
          [-4.0, 1.0],
          [-3.0, 4.0],
          [-3.0, -4.0],
          [3.0, -4.0],
          [-1.0, 2.0],
          [3.0, 4.0],
          [-4.0, 1.0],
          [-1.0, -4.0],
          [1.0, 1.0],
          [-4.0, 3.0],
          [0.0, 5.0],
          [2.0, 4.0],
          [-2.0, 1.0],
          [0.0, -5.0],
          [-3.0, -4.0],
          [-4.0, 3.0],
          [1.0, 4.0],
          [-4.0, -3.0],
          [0.0, -5.0],
          [-3.0, 4.0],
          [1.0, 5.0],
          [5.0, 1.0],
          [1.0, -5.0],
          [1.0, 5.0],
          [1.0, 5.0],
          [-5.0, 1.0],
          [-4.0, -4.0],
          [-1.0, -5.0],
          [-1.0, -5.0],
          [4.0, 4.0],
          [-5.0, 1.0],
          [1.0, -5.0],
          [-5.0, 3.0],
          [1.0, -5.0],
          [-5.0, 3.0],
          [5.0, 1.0],
          [4.0, -4.0],
          [-4.0, 4.0],
          [-3.0, 5.0],
          [5.0, 1.0],
          [-5.0, 1.0],
          [-4.0, -4.0],
          [-4.0, 4.0],
          [1.0, -5.0],
          [-1.0, 5.0],
          [4.0, -4.0],
          [4.0, -4.0],
          [5.0, -3.0],
          [-1.0, 5.0],
          [5.0, -1.0],
          [5.0, -1.0],
          [4.0, 4.0],
          [-4.0, 4.0],
          [-4.0, -4.0],
          [-2.0, 5.0],
          [1.0, -5.0],
          [5.0, 1.0],
          [4.0, 4.0],
          [4.0, 4.0],
          [-5.0, 1.0],
          [-1.0, 5.0],
          [3.0, 5.0],
          [5.0, -1.0],
          [-5.0, -1.0],
          [5.0, 1.0],
          [3.0, 5.0],
          [5.0, 1.0],
          [3.0, -5.0],
          [-5.0, -2.0],
          [5.0, 3.0],
          [5.0, -1.0],
          [-5.0, -3.0],
          [-5.0, -1.0],
          [3.0, -5.0],
          [-5.0, 1.0],
          [5.0, 3.0],
          [-3.0, -5.0],
          [-4.0, 4.0],
          [5.0, -1.0],
          [2.0, 5.0],
          [-1.0, 5.0],
          [-5.0, -5.0],
          [-3.0, 5.0],
          [-5.0, -2.0],
          [-2.0, -5.0],
          [-2.0, -5.0],
        ],
        probB: [-0.26830931608536374],
        dual_coef: [
          [
            1.0, 0.17963792804729697, 0.403550660516519, 1.0, 1.0,
            0.2174320339900639, 0.32237125746964795, 1.0, 0.23406746659599886,
            0.13107690381219206, 1.0, 1.0, 0.20357365261524915,
            0.3806808376092491, 0.07231536087203701, 0.052444785344018065,
            0.9373454934508193, 0.2887075426898694, 1.0, 1.0, 1.0, 1.0, 1.0,
            1.0, 1.0, 0.09794130691741577, 0.9371448493098987,
            0.8020377139435809, 0.10523314152848777, 0.7208368125926214,
            0.2745796264317118, 1.0, 1.0, 0.1341948623940462, 0.184427006132661,
            1.0, 1.0, 1.0, 0.23220617815321498, 1.0, 1.0, 0.34100990964941563,
            1.0, 1.0, 1.0, 1.0, 0.19085781720747444, 0.10622693983611159, 1.0,
            0.9733190570902237, 0.07976440321906088, 1.0, 0.18011727003205402,
            1.0, 1.0, 0.32629540304776156, 0.10535962914306607, 1.0,
            0.23429959940904435, 0.2026193359451537, 1.0, 1.0, 1.0, 1.0,
            0.20070585346407077, 1.0, 0.999293299134111, 0.1464060764902667,
            0.3890080385472037, 1.0, 1.0, 0.2940723495632226,
            0.10157585440791363, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
            0.9896220351629882, -1.0, -0.3138343321080787, -0.5646722753216991,
            -1.0, -1.0, -0.5049737412040376, -1.0, -1.0, -1.0,
            -0.3604481739923152, -1.0, -1.0, -0.3764883860788214, -1.0, -1.0,
            -1.0, -0.9796036195304532, -0.7935513623790754, -0.4880899543555701,
            -1.0, -1.0, -1.0, -1.0, -1.0, -0.06955956862215716, -1.0, -1.0,
            -0.9943549382257916, -1.0, -0.9093105655834473, -1.0, -1.0, -1.0,
            -0.959538326485444, -0.09018015160096478, -1.0, -1.0, -1.0,
            -0.9834004233163612, -1.0, -1.0, -0.025769464156420306, -1.0, -1.0,
            -1.0, -0.9975520403673716, -0.9930682257463449, -0.6822294518807557,
            -0.19520832795700704, -0.9197741907209602, -1.0,
            -0.8410687469551232, -1.0, -1.0, -0.9737667349019845,
            -0.5740153082819723, -0.8386505512671123, -0.9559546083928095,
            -0.9333906130210027, -0.1858803184178766, -1.0,
            -0.046523686102560795, -1.0, -0.9984458385556095,
            -0.3443622757534568, -0.8786920904831582,
          ],
        ],
      };

      var testx = [
        [-2, -1],
        [5, -2],
        [-1, -2],
        [2, -5],
        [-4, 5],
        [-1, 5],
        [-2, 0],
        [-1, 5],
        [-1, 2],
      ];
      var predictions = [0, 1, 0, 1, 1, 1, 0, 1, 0];
      var probs = [
        0.9934976, 0.99677775, 0.99349075, 0.99999448, 0.99999958, 0.98901676,
        0.99349654, 0.98901676, 0.99351481,
      ];

      for (var i = 0; i < testx.length; i++) {
        var predictionResult = service.predict(svmData, testx[i]);
        expect(predictionResult.predictionLabel).toEqual(predictions[i]);
        expect(
          Math.abs(predictionResult.predictionConfidence - probs[i])
        ).toBeLessThan(1e-3);
      }
    });

    it('should give correct labels and confidence estimations', () => {
      // This is classifier data of a pretrained SVM classifier trained
      // on a synthetic dataset created for the purpose of testing SVM
      // prediction service. The classes in the classifier data are equivalent
      // to two non-default answer groups of the training data of some
      // exploration.
      var svmData = {
        classes: [0, 1],
        kernel_params: {
          kernel: 'rbf',
          coef0: 0.0,
          degree: 3,
          gamma: 0.5,
        },
        intercept: [0.04554340162799716],
        n_support: [80, 66],
        probA: [-4.76812258346006],
        support_vectors: [
          [5.0, 0.0],
          [4.0, -2.0],
          [3.0, -4.0],
          [-5.0, 0.0],
          [4.0, -2.0],
          [0.0, 0.0],
          [2.0, -1.0],
          [3.0, 4.0],
          [-4.0, -2.0],
          [-3.0, 4.0],
          [-5.0, 0.0],
          [-4.0, -2.0],
          [0.0, -5.0],
          [-1.0, 1.0],
          [-1.0, -1.0],
          [-1.0, 0.0],
          [4.0, -2.0],
          [-2.0, -4.0],
          [-5.0, 0.0],
          [4.0, -3.0],
          [5.0, 0.0],
          [3.0, -4.0],
          [-5.0, 0.0],
          [5.0, 0.0],
          [0.0, -5.0],
          [-4.0, 3.0],
          [4.0, 1.0],
          [-1.0, 4.0],
          [4.0, 2.0],
          [1.0, 4.0],
          [4.0, 3.0],
          [-1.0, 4.0],
          [5.0, 0.0],
          [0.0, -1.0],
          [2.0, 0.0],
          [-4.0, -3.0],
          [-2.0, -4.0],
          [4.0, 3.0],
          [-4.0, -1.0],
          [0.0, 5.0],
          [4.0, 3.0],
          [-2.0, -1.0],
          [3.0, -4.0],
          [5.0, 0.0],
          [4.0, 1.0],
          [5.0, 0.0],
          [0.0, -2.0],
          [2.0, 1.0],
          [0.0, -5.0],
          [1.0, -4.0],
          [0.0, 2.0],
          [-4.0, 3.0],
          [-2.0, 0.0],
          [4.0, 3.0],
          [-3.0, 4.0],
          [-1.0, -2.0],
          [1.0, -2.0],
          [1.0, -4.0],
          [1.0, -1.0],
          [1.0, 2.0],
          [-4.0, 1.0],
          [-3.0, 4.0],
          [-3.0, -4.0],
          [3.0, -4.0],
          [-1.0, 2.0],
          [3.0, 4.0],
          [-4.0, 1.0],
          [-1.0, -4.0],
          [1.0, 1.0],
          [-4.0, 3.0],
          [0.0, 5.0],
          [2.0, 4.0],
          [-2.0, 1.0],
          [0.0, -5.0],
          [-3.0, -4.0],
          [-4.0, 3.0],
          [1.0, 4.0],
          [-4.0, -3.0],
          [0.0, -5.0],
          [-3.0, 4.0],
          [1.0, 5.0],
          [5.0, 1.0],
          [1.0, -5.0],
          [1.0, 5.0],
          [1.0, 5.0],
          [-5.0, 1.0],
          [-4.0, -4.0],
          [-1.0, -5.0],
          [-1.0, -5.0],
          [4.0, 4.0],
          [-5.0, 1.0],
          [1.0, -5.0],
          [-5.0, 3.0],
          [1.0, -5.0],
          [-5.0, 3.0],
          [5.0, 1.0],
          [4.0, -4.0],
          [-4.0, 4.0],
          [-3.0, 5.0],
          [5.0, 1.0],
          [-5.0, 1.0],
          [-4.0, -4.0],
          [-4.0, 4.0],
          [1.0, -5.0],
          [-1.0, 5.0],
          [4.0, -4.0],
          [4.0, -4.0],
          [5.0, -3.0],
          [-1.0, 5.0],
          [5.0, -1.0],
          [5.0, -1.0],
          [4.0, 4.0],
          [-4.0, 4.0],
          [-4.0, -4.0],
          [-2.0, 5.0],
          [1.0, -5.0],
          [5.0, 1.0],
          [4.0, 4.0],
          [4.0, 4.0],
          [-5.0, 1.0],
          [-1.0, 5.0],
          [3.0, 5.0],
          [5.0, -1.0],
          [-5.0, -1.0],
          [5.0, 1.0],
          [3.0, 5.0],
          [5.0, 1.0],
          [3.0, -5.0],
          [-5.0, -2.0],
          [5.0, 3.0],
          [5.0, -1.0],
          [-5.0, -3.0],
          [-5.0, -1.0],
          [3.0, -5.0],
          [-5.0, 1.0],
          [5.0, 3.0],
          [-3.0, -5.0],
          [-4.0, 4.0],
          [5.0, -1.0],
          [2.0, 5.0],
          [-1.0, 5.0],
          [-5.0, -5.0],
          [-3.0, 5.0],
          [-5.0, -2.0],
          [-2.0, -5.0],
          [-2.0, -5.0],
        ],
        probB: [-0.26830931608536374],
        dual_coef: [
          [
            1.0, 0.17963792804729697, 0.403550660516519, 1.0, 1.0,
            0.2174320339900639, 0.32237125746964795, 1.0, 0.23406746659599886,
            0.13107690381219206, 1.0, 1.0, 0.20357365261524915,
            0.3806808376092491, 0.07231536087203701, 0.052444785344018065,
            0.9373454934508193, 0.2887075426898694, 1.0, 1.0, 1.0, 1.0, 1.0,
            1.0, 1.0, 0.09794130691741577, 0.9371448493098987,
            0.8020377139435809, 0.10523314152848777, 0.7208368125926214,
            0.2745796264317118, 1.0, 1.0, 0.1341948623940462, 0.184427006132661,
            1.0, 1.0, 1.0, 0.23220617815321498, 1.0, 1.0, 0.34100990964941563,
            1.0, 1.0, 1.0, 1.0, 0.19085781720747444, 0.10622693983611159, 1.0,
            0.9733190570902237, 0.07976440321906088, 1.0, 0.18011727003205402,
            1.0, 1.0, 0.32629540304776156, 0.10535962914306607, 1.0,
            0.23429959940904435, 0.2026193359451537, 1.0, 1.0, 1.0, 1.0,
            0.20070585346407077, 1.0, 0.999293299134111, 0.1464060764902667,
            0.3890080385472037, 1.0, 1.0, 0.2940723495632226,
            0.10157585440791363, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
            0.9896220351629882, -1.0, -0.3138343321080787, -0.5646722753216991,
            -1.0, -1.0, -0.5049737412040376, -1.0, -1.0, -1.0,
            -0.3604481739923152, -1.0, -1.0, -0.3764883860788214, -1.0, -1.0,
            -1.0, -0.9796036195304532, -0.7935513623790754, -0.4880899543555701,
            -1.0, -1.0, -1.0, -1.0, -1.0, -0.06955956862215716, -1.0, -1.0,
            -0.9943549382257916, -1.0, -0.9093105655834473, -1.0, -1.0, -1.0,
            -0.959538326485444, -0.09018015160096478, -1.0, -1.0, -1.0,
            -0.9834004233163612, -1.0, -1.0, -0.025769464156420306, -1.0, -1.0,
            -1.0, -0.9975520403673716, -0.9930682257463449, -0.6822294518807557,
            -0.19520832795700704, -0.9197741907209602, -1.0,
            -0.8410687469551232, -1.0, -1.0, -0.9737667349019845,
            -0.5740153082819723, -0.8386505512671123, -0.9559546083928095,
            -0.9333906130210027, -0.1858803184178766, -1.0,
            -0.046523686102560795, -1.0, -0.9984458385556095,
            -0.3443622757534568, -0.8786920904831582,
          ],
        ],
      };
      spyOn(console, 'error');

      var testx = [-2];

      service.predict(svmData, testx);

      expect(console.error).toHaveBeenCalledWith(
        'Dimension of support vectors and given input is different.'
      );
    });
  });
});