高级：

# 实现线性回归：

## 案例代码：

def linear_regression():
"""
自实现一个线性回归
:return:
"""
with tf.compat.v1.variable_scope("prepare_data"):
# 1）准备数据
X = tf.compat.v1.random_normal(shape=[100, 1], name="feature")
y_true = tf.matmul(X, [[0.8]]) + 0.7

with tf.compat.v1.variable_scope("create_model"):
# 2）构造模型
# 定义模型参数 用 变量
weights = tf.Variable(initial_value=tf.compat.v1.random_normal(shape=[1, 1]), name="Weights")
bias = tf.Variable(initial_value=tf.compat.v1.random_normal(shape=[1, 1]), name="Bias")
y_predict = tf.matmul(X, weights) + bias

with tf.compat.v1.variable_scope("loss_function"):
# 3）构造损失函数
error = tf.reduce_mean(tf.square(y_predict - y_true))

with tf.compat.v1.variable_scope("optimizer"):
# 4）优化损失

# 2_收集变量
tf.summary.scalar("error", error)
tf.summary.histogram("weights", weights)
tf.summary.histogram("bias", bias)

# 3_合并变量
merged = tf.compat.v1.summary.merge_all()

# 创建Saver对象
saver = tf.compat.v1.train.Saver()

# 显式地初始化变量
init = tf.compat.v1.global_variables_initializer()

# 开启会话
with tf.compat.v1.Session() as sess:
# 初始化变量
sess.run(init)

# 1_创建事件文件
file_writer = tf.compat.v1.summary.FileWriter("./tmp/linear", graph=sess.graph)

# 查看初始化模型参数之后的值
print("训练前模型参数为：权重%f，偏置%f，损失为%f" % (weights.eval(), bias.eval(), error.eval()))

# 开始训练
# for i in range(100):
#     sess.run(optimizer)
#     print("第%d次训练后模型参数为：权重%f，偏置%f，损失为%f" % (i+1, weights.eval(), bias.eval(), error.eval()))
#
#     # 运行合并变量操作
#     summary = sess.run(merged)
#     # 将每次迭代后的变量写入事件文件
#
#     # 保存模型
#     if i % 10 ==0:
#         saver.save(sess, "./tmp/model/my_linear.ckpt")
# 加载模型
if os.path.exists("./tmp/model/checkpoint"):
saver.restore(sess, "./tmp/model/my_linear.ckpt")

print("训练后模型参数为：权重%f，偏置%f，损失为%f" % (weights.eval(), bias.eval(), error.eval()))

return None

## 命令行参数：

# 1）定义命令行参数
tf.compat.v1.app.flags.DEFINE_integer("max_step", 100, "训练模型的步数")
tf.compat.v1.app.flags.DEFINE_string("model_dir", "Unknown", "模型保存的路径+模型名字")

# 2）简化变量名
FLAGS = tf.compat.v1.app.flags.FLAGS

def command_demo():
"""
命令行参数演示
:return:
"""
print("max_step:\n", FLAGS.max_step)
print("model_dir:\n", FLAGS.model_dir)

return None

(0)
(0)