网站开发手册下载,win淘宝客wordpress主题模板,湘潭seo优化,线上广告推广文章目录 1.minist数据集2.依赖包3.手写数字训练与推理4. 扩展阅读deeplearning4j自带学习案例项目deeplearning4j-examples 1.minist数据集
下载链接 6W训练集#xff0c;1W测试集
2.依赖包
主要是deeplearning4j、javacv的一些包#xff0c;案例打出的jar包1.3G,pom来自… 文章目录 1.minist数据集2.依赖包3.手写数字训练与推理4. 扩展阅读deeplearning4j自带学习案例项目deeplearning4j-examples 1.minist数据集
下载链接 6W训练集1W测试集
2.依赖包
主要是deeplearning4j、javacv的一些包案例打出的jar包1.3G,pom来自github deeplearning子项目deeplearning4j-examples 的dl4j-examples模块
?xml version1.0 encodingUTF-8?
project xmlnshttp://maven.apache.org/POM/4.0.0 xmlns:xsihttp://www.w3.org/2001/XMLSchema-instancexsi:schemaLocationhttp://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsdmodelVersion4.0.0/modelVersionparentgroupIdorg.springframework.boot/groupIdartifactIdspring-boot-starter-parent/artifactIdversion2.7.9/versionrelativePath//parentgroupIdcom.example/groupIdartifactIddemo/artifactIdversion0.0.1-SNAPSHOT/versionnamedemo/namedescriptiondemo/descriptionpropertiesdl4j-master.version1.0.0-M2.1/dl4j-master.versionnd4j.backendnd4j-native/nd4j.backendjava.version17/java.versionmaven-compiler-plugin.version3.8.1/maven-compiler-plugin.versionmaven.minimum.version3.3.1/maven.minimum.versionexec-maven-plugin.version1.4.0/exec-maven-plugin.versionmaven-shade-plugin.version2.4.3/maven-shade-plugin.versionjcommon.version1.0.23/jcommon.versionjfreechart.version1.0.13/jfreechart.versionlogback.version1.1.7/logback.versionproject.build.sourceEncodingUTF-8/project.build.sourceEncodingjunit.version5.8.0-M1/junit.versionjavacv.version1.5.9/javacv.version/propertiesdependencyManagementdependenciesdependencygroupIdorg.bytedeco/groupIdartifactIdjavacv-platform/artifactIdversion${javacv.version}/version/dependency/dependencies/dependencyManagementdependenciesdependencygroupIdorg.springframework.boot/groupIdartifactIdspring-boot-starter/artifactId/dependencydependencygroupIdorg.projectlombok/groupIdartifactIdlombok/artifactId/dependencydependencygroupIdorg.springframework.boot/groupIdartifactIdspring-boot-starter-test/artifactIdscopetest/scope/dependencydependencygroupIdorg.nd4j/groupIdartifactId${nd4j.backend}/artifactIdversion${dl4j-master.version}/version/dependencydependencygroupIdorg.datavec/groupIdartifactIddatavec-api/artifactIdversion${dl4j-master.version}/version/dependencydependencygroupIdorg.datavec/groupIdartifactIddatavec-data-image/artifactIdversion${dl4j-master.version}/version/dependencydependencygroupIdorg.datavec/groupIdartifactIddatavec-local/artifactIdversion${dl4j-master.version}/version/dependencydependencygroupIdorg.deeplearning4j/groupIdartifactIddeeplearning4j-datasets/artifactIdversion${dl4j-master.version}/version/dependencydependencygroupIdorg.deeplearning4j/groupIdartifactIddeeplearning4j-core/artifactIdversion${dl4j-master.version}/version/dependencydependencygroupIdorg.deeplearning4j/groupIdartifactIdresources/artifactIdversion${dl4j-master.version}/version/dependencydependencygroupIdorg.deeplearning4j/groupIdartifactIddeeplearning4j-ui/artifactIdversion${dl4j-master.version}/version/dependencydependencygroupIdorg.deeplearning4j/groupIdartifactIddeeplearning4j-zoo/artifactIdversion${dl4j-master.version}/version/dependency!-- ParallelWrapper ParallelInference live here --dependencygroupIdorg.deeplearning4j/groupIdartifactIddeeplearning4j-parallel-wrapper/artifactIdversion${dl4j-master.version}/version/dependency!-- Used in the feedforward/classification/MLP* and feedforward/regression/RegressionMathFunctions example --dependencygroupIdjfree/groupIdartifactIdjfreechart/artifactIdversion${jfreechart.version}/version/dependencydependencygroupIdorg.jfree/groupIdartifactIdjcommon/artifactIdversion${jcommon.version}/version/dependency!-- Used for downloading data in some of the examples --dependencygroupIdorg.apache.httpcomponents/groupIdartifactIdhttpclient/artifactIdversion4.3.5/version/dependencydependencygroupIdch.qos.logback/groupIdartifactIdlogback-classic/artifactIdversion${logback.version}/version/dependencydependencygroupIdorg.bytedeco/groupIdartifactIdjavacv-platform/artifactId/dependencydependencygroupIdorg.nd4j/groupIdartifactIdnd4j-api/artifactIdversion1.0.0-M2.1/version/dependency/dependenciesbuildpluginsplugingroupIdorg.springframework.boot/groupIdartifactIdspring-boot-maven-plugin/artifactId/pluginplugingroupIdorg.apache.maven.plugins/groupIdartifactIdmaven-compiler-plugin/artifactIdconfigurationsource17/sourcetarget17/target/configuration/plugin/plugins/build/project
3.手写数字训练与推理
1个epoch训练耗时100s,准确率达97%,详见代码注释,框架的api做得还比较好用
package ai;import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.model.stats.StatsListener;
import org.deeplearning4j.ui.model.storage.FileStatsStorage;
import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.common.io.Assert;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.schedule.MapSchedule;
import org.nd4j.linalg.schedule.ScheduleType;import java.io.File;
import java.util.Random;Slf4j
public class LeNetMNISTReLu {private static final String DATASET_PATH_BASE D:\\;public static void main(String[] args) throws Exception {int height 28;int width 28;// 黑白图片通道只有一个int channels 1;// 0-9十种数字int outputNum 10;int batchSize 64;// 这里一个epoch耗时约100s3次准确率99%int nEpochs 1;Assert.isTrue(new File(DATASET_PATH_BASE /mnist_png).exists(), 请下载压缩包并解压到 DATASET_PATH_BASE);// 该label生成器会将数据所在父目录名作为label,要求目录名必须为数值这里mnist数据集正好是放在0-9文件夹的ParentPathLabelGenerator labelMaker new ParentPathLabelGenerator();// 归一化(0-1)DataNormalization normalization new ImagePreProcessingScaler();Random random new Random(12345);log.info(训练集6W张...);File trainData new File(DATASET_PATH_BASE /mnist_png/training);FileSplit trainSplit new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, random);ImageRecordReader trainRecordReader new ImageRecordReader(height, width, channels, labelMaker);trainRecordReader.initialize(trainSplit);DataSetIterator trainIter new RecordReaderDataSetIterator(trainRecordReader, batchSize, 1, outputNum);normalization.fit(trainIter);trainIter.setPreProcessor(normalization); // 先像素归一化log.info(验证集1W张...);File validateData new File(DATASET_PATH_BASE /mnist_png/testing);FileSplit validateSplit new FileSplit(validateData, NativeImageLoader.ALLOWED_FORMATS, random);ImageRecordReader validateRecordReader new ImageRecordReader(height, width, channels, labelMaker);validateRecordReader.initialize(validateSplit);DataSetIterator validateIter new RecordReaderDataSetIterator(validateRecordReader, batchSize, 1, outputNum);validateIter.setPreProcessor(normalization);// 训练集6W数据 每次迭代batchSize64,故这里大概有1000次迭代// 学习率每200个迭代更新一次学习率步长先大一点还可以每个Epoch更新一次学习率MapSchedule mapSchedule new MapSchedule.Builder(ScheduleType.ITERATION).add(0, 0.06).add(200, 0.05).add(600, 0.028).add(800, 0.006).add(1000, 0.001).build();// 超参MultiLayerConfiguration conf new NeuralNetConfiguration.Builder().seed(1).l2(0.0005).updater(new Nesterovs(mapSchedule))//.optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT) //该优化器导致长时间无法拟合.weightInit(WeightInit.XAVIER).list().layer(new ConvolutionLayer.Builder(5, 5).nIn(channels).stride(1, 1).nOut(20).activation(Activation.IDENTITY).build()).layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2).build()).layer(new ConvolutionLayer.Builder(5, 5).stride(1, 1).nOut(50).activation(Activation.IDENTITY).build()).layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2).build()).layer(new DenseLayer.Builder().activation(Activation.RELU).nOut(500).build()).layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(outputNum).activation(Activation.SOFTMAX).build()).setInputType(InputType.convolutionalFlat(height, width, channels)) // InputType.convolutional for normal image.build();// 神经网络对象构建MultiLayerNetwork net new MultiLayerNetwork(conf);net.init();// 训练监控每次迭代打印损失函数值net.setListeners(new ScoreIterationListener(10));// WEB UI监控训练过程//UIServer uiServer UIServer.getInstance();//FileStatsStorage statsStorage new FileStatsStorage(new File(D:\\ai-webui.dat));//uiServer.attach(statsStorage);//net.setListeners(new StatsListener(statsStorage));log.info(网络参数个数{}, net.numParams());long startTime System.currentTimeMillis();// 训练epochs轮for (int i 0; i nEpochs; i) {log.info(Epoch i);net.fit(trainIter);Evaluation eval net.evaluate(validateIter);log.info(eval.stats());trainIter.reset();validateIter.reset();}log.info(训练耗时{}毫秒, System.currentTimeMillis() - startTime);// 保存模型File ministModelPath new File(DATASET_PATH_BASE /ministModel.zip);ModelSerializer.writeModel(net, ministModelPath, true);// 推理逻辑加载网络模型——加载测试图片——预测MultiLayerNetwork network ModelSerializer.restoreMultiLayerNetwork(new File(DATASET_PATH_BASE /ministModel.zip));NativeImageLoader imageLoader new NativeImageLoader(height, width, channels);FileUtils.listFiles(new File(D:\\mnist_png\\testing), null, true).parallelStream().forEach(file - {try {INDArray matrix imageLoader.asMatrix(file);INDArray output network.output(matrix);// 取最可能的预测结果int predictedValue Nd4j.argMax(output, 1).getInt(0);// 数字图片按数值放在每个文件夹的故图片所在文件夹名即为真实值String realValue file.getParentFile().getName();log.info(真实值{}预测值{}, realValue, predictedValue);Assert.isTrue(predictedValue Integer.parseInt(realValue), file.getAbsolutePath() 预测错误);} catch (Exception e) {log.warn(e.getMessage(), e);}});}
}4. 扩展阅读deeplearning4j自带学习案例项目deeplearning4j-examples
deeplearning4j-examples 参考其readme文档