本文共 1764 字,大约阅读时间需要 5 分钟。
TensorFlow高层次机器学习API (tf.contrib.learn)
1.tf.contrib.learn.datasets.base.load_csv_with_header 加载csv格式数据
2.tf.contrib.learn.DNNClassifier 建立DNN模型(classifier)
3.classifer.fit 训练模型
4.classifier.evaluate 评价模型
5.classifier.predict 预测新样本
完整代码:
1 from __future__ import absolute_import 2 from __future__ import division 3 from __future__ import print_function 4 5 import tensorflow as tf 6 import numpy as np 7 8 # Data sets 9 IRIS_TRAINING = "iris_training.csv" 10 IRIS_TEST = "iris_test.csv" 11 12 # Load datasets. 13 training_set = tf.contrib.learn.datasets.base.load_csv_with_header( 14 filename=IRIS_TRAINING, 15 target_dtype=np.int, 16 features_dtype=np.float32) 17 test_set = tf.contrib.learn.datasets.base.load_csv_with_header( 18 filename=IRIS_TEST, 19 target_dtype=np.int, 20 features_dtype=np.float32) 21 22 # Specify that all features have real-value data 23 feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)] 24 25 # Build 3 layer DNN with 10, 20, 10 units respectively. 26 classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns, 27 hidden_units=[10, 20, 10], 28 n_classes=3, 29 model_dir="/tmp/iris_model") 30 31 # Fit model. 32 classifier.fit(x=training_set.data, 33 y=training_set.target, 34 steps=2000) 35 36 # Evaluate accuracy. 37 accuracy_score = classifier.evaluate(x=test_set.data, 38 y=test_set.target)["accuracy"] 39 print('Accuracy: {0:f}'.format(accuracy_score)) 40 41 # Classify two new flower samples. 42 new_samples = np.array( 43 [[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float) 44 y = list(classifier.predict(new_samples, as_iterable=True)) 45 print('Predictions: {}'.format(str(y)))
结果:
Accuracy:0.966667
本文转自张昺华-sky博客园博客,原文链接:http://www.cnblogs.com/bonelee/p/7903436.html,如需转载请自行联系原作者