/*
 * Decompiled with CFR 0.152.
 */
package com.hankcs.hanlp.dependency.nnparser;

import com.hankcs.hanlp.dependency.nnparser.Matrix;
import com.hankcs.hanlp.dependency.nnparser.option.LearnOption;
import com.hankcs.hanlp.dependency.nnparser.util.Log;
import java.util.List;
import java.util.Map;

public class NeuralNetworkClassifier {
    Matrix W1;
    Matrix W2;
    Matrix E;
    Matrix b1;
    Matrix saved;
    Matrix grad_W1;
    Matrix grad_W2;
    Matrix grad_E;
    Matrix grad_b1;
    Matrix grad_saved;
    Matrix eg2W1;
    Matrix eg2W2;
    Matrix eg2E;
    Matrix eg2b1;
    double loss;
    double accuracy;
    int embedding_size;
    int hidden_layer_size;
    int nr_objects;
    int nr_feature_types;
    int nr_classes;
    int batch_size;
    int nr_threads;
    boolean fix_embeddings;
    double dropout_probability;
    double lambda;
    double ada_eps;
    double ada_alpha;
    Map<Integer, Integer> precomputation_id_encoder;
    boolean initialized = false;

    void initialize(int _nr_objects, int _nr_classes, int _nr_feature_types, LearnOption opt, List<List<Double>> embeddings, List<Integer> precomputed_features) {
        if (this.initialized) {
            Log.ERROR_LOG("classifier: weight should not be initialized twice!", new Object[0]);
            return;
        }
        this.batch_size = opt.batch_size;
        this.fix_embeddings = opt.fix_embeddings;
        this.dropout_probability = opt.dropout_probability;
        this.lambda = opt.lambda;
        this.ada_eps = opt.ada_eps;
        this.ada_alpha = opt.ada_alpha;
        this.nr_feature_types = _nr_feature_types;
        this.nr_objects = _nr_objects;
        this.nr_classes = _nr_classes;
        this.embedding_size = opt.embedding_size;
        int nrows = this.hidden_layer_size = opt.hidden_layer_size;
        int ncols = this.embedding_size * this.nr_feature_types;
        this.W1 = Matrix.random(nrows, ncols).times(Math.sqrt(6.0 / (double)(nrows + ncols)));
        this.b1 = Matrix.random(nrows, 1).times(Math.sqrt(6.0 / (double)(nrows + ncols)));
        nrows = _nr_classes;
        ncols = this.hidden_layer_size;
        this.W2 = Matrix.random(nrows, ncols).times(Math.sqrt(6.0 / (double)(nrows + ncols)));
        nrows = this.embedding_size;
        ncols = _nr_objects;
        this.E = Matrix.random(nrows, ncols).times(opt.init_range);
        for (int i = 0; i < embeddings.size(); ++i) {
            List<Double> embedding = embeddings.get(i);
            int id = embedding.get(0).intValue();
            for (int j = 1; j < embedding.size(); ++j) {
                this.E.set(j - 1, id, embedding.get(j));
            }
        }
        this.grad_W1 = Matrix.zero(this.W1.getRowDimension(), this.W1.getColumnDimension());
        this.grad_b1 = Matrix.zero(this.b1.rows(), 1);
        this.grad_W2 = Matrix.zero(this.W2.rows(), this.W2.cols());
        this.grad_E = Matrix.zero(this.E.rows(), this.E.cols());
        Map<Integer, Integer> encoder = this.precomputation_id_encoder;
        int rank = 0;
        for (int i = 0; i < precomputed_features.size(); ++i) {
            int fid = precomputed_features.get(i);
            encoder.put(fid, rank++);
        }
        this.saved = Matrix.zero(this.hidden_layer_size, encoder.size());
        this.grad_saved = Matrix.zero(this.hidden_layer_size, encoder.size());
        this.initialize_gradient_histories();
        this.initialized = true;
        this.info();
        Log.INFO_LOG("classifier: size of batch = %d", this.batch_size);
        Log.INFO_LOG("classifier: alpha = %e", this.ada_alpha);
        Log.INFO_LOG("classifier: eps = %e", this.ada_eps);
        Log.INFO_LOG("classifier: lambda = %e", this.lambda);
        Log.INFO_LOG("classifier: fix embedding = %s", this.fix_embeddings ? "true" : "false");
    }

    void initialize_gradient_histories() {
        this.eg2W1 = Matrix.zero(this.W1.rows(), this.W1.cols());
        this.eg2b1 = Matrix.zero(this.b1.rows(), 1);
        this.eg2W2 = Matrix.zero(this.W2.rows(), this.W2.cols());
        this.eg2E = Matrix.zero(this.E.rows(), this.E.cols());
    }

    NeuralNetworkClassifier(Matrix _W1, Matrix _W2, Matrix _E, Matrix _b1, Matrix _saved, Map<Integer, Integer> encoder) {
        this.W1 = _W1;
        this.W2 = _W2;
        this.E = _E;
        this.b1 = _b1;
        this.saved = _saved;
        this.precomputation_id_encoder = encoder;
        this.embedding_size = 0;
        this.hidden_layer_size = 0;
        this.nr_objects = 0;
        this.nr_feature_types = 0;
        this.nr_classes = 0;
    }

    void score(List<Integer> attributes, List<Double> retval) {
        Map<Integer, Integer> encoder = this.precomputation_id_encoder;
        Matrix hidden_layer = Matrix.zero(this.hidden_layer_size, 1);
        int i = 0;
        int off = 0;
        while (i < attributes.size()) {
            int aid = attributes.get(i);
            int fid = aid * this.nr_feature_types + i;
            Integer rep = encoder.get(fid);
            if (rep != null) {
                hidden_layer.plusEquals(this.saved.col(rep));
            } else {
                hidden_layer.plusEquals(this.W1.block(0, off, this.hidden_layer_size, this.embedding_size).times(this.E.col(aid)));
            }
            ++i;
            off += this.embedding_size;
        }
        hidden_layer.plusEquals(this.b1);
        Matrix output = this.W2.times(new Matrix(hidden_layer.cube()));
        retval.clear();
        for (int i2 = 0; i2 < this.nr_classes; ++i2) {
            retval.add(output.get(i2, 0));
        }
    }

    double get_cost() {
        return this.loss;
    }

    double get_accuracy() {
        return this.accuracy;
    }

    void canonical() {
        this.hidden_layer_size = this.b1.rows();
        this.nr_feature_types = this.W1.cols() / this.E.rows();
        this.nr_classes = this.W2.rows();
        this.embedding_size = this.E.rows();
    }

    void info() {
        Log.INFO_LOG("classifier: E(%d,%d)", this.E.rows(), this.E.cols());
        Log.INFO_LOG("classifier: W1(%d,%d)", this.W1.rows(), this.W1.cols());
        Log.INFO_LOG("classifier: b1(%d)", this.b1.rows());
        Log.INFO_LOG("classifier: W2(%d,%d)", this.W2.rows(), this.W2.cols());
        Log.INFO_LOG("classifier: saved(%d,%d)", this.saved.rows(), this.saved.cols());
        Log.INFO_LOG("classifier: precomputed size=%d", this.precomputation_id_encoder.size());
        Log.INFO_LOG("classifier: hidden layer size=%d", this.hidden_layer_size);
        Log.INFO_LOG("classifier: embedding size=%d", this.embedding_size);
        Log.INFO_LOG("classifier: number of classes=%d", this.nr_classes);
        Log.INFO_LOG("classifier: number of feature types=%d", this.nr_feature_types);
    }
}

