1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
| public class Gan { static double lr = 0.01;
public static void main(String[] args) throws Exception {
final NeuralNetConfiguration.Builder builder = new NeuralNetConfiguration.Builder().updater(new Sgd(lr)) .weightInit(WeightInit.XAVIER);
final GraphBuilder graphBuilder = builder.graphBuilder().backpropType(BackpropType.Standard) .addInputs("input1", "input2") .addLayer("g1", new DenseLayer.Builder().nIn(10).nOut(128).activation(Activation.RELU) .weightInit(WeightInit.XAVIER).build(), "input1") .addLayer("g2", new DenseLayer.Builder().nIn(128).nOut(512).activation(Activation.RELU) "g1") .addLayer("g3", new DenseLayer.Builder().nIn(512).nOut(28 * 28).activation(Activation.RELU) "g2") .addVertex("stack", new StackVertex(), "input2", "g3") .addLayer("d1", new DenseLayer.Builder().nIn(28 * 28).nOut(256).activation(Activation.RELU) "stack") .addLayer("d2", new DenseLayer.Builder().nIn(256).nOut(128).activation(Activation.RELU) "d1") .addLayer("d3", new DenseLayer.Builder().nIn(128).nOut(128).activation(Activation.RELU) "d2") .addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.XENT).nIn(128).nOut(1) .activation(Activation.SIGMOID).build(), "d3") .setOutputs("out");
ComputationGraph net = new ComputationGraph(graphBuilder.build()); net.init(); System.out.println(net.summary()); UIServer uiServer = UIServer.getInstance(); StatsStorage statsStorage = new InMemoryStatsStorage(); uiServer.attach(statsStorage); net.setListeners(new ScoreIterationListener(100)); net.getLayers(); DataSetIterator train = new MnistDataSetIterator(30, true, 12345); INDArray labelD = Nd4j.vstack(Nd4j.ones(30, 1), Nd4j.zeros(30, 1));
INDArray labelG = Nd4j.ones(60, 1);
for (int i = 1; i <= 100000; i++) { if (!train.hasNext()) { train.reset(); } INDArray trueExp = train.next().getFeatures(); INDArray z = Nd4j.rand(new long[] { 30, 10 }, new NormalDistribution()); MultiDataSet dataSetD = new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[] { z, trueExp }, new INDArray[] { labelD }); for(int m=0;m<10;m++){ trainD(net, dataSetD); } z = Nd4j.rand(new long[] { 30, 10 }, new NormalDistribution()); MultiDataSet dataSetG = new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[] { z, trueExp }, new INDArray[] { labelG }); trainG(net, dataSetG);
if (i % 10000 == 0) { net.save(new File("E:/gan.zip"), true); }
}
}
public static void trainD(ComputationGraph net, MultiDataSet dataSet) { net.setLearningRate("g1", 0); net.setLearningRate("g2", 0); net.setLearningRate("g3", 0); net.setLearningRate("d1", lr); net.setLearningRate("d2", lr); net.setLearningRate("d3", lr); net.setLearningRate("out", lr); net.fit(dataSet); }
public static void trainG(ComputationGraph net, MultiDataSet dataSet) { net.setLearningRate("g1", lr); net.setLearningRate("g2", lr); net.setLearningRate("g3", lr); net.setLearningRate("d1", 0); net.setLearningRate("d2", 0); net.setLearningRate("d3", 0); net.setLearningRate("out", 0); } ``` 说明: 1、dl4j并没有提供像keras那样冻结某些层参数的方法,这里采用设置learningrate为0的方法,来冻结某些层的参数 2、这个的更新器,用的是sgd,不能用其他的(比方说Adam、Rmsprop),因为这些自适应更新器会考虑前面batch的梯度作为本次更新的梯度,达不到不更新参数的目的 3、这里用了StackVertex,沿着第一维合并张量,也就是合并真实数据样本和Generator产生的数据样本,共同训练Discriminator
4、训练过程中多次update Discriminator的参数,以便量出最大距离,让后更新Generator一次 5、进行10w次迭代 三、Generator生成手写数字 加载训练好的模型,随机从NormalDistribution取出一些噪音数据,丢给模型,经过feedForward,取出最后一层Generator的激活值,便是我们想要的结果,代码如下: ```java public class LoadGan {
|