Load Dataset with tensorflow

In this article, we go through a brief hand-on notebook of loading data using tensorflow. Tensorflow can load data directly from local, also can load data stored in google cloud storage if you feed it with a storage bucket path.

import tensorflow as tf
import os

Specify data file path

First import module. Then, specify the path of the data in google storage. Tensorflow will read all csv files under the “gs://bucket_name/folder_name/” path.

  • Directly load the local files:
input_file_names = tf.train.match_filenames_once('file_name.csv')
  • Load from Google Storage:
input_dir = 'gs://bucket_name'
file_prefix = 'folder_name/'
input_file_names = tf.train.match_filenames_once(os.path.join(input_dir, '{}*{}'.format(file_prefix, '.csv')))

Shuffle and read data

Shuffle the input data, skip the first line of the dataset. Put the loaded data into ‘example’ variable.

filename_queue = tf.train.string_input_producer(input_file_names, num_epochs=15, shuffle=True)
reader = tf.TextLineReader(skip_header_lines=1)
_,example = reader.read(filename_queue)

Give column names to data

The data has three columns. First column is numeric, second is a string, and the third is the target. ‘record_defaults’ declares the format of each column. Then we put the three columns into ‘x1, x2, target’.

record_defaults = [[0.], ['red'], ['False']]
c1, c2, c3 = tf.decode_csv(example, record_defaults = record_defaults)
features = {'x1': c1, 'x2':c2, 'target':c3}

Begin tensorflow session…print first 10 lines

Here we begin the tensorflow session. First declare a session. Then set the initializer. In this code, we print out the first 10 lines of the data.

sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())

with sess as sess:
 coord = tf.train.Coordinator()
 threads = tf.train.start_queue_runners(coord = coord)
 for i in range(10):
   print sess.run(features)
 
 coord.request_stop()
 coord.join(threads)
Advertisements