博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
TensorFlow高层次机器学习API (tf.contrib.learn)
阅读量:7048 次
发布时间:2019-06-28

本文共 1764 字,大约阅读时间需要 5 分钟。

TensorFlow高层次机器学习API (tf.contrib.learn)

1.tf.contrib.learn.datasets.base.load_csv_with_header 加载csv格式数据

2.tf.contrib.learn.DNNClassifier 建立DNN模型(classifier)

3.classifer.fit 训练模型

4.classifier.evaluate 评价模型

5.classifier.predict 预测新样本

完整代码:

1 from __future__ import absolute_import 2 from __future__ import division 3 from __future__ import print_function 4 5 import tensorflow as tf 6 import numpy as np 7 8 # Data sets 9 IRIS_TRAINING = "iris_training.csv" 10 IRIS_TEST = "iris_test.csv" 11 12 # Load datasets. 13 training_set = tf.contrib.learn.datasets.base.load_csv_with_header( 14 filename=IRIS_TRAINING, 15 target_dtype=np.int, 16 features_dtype=np.float32) 17 test_set = tf.contrib.learn.datasets.base.load_csv_with_header( 18 filename=IRIS_TEST, 19 target_dtype=np.int, 20 features_dtype=np.float32) 21 22 # Specify that all features have real-value data 23 feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)] 24 25 # Build 3 layer DNN with 10, 20, 10 units respectively. 26 classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns, 27 hidden_units=[10, 20, 10], 28 n_classes=3, 29 model_dir="/tmp/iris_model") 30 31 # Fit model. 32 classifier.fit(x=training_set.data, 33 y=training_set.target, 34 steps=2000) 35 36 # Evaluate accuracy. 37 accuracy_score = classifier.evaluate(x=test_set.data, 38 y=test_set.target)["accuracy"] 39 print('Accuracy: {0:f}'.format(accuracy_score)) 40 41 # Classify two new flower samples. 42 new_samples = np.array( 43 [[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float) 44 y = list(classifier.predict(new_samples, as_iterable=True)) 45 print('Predictions: {}'.format(str(y)))

 结果:

Accuracy:0.966667

本文转自张昺华-sky博客园博客,原文链接:http://www.cnblogs.com/bonelee/p/7903436.html,如需转载请自行联系原作者

你可能感兴趣的文章
What are words
查看>>
android之bundle传递数据--两个activities之间
查看>>
centos You don't have permission to access 解决
查看>>
WPF仿windows图片查看器(附源码)
查看>>
我的友情链接
查看>>
ubuntu 超级管理员修改Mysql数据库密码
查看>>
社会化分享功能百度分享代码示例
查看>>
我的友情链接
查看>>
java爬虫学习日记1-基本爬虫原理介绍
查看>>
bash的功能简介
查看>>
Python中的and和or
查看>>
Linux下TFTP+NFS+PXE安装FreeBSD操作系统
查看>>
企业网络部署和运维
查看>>
win7/win8右键在目录当前打开命令cmd窗口
查看>>
定时任务1.基本原理
查看>>
linux文件操作之系统调用
查看>>
《飞机大战》安卓游戏开发源码(二)
查看>>
2017总结、计划——IT人应该拥有什么样子的价值观与实践能力
查看>>
Linux设备驱动入门之hello驱动
查看>>
vim 的一些简单使用
查看>>