guimc233/lgz-bot

View on GitHub
src/main/java/huzpsb/ll4j/model/Model.java

Summary

Maintainability
D
2 days
Test Coverage
package huzpsb.ll4j.model;

import huzpsb.ll4j.layer.*;
import huzpsb.ll4j.utils.data.DataEntry;
import huzpsb.ll4j.utils.data.DataSet;
import huzpsb.ll4j.utils.pair.Pair;
import huzpsb.ll4j.utils.random.NRandom;
import huzpsb.ll4j.utils.random.RandomSeedGenerator;
import org.jetbrains.annotations.NotNull;

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.nio.charset.StandardCharsets;
import java.util.LinkedList;
import java.util.List;
import java.util.Scanner;

@SuppressWarnings({"unused"})
public class Model {
    public AbstractLayer[] layers;

    public Model(AbstractLayer... layers) {
        this.layers = layers;
    }

    public Pair<Integer, Integer> trainOn(DataSet dataSet) {
        return trainOn(dataSet, 8e-7);
    }

    public Pair<Integer, Integer> trainOn(DataSet dataSet, double learningRate) {
        int t = 0, f = 0;
        if (!(layers[layers.length - 1] instanceof JudgeLayer)) {
            throw new RuntimeException("Last layer is not output layer");
        }
        JudgeLayer judgeLayer = (JudgeLayer) layers[layers.length - 1];
        for (DataEntry dataEntry : dataSet.split) {
            layers[0].input = dataEntry.values;
            for (int i = 0; i < layers.length; i++) {
                layers[i].training = true;
                layers[i].forward();
                if (i < layers.length - 1) {
                    layers[i + 1].input = layers[i].output;
                }
            }
            int predict = judgeLayer.result;
            int actual = dataEntry.type;
            if (predict == actual) {
                t++;
            } else {
                f++;
            }
            judgeLayer.result = actual;
            for (int i = layers.length - 1; i >= 0; i--) {
                layers[i].backward();
                layers[i].update(learningRate);
                if (i > 0) {
                    layers[i - 1].output_error = layers[i].input_error;
                }
            }
        }
        return Pair.create(t, f);
    }

    // train parallel

    public List<DataEntry> trainAndGetWA(DataSet dataSet, double learningRate) {
        LinkedList<DataEntry> entries = new LinkedList<>();
        if (!(layers[layers.length - 1] instanceof JudgeLayer)) {
            throw new RuntimeException("Last layer is not output layer");
        }
        JudgeLayer judgeLayer = (JudgeLayer) layers[layers.length - 1];
        for (DataEntry dataEntry : dataSet.split) {
            layers[0].input = dataEntry.values;
            for (int i = 0; i < layers.length; i++) {
                layers[i].training = true;
                layers[i].forward();
                if (i < layers.length - 1) {
                    layers[i + 1].input = layers[i].output;
                }
            }
            int predict = judgeLayer.result;
            int actual = dataEntry.type;
            if (predict != actual) {
                entries.add(dataEntry);
            }
            judgeLayer.result = actual;
            for (int i = layers.length - 1; i >= 0; i--) {
                layers[i].backward();
                layers[i].update(learningRate);
                if (i > 0) {
                    layers[i - 1].output_error = layers[i].input_error;
                }
            }
        }
        return entries;
    }

    public Pair<Integer, Integer> testOn(DataSet dataSet) {
        int t = 0, f = 0;
        for (DataEntry dataEntry : dataSet.split) {
            layers[0].input = dataEntry.values;
            for (int i = 0; i < layers.length; i++) {
                layers[i].training = false;
                layers[i].forward();
                if (i < layers.length - 1) {
                    layers[i + 1].input = layers[i].output;
                }
            }
            JudgeLayer judgeLayer = (JudgeLayer) layers[layers.length - 1];
            int predict = judgeLayer.result;
            int actual = dataEntry.type;
            if (predict == actual) {
                t++;
            } else {
                f++;
            }
        }
        return Pair.create(t, f);
    }

    public List<DataEntry> testAndGetWA(DataSet dataSet) {
        LinkedList<DataEntry> entries = new LinkedList<>();
        for (DataEntry dataEntry : dataSet.split) {
            layers[0].input = dataEntry.values;
            for (int i = 0; i < layers.length; i++) {
                layers[i].training = false;
                layers[i].forward();
                if (i < layers.length - 1) {
                    layers[i + 1].input = layers[i].output;
                }
            }
            JudgeLayer judgeLayer = (JudgeLayer) layers[layers.length - 1];
            int predict = judgeLayer.result;
            int actual = dataEntry.type;
            if (predict == actual) continue;
            entries.add(dataEntry);
        }
        return entries;
    }

    public int predict(double[] input) {
        AbstractLayer[] layers = this.layers;
        if (!(layers[layers.length - 1] instanceof JudgeLayer)) {
            throw new RuntimeException("Last layer is not output layer");
        }
        for (AbstractLayer layer : layers) {
            layer.training = false;
            layer.input = input;
            layer.forward();
            input = layer.output;
        }
        return ((JudgeLayer) layers[layers.length - 1]).result;
    }
    public kotlin.Pair<Integer, Double> predictDebug(double[] input) {
        AbstractLayer[] layers = this.layers;
        if (!(layers[layers.length - 1] instanceof JudgeLayer)) {
            throw new RuntimeException("Last layer is not output layer");
        }
        for (AbstractLayer layer : layers) {
            layer.training = false;
            layer.input = input;
            layer.forward();
            input = layer.output;
        }
        JudgeLayer judgeLayer = (JudgeLayer) layers[layers.length - 1];
        return new kotlin.Pair<>(judgeLayer.result, layers[layers.length - 1].input[judgeLayer.result]);
    }

