使用反向传播的多层神经网络的准确性约为正常水平的86%?

如何解决使用反向传播的多层神经网络的准确性约为正常水平的86%?

最近,我开始尝试使用反向传播训练神经网络。网络结构为784-512-10,我使用了Sigmoid激活功能。当我在MNIST数据集上测试单层网络时,我得到了大约90%。使用该多层网络,我的结果大约是86%,这正常吗?反向传播部分我弄错了吗?

这是我的代码:

import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.*;
import java.nio.file.Files;
import java.nio.file.StandardCopyOption;
import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;
import java.util.Scanner;

public class NeuralNetwork{
    public static double learningRate = 0.01;
    public static int epoch = 15;
    public static int ROWS = 28;
    public static int COLUMNS = 28;
    public static int INPUT = ROWS * COLUMNS;
    public static int outNum = 10;
    public static int hiddenNum = 512;
    public static double[][] weights2 = new double[outNum][hiddenNum];
    public static double[] bias2 = new double[outNum];
    public static double[][] weights1 = new double[hiddenNum][INPUT];
    public static double[] bias1 = new double[outNum];
    private static final double TRAININGSIZE = 10;
    public static double[][] inputs = new double[outNum][INPUT];
    private static final double[][] target = new double[outNum][outNum];

    private static final ArrayList<String> filenames = new ArrayList<>();
    private static final ArrayList<Integer> yetDone = new ArrayList<>();
    public static double[] actual = new double[outNum];

    public static Random rand = new SecureRandom();

    public static Scanner input = new Scanner(System.in);
    public static void main(String[]args) throws Exception {
        System.out.println("1. Learn the network");
        System.out.println("2. Guess a number");
        System.out.println("3. Guess file");
        System.out.println("4. Guess All Numbers");
        System.out.println("5. Guess image");
        switch (input.nextInt()){
            case 1:
                learn();
                break;
            case 2:
                guess();
                break;
            case 3:
                guessFile();
                break;
            case 4:
                guessAll();
                break;
        }
    }

    public static void guessAll() throws IOException,ClassNotFoundException {
        System.out.println("Recognizing...");
        /*
        for(int x = 1; x < 60000; x++){
            filenames.add("data/" + String.format("%05d",x) + ".txt");
        }

        ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream("network.ser")));
        Layers lay = (Layers) ois.readObject();
        int correct = 0;
        for (String z : filenames) {
            double[] a = scan(z,0);
            correct += getBestGuess(sigmoid(lay.step(a))) == actual[0] ? 1 : 0;
        }
        System.out.println("Training: " + correct + " / " + filenames.size() + " correct.");
        filenames.clear();

         */

        for(int x = 60000; x < 70000; x++){
            filenames.add("data/" + String.format("%05d",x) + ".txt");
        }

        ObjectInputStream oiss = new ObjectInputStream(new BufferedInputStream(new FileInputStream("network1.ser")));
        Layers lays1 = (Layers) oiss.readObject();
        ObjectInputStream oiss2 = new ObjectInputStream(new BufferedInputStream(new FileInputStream("network2.ser")));
        Layers lays2 = (Layers) oiss2.readObject();
        int corrects = 0;
        for (String z : filenames) {
            double[] a = scan(z,0);
            corrects += getBestGuess(sigmoid(lays2.step(sigmoid(lays1.step(a))))) == actual[0] ? 1 : 0;
        }
        System.out.println("Testing: " + corrects + " / " + filenames.size() + " correct.");
        
        System.out.println("Done!");
    }

    public static void makeList(){
        for(int index = 0; index < TRAININGSIZE; index++){
            int indices = rand.nextInt(yetDone.size() - 1) + 1;
            filenames.add("data/" + String.format("%05d",yetDone.get(indices)) + ".txt");
            yetDone.remove(indices);
        }
        prepareData();
        for(int indices = 0; indices < outNum; indices++) {
            for(int index = 0; index < outNum; index++){
                target[indices][index] = 0;
            }
            target[indices][(int)actual[indices]] = 1;
        }
    }

    public static void prepareData(){
        for(int index = 0; index < outNum; index++){
            try {
                inputs[index] = scan(filenames.get(index),index);
            } catch (FileNotFoundException ex) {
                ex.printStackTrace();
            }
        }
    }

    public static double[] scan(String filename,int index) throws FileNotFoundException {
        Scanner in = new Scanner(new File(filename));
        double[] a = new double[INPUT];
        for(int i = 0; i < INPUT; i++){
            a[i] = in.nextDouble() / 255;
        }
        actual[index] = in.nextDouble();
        return a;
    }

