TensorFlow同步训练原理
一、同步训练例子:
#coding=utf-8
#python sync_dist_train.py --job_name=ps --task_index=0 --issync=1
#python sync_dist_train.py --job_name=worker --task_index=0 --issync=1
#python sync_dist_train.py --job_name=worker --task_index=1 --issync=1
import time
import numpy as np
import tensorflow as tf
# Define parameters
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_float('learning_rate', 0.00003, 'Initial learning rate.')
tf.app.flags.DEFINE_integer('steps_to_validate', 1000,
'Steps to validate and print loss')
# For distributed
tf.app.flags.DEFINE_string("ps_hosts", "127.0.0.1:2222",
"Comma-separated list of hostname:port pairs")
tf.app.flags.DEFINE_string("worker_hosts", "127.0.0.1:2224,127.0.0.1:2225,127.0.0.1:2226",
"Comma-separated list of hostname:port pairs")
tf.app.flags.DEFINE_string("job_name", "worker", "One of 'ps', 'worker'")
tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")
tf.app.flags.DEFINE_integer("issync", 0, "是否采用分布式的同步模式,1表示同步模式,0表示异步模式")
# Hyperparameters
learning_rate = FLAGS.learning_rate
steps_to_validate = FLAGS.steps_to_validate
def main(_):
ps_hosts = FLAGS.ps_hosts.split(",")
worker_hosts = FLAGS.worker_hosts.split(",")
cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
server = tf.train.Server(cluster,job_name=FLAGS.job_name,task_index=FLAGS.task_index)
worker_count = len(worker_hosts)
issync = FLAGS.issync
if FLAGS.job_name == "ps":
server.join()
elif FLAGS.job_name == "worker":
with tf.device(tf.train.replica_device_setter(
worker_device="/job:worker/task:%d" % FLAGS.task_index,
cluster=cluster)):
global_step = tf.Variable(0, name='global_step', trainable=False)
X = tf.placeholder(tf.float32)
Y = tf.placeholder(tf.float32)
w = tf.Variable(0.0, name="weight")
b = tf.Variable(0.0, name="reminder")
y = w * X + b
loss = tf.reduce_mean(tf.square(y - Y))
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
if issync == 1:
#同步模式
optimizer = tf.train.SyncReplicasOptimizer(optimizer,
replicas_to_aggregate=len(worker_hosts),
total_num_replicas=len(worker_hosts),
use_locking=True)
sync_replicas_hook = optimizer.make_session_run_hook(FLAGS.task_index == 0)
#更新梯度
train_op = optimizer.minimize(loss, global_step=global_step)
hooks = [tf.train.StopAtStepHook(last_step=1000000)]
if issync == 1:
hooks.append(sync_replicas_hook)
with tf.train.MonitoredTrainingSession(
master=server.target, is_chief=(FLAGS.task_index == 0),
checkpoint_dir="./train_logs", hooks=hooks) as mon_sess:
while not mon_sess.should_stop():
train_x = np.random.randn(1)
train_y = 2 * train_x + np.random.randn(1) * 0.33 + 10
_, loss_v, step = mon_sess.run([train_op, loss,global_step], feed_dict={X:train_x, Y:train_y})
if step % steps_to_validate == 0:
w_,b_ = mon_sess.run([w,b])
print("step: %d, weight: %f, biase: %f, loss: %f" %(step, w_, b_, loss_v))
if __name__ == "__main__":
tf.app.run()
二、同步训练原理:
TensorFlow的同步训练实现是在tensorflow/python/training/sync_replicas_optimizer.py文件。