一、数据集介绍 数据格式如下: ```text 6551700932705387022_!101!news_culture!京城最值得你来场文化之旅的博物馆!保利集团,马未都,中国科学技术馆,博物馆,新中国 6552368441838272771!101!_news_cul…
一、数据集介绍
数据格式如下:
1 | 6551700932705387022_!_101_!_news_culture_!_京城最值得你来场文化之旅的博物馆_!_保利集团,马未都,中国科学技术馆,博物馆,新中国 |
1 | log.info("Building model...."); |
1 | vec.fit(); |
1 | 六、CNN网络结构 |
1 | List<String> trainLabelList = new ArrayList<>();// 训练集label |
1 | new FileInputStream(new File("/toutiao_cat_data/toutiao_data_type_word.txt")), "UTF-8")); |
}
1 | map.get(array[0]).add(array[1]);// 将样本中所有数据,按照类别归类 |
}
1 | for (Map.Entry<String, List<String>> entry : map.entrySet()) { |
}
}
1 | int batchSize = 64; |
1 | int cnnLayerFeatureMaps = 50; |
1 | ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder().weightInit(WeightInit.RELU) |
1 | ComputationGraph net = new ComputationGraph(config); |
1 | UIServer uiServer = UIServer.getInstance(); |
1 | // net.setListeners(new ScoreIterationListener(100), |
}
1 | private static DataSetIterator getDataSetIterator(WordVectors wordVectors, int minibatchSize, int maxSentenceLength, |
1 | LabeledSentenceProvider sentenceProvider = new CollectionLabeledSentenceProvider(sentences, lableList, rng); |
1 | return new CnnSentenceDataSetIterator.Builder().sentenceProvider(sentenceProvider).wordVectors(wordVectors) |
}
1 | 代码说明: |
if(sentencesAlongHeight){
featuresMask = Nd4j.create(currMinibatchSize, 1, maxLength, 1);
for (int i = 0; i < currMinibatchSize; i++) {
int sentenceLength = tokenizedSentences.get(i).getFirst().size();
if (sentenceLength >= maxLength) {
featuresMask.slice(i).assign(1.0);
featuresMask.get(NDArrayIndex.point(i), NDArrayIndex.point(0), NDArrayIndex.interval(0, sentenceLength), NDArrayIndex.point(0)).assign(1.0);
1 | } |
featuresMask = Nd4j.create(currMinibatchSize, 1, 1, maxLength);
featuresMask.get(NDArrayIndex.point(i), NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.interval(0, sentenceLength)).assign(1.0);
}
这里为什么有个if呢?生成句子张量的时候,可以任意指定句子的方向,可以沿着矩阵中height的方向,也可以是width的方向,方向不同,填掩模的那一维也就不同。
八、结果
运行了10个Epoch结果如下:
========================Evaluation Metrics========================
of classes: 15
Accuracy: 0.8420
Precision: 0.8362 (1 class excluded from average)
Recall: 0.7783
F1 Score: 0.8346 (1 class excluded from average)
Precision, recall & F1: macro-averaged (equally weighted avg. of 15 classes)
Warning: 1 class was never predicted by the model and was excluded from average precision
Classes excluded from average precision: [12]
=========================Confusion Matrix=========================
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14
973 35 114 2 9 8 11 19 14 6 19 11 0 22 13 | 0 = 0
17 4636 250 37 51 16 14 151 47 29 232 36 0 82 44 | 1 = 1
103 176 6980 108 16 8 31 62 83 41 53 77 0 36 163 | 2 = 2
9 78 244 6692 37 9 52 59 33 27 57 54 0 10 96 | 3 = 3
7 52 36 31 4072 96 101 107 581 20 64 108 0 135 37 | 4 = 4
12 18 22 8 150 3061 27 36 53 2 100 16 0 56 2 | 5 = 5
17 38 71 26 94 13 6443 43 174 31 121 39 0 32 34 | 6 = 6
17 157 93 49 62 20 34 4793 85 14 58 36 0 49 31 | 7 = 7
1 45 71 21 436 30 195 138 7018 48 54 49 0 45 148 | 8 = 8
24 74 84 47 24 1 57 50 68 3963 45 431 0 9 65 | 9 = 9
9 165 90 21 40 37 61 40 42 21 3428 111 0 78 30 | 10 = 10
47 78 173 52 114 20 48 67 93 320 140 4097 0 48 29 | 11 = 11
0 0 0 0 60 0 1 0 5 0 0 0 0 0 0 | 12 = 12
35 105 31 6 139 37 34 61 79 11 153 35 0 3187 12 | 13 = 13
14 36 210 128 31 2 19 20 164 44 38 15 0 19 5183 | 14 = 14
平均准确率0.8420,比原资源中给定的结果略好,F1 score要略差一点,混淆矩阵中,有一个类别,无法被预测到,是因为样本中改类别数据量本身很少,难以抓到共性特征。这里参数如果精心调节一番,迭代更多次数,理论上会有更好的表现。
九、后记
读Deeplearning4j是一种享受,优雅的架构,清晰的逻辑,多种设计模式,扩展性强,将有后续博客,对dl4j源码进行剖析。
快乐源于分享。
此博客乃作者原创, 转载请注明出处
本文标题: DL4J之CNN对今日头条新闻分类
发布时间: 2019年02月13日 00:00
最后更新: 2025年12月30日 08:54
原始链接: https://haoxiang.eu.org/739bf47e/
版权声明: 本文著作权归作者所有,均采用CC BY-NC-SA 4.0许可协议,转载请注明出处!