    public static void guessFile() throws IOException,ClassNotFoundException {
        System.out.print("Enter Filename: ");
        double[] a = scan(input.next(),0);
        ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream("network.ser")));
        Layers lay = (Layers) ois.readObject();
        double[] results = lay.step(a);
        System.out.println("This is a " + getBestGuess(sigmoid(results)) + "!");
        System.out.println(Arrays.toString(results));
    }

    public static double guess(double[] a) throws IOException,ClassNotFoundException {
        ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream("network.ser")));
        Layers lay = (Layers) ois.readObject();
        double[] results = lay.step(a);
        return getBestGuess(sigmoid(results));
    }

    public static void guess() throws IOException,ClassNotFoundException {
        ObjectInputStream ois = new ObjectInputStream(new BufferedInputStream(new FileInputStream("network.ser")));
        System.out.println("Input number: ");
        Layers lay = (Layers) ois.readObject();
        double[] a = new double[INPUT];
        for(int index = 0; index < a.length; index++){
            a[index] = input.nextInt();
        }
        double[] results = lay.step(a);
        System.out.println("This is a " + getBestGuess(sigmoid(results)) + "!");
        System.out.println(Arrays.toString(sigmoid(results)));
    }

    public static void learn() {
        System.out.println("Learning...");
        initialise(weights2,outNum,hiddenNum);
        initialise(bias2);
        initialise(weights1,hiddenNum,INPUT);
        initialise(bias1);

        Layers lay2 = new Layers(weights2,bias2,hiddenNum);
        Layers lay1 = new Layers(weights1,bias1,INPUT);

        double[] result2 = new double[lay2.outNum];
        double[] result1 = new double[lay1.outNum];
        double[] a2;
        double[] a1;
        double cost = 0;
        double sumFinal;

        for(int x = 0; x < epoch; x++) {
            yetDone.clear();
            for(int y = 0; y < 60000; y++){
                yetDone.add(y);
            }

            for (int ind = 0; ind < 200; ind++) {
                filenames.clear();
                makeList();
                for (int n = 0; n < lay2.outNum; n++) {
                    a1 = inputs[n]; //number
                    result1 = sigmoid(lay1.step(a1));
                    a2 = result1;
                    result2 = sigmoid(lay2.step(a2));

                    for (int i = 0; i < lay2.outNum; i++) {
                        for (int j = 0; j < lay2.INPUT; j++) {
                            weights2[i][j] += learningRate * a2[j] * (target[n][i] - result2[i]);
                            cost += Math.pow((target[n][i] - result2[i]),2);
                        }
                    }

                    for(int i = 0; i < lay1.outNum; i++){
                        for(int j = 0; j < lay1.INPUT; j++){
                            sumFinal = 0;
                            for(int k = 0; k < lay2.outNum; k++){
                                // weight * derivSigma(outputHiddenLayer) * 2(out - expected)
                                sumFinal += result1[k] * (1 - result1[k]) * 2 * (result2[k] - target[n][k]); // * weights2[k][i]
                            }
                            weights1[i][j] -= learningRate * a1[j] * sumFinal * result1[i] * (1 - result1[i]);
                        }
                    }
                }
                lay1.update(weights1,bias1);
                lay2.update(weights2,bias2);
            }
            System.out.println("Epoch " + x + ": " + cost);
            cost = 0;
        }
        System.out.println(Arrays.toString(result1));
        System.out.println(Arrays.toString(result2));

        for(double[] arr : inputs) {
            System.out.println("This is a " + getBestGuess(sigmoid(lay2.step(sigmoid(lay1.step(arr))))) + "!");
        }

        try (ObjectOutputStream oos = new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream("network1.ser")))) {
            oos.writeObject(lay1);
        } catch (IOException ex) {
            ex.printStackTrace();
        }

        try (ObjectOutputStream oos = new ObjectOutputStream(new BufferedOutputStream(new FileOutputStream("network2.ser")))) {
            oos.writeObject(lay2);
        } catch (IOException ex) {
            ex.printStackTrace();
        }

        System.out.println("Done! Saved to file.");
    }

    public static double sigmoid(double x){
        return 1 / (1 + Math.exp(-x));
    }

    public static double[] sigmoid(double[] weights){
        for(int index = 0; index < weights.length; index++){
            weights[index] = sigmoid(weights[index]);
        }
        return weights;
    }

    public static void initialise(double[] bias){
        Random random = new Random();
        for(int index = 0; index < bias.length; index++){
            bias[index] = random.nextGaussian();
        }
    }

    public static void initialise(double[][] weights,int outNum,int INPUT){
        Random random = new Random();
        for(int index = 0; index < outNum; index++){
            for(int indice = 0; indice < INPUT; indice++){
                weights[index][indice] = random.nextGaussian();
            }
        }
    }

    public static int getBestGuess(double[] result){
        double k = Integer.MIN_VALUE;
        double index = 0;
        int current = 0;
        for(double a : result){
            if(k < a){
                k = a;
                index = current;
            }
            current++;
        }

        return (int)index;
    }
}

