teco-kit/PointAndControl

View on GitHub
IGS/Classifier/KNNClassifier.cs

Summary

Maintainability
A
0 mins
Test Coverage
using System;
using System.Collections.Generic;
using numl;
using numl.Supervised.KNN;
using numl.Model;


namespace PointAndControl.Classifier
{
    public class KNNClassifier
    {
        public List<WallProjectionSample> trainingSamples { get; set; }
        KNNGenerator generator { get; set; }
        LearningModel learned { get; set; }


            
        public int trainingSetSize { get; set; }
       

        public KNNClassifier(List<WallProjectionSample> list, int trainingsSampleCount)
        {
        
            trainingSamples = new List<WallProjectionSample>();
 
            var descriptor = Descriptor.Create<WallProjectionSample>();
            generator = new KNNGenerator();
            generator.Descriptor = descriptor;
            trainingSetSize = trainingsSampleCount;
            
            if (list != null && list.Count != 0)
            {
                learnBatch(list);
            }
                       
        }




        public void trainClassifier()
        {
           
            if (trainingSamples.Count > 10)
            {
                int k = (int)(Math.Floor(Math.Sqrt(trainingSamples.Count)));
                if (k == 0)
                {
                    k = 1;
                }

                generator.K = k;
                
                learned = Learner.Learn(trainingSamples, 0.90, 1, generator);
            }
            else Console.WriteLine("Please create samples first!");

        }

        public void trainClassifier(List<WallProjectionSample> trainingSet)
        {
            if (trainingSet.Count > 0)
            {
                generator.K = (int)Math.Sqrt(trainingSamples.Count);
                learned = Learner.Learn(trainingSet, 0.99, 1, generator);
            }
        }


        public void trainClassifier(int iterations)
        {
            if (trainingSamples.Count > 0)
            {
                generator.K = (int)Math.Sqrt(trainingSamples.Count);
                learned = Learner.Learn(trainingSamples, 0.99, iterations, generator);
            }
        }
        public WallProjectionSample classify(WallProjectionSample sample)
        {
            if (learned != null)
            {
                return learned.Model.Predict(sample);
            }
            else return null;
        }

        public void learnBatch(List<WallProjectionSample> trainingsSamples)
        {
            
            if (trainingsSamples == null || trainingsSamples.Count == 0) { return; }

            int diff = Math.Abs(trainingSamples.Count - trainingSetSize);

                foreach (WallProjectionSample s in trainingsSamples)
                {
                    trainingSamples.Add(s);

                    if (trainingSetSize > 0 && trainingSamples.Count > trainingSetSize)
                    {
                        

                        while (diff >= 0)
                        {
                            trainingSamples.RemoveAt(0);

                            diff--;
                        }

                    }
                }
                trainClassifier();
        }

        public void addSampleAndLearn(WallProjectionSample sample)
        {

            if (sample == null) { return; }

            trainingSamples.Add(sample);

            if (trainingSetSize > 0 && trainingSamples.Count > trainingSetSize)
            {
                Random r = new Random();
                int half = (int)Math.Floor((double)(trainingSamples.Count / 2));
                int pos = r.Next(half);
                trainingSamples.RemoveAt(pos);
            }
            trainClassifier();
        }    
    }
}