    public double @NotNull [] predictAllResult(double[] input) {
        AbstractLayer[] layers = this.layers;
        if (!(layers[layers.length - 1] instanceof JudgeLayer)) {
            throw new RuntimeException("Last layer is not output layer");
        }
        for (AbstractLayer layer : layers) {
            layer.training = false;
            layer.input = input;
            layer.forward();
            input = layer.output;
        }
        return layers[layers.length - 1].input;
    }

    public int predict(DataEntry de) {
        double[] input = de.values;
        for (AbstractLayer layer : layers) {
            layer.training = false;
            layer.input = input;
            layer.forward();
            input = layer.output;
        }
        return ((JudgeLayer) layers[layers.length - 1]).result;
    }

    public void save(String path) {
        try (PrintWriter pw = new PrintWriter(path, StandardCharsets.UTF_8)) {
            pw.print(toScript());
        } catch (Exception ex) {
            throw new RuntimeException(ex);
        }
    }

    public static Model read(BufferedReader reader) {
        return new Model(reader.lines().map(Model::parseLine).toArray(AbstractLayer[]::new));
    }

    public static Model read(Scanner scanner) {
        Builder builder = new Builder();
        while (scanner.hasNextLine()) {
            builder.layer(parseLine(scanner.nextLine()));
        }
        return builder.build();
    }

    public static Model read(String script) {
        String[] spLine = script.split("\n");
        Builder builder = new Builder();
        for (String line : spLine) {
            builder.layer(parseLine(line));
        }
        return builder.build();
    }

    public static Model readFrom(String path) {
        try (BufferedReader reader = new BufferedReader(new FileReader(path))) {
            return read(reader);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public static AbstractLayer parseLine(String line) {
        String[] tokens = line.split(" ");
        switch (tokens[0]) {
            case "D":
                int inputSize = Integer.parseInt(tokens[1]);
                int outputSize = Integer.parseInt(tokens[2]);
                DenseLayer denseLayer = new DenseLayer(inputSize, outputSize);
                double[][] weights = denseLayer.weights;
                for (int i = 0; i < inputSize; i++) {
                    for (int j = 0; j < outputSize; j++) {
                        weights[i][j] = Double.parseDouble(tokens[3 + i * outputSize + j]);
                    }
                }
                return denseLayer;
            case "L":
                return new LeakyRelu(Integer.parseInt(tokens[1]));
            case "S":
                return new Sigmoid(Integer.parseInt(tokens[1]));
            case "Dropout":
                return new DropoutLayer(Integer.parseInt(tokens[1])).dropout(Double.parseDouble(tokens[2]));
            case "Softmax":
                return new Softmax(Integer.parseInt(tokens[1]));
            case "Th":
                return new Tanh(Integer.parseInt(tokens[1]));
            case "J":
                return new JudgeLayer(Integer.parseInt(tokens[1]));
            default:
                throw new AssertionError("Unknown layer type: " + tokens[0]);
        }
    }

    public String toScript() {
        StringBuilder sb = new StringBuilder();
        for (AbstractLayer layer : layers) {
            layer.serialize(sb);
        }
        return sb.toString();
    }

    public Model clean() {
        Builder builder = new Builder();
        for (AbstractLayer layer : layers) {
            if (layer instanceof AbstractLayer.TrainOnlyLayer) continue;
            builder.layer(layer);
        }
        return builder.build();
    }

    @SuppressWarnings({"UnusedReturnValue", "unused"})
    public static class Builder {
        private List<AbstractLayer> layers = new LinkedList<>();
        private RandomSeedGenerator randomGenerator = RandomSeedGenerator.Builtin.CONSTANT;

        public Builder() {}

        public Builder layer(AbstractLayer layer) {
            if (layer != null) {
                layer.random = new NRandom(randomGenerator.generateSeed());
                layers.add(layer);
            }
            return this;
        }

        public Builder setRandomGenerator(RandomSeedGenerator randomGenerator) {
            this.randomGenerator = randomGenerator;
            return this;
        }

        public RandomSeedGenerator getRandomGenerator() {
            return randomGenerator;
        }

        public List<AbstractLayer> getLayers() {
            return layers;
        }

        public Builder setLayers(List<AbstractLayer> layers) {
            this.layers = layers;
            return this;
        }

        public Model build() {
            return new Model(layers.toArray(AbstractLayer[]::new));
        }
    }


    @SuppressWarnings({"UnusedReturnValue", "unused"})
    public static class ContextBuilder extends Builder {
        private int size;
        private boolean hasOutput = false;

        public ContextBuilder(int inputSize) {
            this.size = inputSize;
        }

        @Override
        public ContextBuilder layer(AbstractLayer layer) {
            super.layer(layer);
            return this;
        }

        public ContextBuilder DenseLayer(int outputSize) {
            layer(new DenseLayer(size, outputSize));
            size = outputSize;
            return this;
        }

        public ContextBuilder LeakyReluLayer() {
            return layer(new LeakyRelu(size));
        }

        public ContextBuilder SigmoidLayer() {
            return layer(new Sigmoid(size));
        }

        public ContextBuilder SoftmaxLayer() {
            return layer(new Softmax(size));
        }

        public ContextBuilder TanhLayer() {
            return layer(new Tanh(size));
        }

        public ContextBuilder DropoutLayer(double dropout) {
            return layer(new DropoutLayer(size).dropout(dropout));
        }

        public ContextBuilder JudgeLayer() {
            hasOutput = true;
            return layer(new JudgeLayer(size));
        }

        @Override
        public ContextBuilder setRandomGenerator(RandomSeedGenerator randomGenerator) {
            super.setRandomGenerator(randomGenerator);
            return this;
        }

        @Override
        public Model build() {
            if (!hasOutput) JudgeLayer();
            return super.build();
        }
    }
}