class Layers implements Serializable {
    private static final long serialVersionUID = 8L;
    double[][] weights;
    double[] bias;
    int outNum;
    int INPUT;

    public Layers(double[][] weights,double[] bias,int INPUT){
        this.weights = weights;
        this.bias = bias;
        this.outNum = outNum;
        this.INPUT = INPUT;
    }

    public void update(double[][] weights,double[] bias){
        this.weights = weights;
        this.bias = bias;
    }

    public double[] step(double[] aa){
        double[] out = new double[outNum];
        for (int index = 0; index < outNum; index++) {
            for (int indices = 0; indices < INPUT; indices++) {
                out[index] += weights[index][indices] * aa[indices];
            }
        }
        return out;
    }
}

谢谢!

版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。

相关推荐


依赖报错 idea导入项目后依赖报错,解决方案:https://blog.csdn.net/weixin_42420249/article/details/81191861 依赖版本报错:更换其他版本 无法下载依赖可参考:https://blog.csdn.net/weixin_42628809/a
错误1:代码生成器依赖和mybatis依赖冲突 启动项目时报错如下 2021-12-03 13:33:33.927 ERROR 7228 [ main] o.s.b.d.LoggingFailureAnalysisReporter : *************************** APPL
错误1:gradle项目控制台输出为乱码 # 解决方案:https://blog.csdn.net/weixin_43501566/article/details/112482302 # 在gradle-wrapper.properties 添加以下内容 org.gradle.jvmargs=-Df
错误还原:在查询的过程中,传入的workType为0时,该条件不起作用 &lt;select id=&quot;xxx&quot;&gt; SELECT di.id, di.name, di.work_type, di.updated... &lt;where&gt; &lt;if test=&qu
报错如下,gcc版本太低 ^ server.c:5346:31: 错误:‘struct redisServer’没有名为‘server_cpulist’的成员 redisSetCpuAffinity(server.server_cpulist); ^ server.c: 在函数‘hasActiveC
解决方案1 1、改项目中.idea/workspace.xml配置文件,增加dynamic.classpath参数 2、搜索PropertiesComponent,添加如下 &lt;property name=&quot;dynamic.classpath&quot; value=&quot;tru
删除根组件app.vue中的默认代码后报错:Module Error (from ./node_modules/eslint-loader/index.js): 解决方案:关闭ESlint代码检测,在项目根目录创建vue.config.js,在文件中添加 module.exports = { lin
查看spark默认的python版本 [root@master day27]# pyspark /home/software/spark-2.3.4-bin-hadoop2.7/conf/spark-env.sh: line 2: /usr/local/hadoop/bin/hadoop: No s
使用本地python环境可以成功执行 import pandas as pd import matplotlib.pyplot as plt # 设置字体 plt.rcParams[&#39;font.sans-serif&#39;] = [&#39;SimHei&#39;] # 能正确显示负号 p
错误1:Request method ‘DELETE‘ not supported 错误还原:controller层有一个接口,访问该接口时报错:Request method ‘DELETE‘ not supported 错误原因:没有接收到前端传入的参数,修改为如下 参考 错误2:cannot r
错误1:启动docker镜像时报错:Error response from daemon: driver failed programming external connectivity on endpoint quirky_allen 解决方法:重启docker -&gt; systemctl r
错误1:private field ‘xxx‘ is never assigned 按Altʾnter快捷键,选择第2项 参考:https://blog.csdn.net/shi_hong_fei_hei/article/details/88814070 错误2:启动时报错,不能找到主启动类 #
报错如下,通过源不能下载,最后警告pip需升级版本 Requirement already satisfied: pip in c:\users\ychen\appdata\local\programs\python\python310\lib\site-packages (22.0.4) Coll
错误1:maven打包报错 错误还原:使用maven打包项目时报错如下 [ERROR] Failed to execute goal org.apache.maven.plugins:maven-resources-plugin:3.2.0:resources (default-resources)
错误1:服务调用时报错 服务消费者模块assess通过openFeign调用服务提供者模块hires 如下为服务提供者模块hires的控制层接口 @RestController @RequestMapping(&quot;/hires&quot;) public class FeignControl
错误1:运行项目后报如下错误 解决方案 报错2:Failed to execute goal org.apache.maven.plugins:maven-compiler-plugin:3.8.1:compile (default-compile) on project sb 解决方案:在pom.
参考 错误原因 过滤器或拦截器在生效时,redisTemplate还没有注入 解决方案:在注入容器时就生效 @Component //项目运行时就注入Spring容器 public class RedisBean { @Resource private RedisTemplate&lt;String
使用vite构建项目报错 C:\Users\ychen\work&gt;npm init @vitejs/app @vitejs/create-app is deprecated, use npm init vite instead C:\Users\ychen\AppData\Local\npm-