pointnet训练文件train.py注释

pointnet训练文件train.py注释

pointnet训练文件pointnet/sem_seg/train.py注释

原文如下:

import argparse
import math
import h5py
import numpy as np
import tensorflow as tf
import socket
import importlib
import os
import sys
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(BASE_DIR)
sys.path.append(os.path.join(BASE_DIR, 'models'))
sys.path.append(os.path.join(BASE_DIR, 'utils'))
import provider
import tf_util

#num_point与所用的数据集有关
#https://blog.csdn.net/chengxuyuanyonghu/article/details/59716405 
#解析命令行参数和选项的标准模块
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=0, help='GPU to use [default: GPU 0]')
parser.add_argument('--model', default='pointnet_cls', help='Model name: pointnet_cls or pointnet_cls_basic [default: pointnet_cls]')
parser.add_argument('--log_dir', default='log', help='Log dir [default: log]')
parser.add_argument('--num_point', type=int, default=1024, help='Point Number [256/512/1024/2048] [default: 1024]')
parser.add_argument('--max_epoch', type=int, default=250, help='Epoch to run [default: 250]')
parser.add_argument('--batch_size', type=int, default=32, help='Batch Size during training [default: 32]')
parser.add_argument('--learning_rate', type=float, default=0.001, help='Initial learning rate [default: 0.001]')
parser.add_argument('--momentum', type=float, default=0.9, help='Initial learning rate [default: 0.9]')
parser.add_argument('--optimizer', default='adam', help='adam or momentum [default: adam]')
parser.add_argument('--decay_step', type=int, default=200000, help='Decay step for lr decay [default: 200000]')
parser.add_argument('--decay_rate', type=float, default=0.7, help='Decay rate for lr decay [default: 0.8]')
    FLAGS = parser.parse_args()
    
 #将定义的参数全部更改为大写字母的形式   
    BATCH_SIZE = FLAGS.batch_size
    NUM_POINT = FLAGS.num_point
    MAX_EPOCH = FLAGS.max_epoch
    BASE_LEARNING_RATE = FLAGS.learning_rate
    GPU_INDEX = FLAGS.gpu
    MOMENTUM = FLAGS.momentum
    OPTIMIZER = FLAGS.optimizer
    DECAY_STEP = FLAGS.decay_step   #计算learning_rate的衰减
    DECAY_RATE = FLAGS.decay_rate   #计算learning_rate的衰减
    
#加载网络模型
#https://blog.csdn.net/xie_0723/article/details/78004649
MODEL = importlib.import_module(FLAGS.model) # import network module
MODEL_FILE = os.path.join(BASE_DIR, 'models', FLAGS.model+'.py')

#记录训练日志,以及备份model.py和train.py
LOG_DIR = FLAGS.log_dir
if not os.path.exists(LOG_DIR): os.mkdir(LOG_DIR)
os.system('cp %s %s' % (MODEL_FILE, LOG_DIR)) # bkp of model def
os.system('cp train.py %s' % (LOG_DIR)) # bkp of train procedure
LOG_FOUT = open(os.path.join(LOG_DIR, 'log_train.txt'), 'w')
LOG_FOUT.write(str(FLAGS)+'\n')

#总共40个语义分类标签#BN_ 开头的4个变量用来计算 Batch Normalization 的Decay参数,即decay参数也随着训练逐渐decay
MAX_NUM_POINT = 2048
NUM_CLASSES = 40

#BN_ 开头的4个变量用来计算 Batch Normalization 的Decay参数,即decay参数也随着训练逐渐decay
BN_INIT_DECAY = 0.5
BN_DECAY_DECAY_RATE = 0.5
BN_DECAY_DECAY_STEP = float(DECAY_STEP)
BN_DECAY_CLIP = 0.99

HOSTNAME = socket.gethostname()

# ModelNet40 official train/test split
TRAIN_FILES = provider.getDataFiles( \
    os.path.join(BASE_DIR, 'data/modelnet40_ply_hdf5_2048/train_files.txt'))
TEST_FILES = provider.getDataFiles(\
    os.path.join(BASE_DIR, 'data/modelnet40_ply_hdf5_2048/test_files.txt'))

#log_string(out_str)函数用来log训练日志
def log_string(out_str):
    LOG_FOUT.write(out_str+'\n')
    LOG_FOUT.flush()    #清空缓存区
    print(out_str)

#计算指数衰减的学习率。训练时学习率最好随着训练衰减。
#decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps)   #此处 global_step = batch * BATCH_SIZE
https://blog.csdn.net/wuguangbin1230/article/details/77658229
def get_learning_rate(batch):
    learning_rate = tf.train.exponential_decay(
                        BASE_LEARNING_RATE,  # Base learning rate.
                        batch * BATCH_SIZE,  # Current index into the dataset.
                        DECAY_STEP,          # Decay step.
                        DECAY_RATE,          # Decay rate.
                        staircase=True)
    learning_rate = tf.maximum(learning_rate, 0.00001) # CLIP THE LEARNING RATE!
    https://blog.csdn.net/YZXnuaa/article/details/79733700
    return learning_rate        

#if the argument staircase is True, then global_step /decay_steps is an integer division and the decayed learning rate follows a staircase function.
#计算衰减的Batch Normalization 的 decay。基本同上。
def get_bn_decay(batch):
    bn_momentum = tf.train.exponential_decay(
                      BN_INIT_DECAY,
                      batch*BATCH_SIZE,
                      BN_DECAY_DECAY_STEP,
                      BN_DECAY_DECAY_RATE,
                      staircase=True)
    bn_decay = tf.minimum(BN_DECAY_CLIP, 1 - bn_momentum)
    return bn_decay

	
#这一段主要是placeholder,以及batch初始化为0。
def train():
    with tf.Graph().as_default():  #表示将这个类实例,也就是新生成的图作为整个 tensorflow 运行环境的默认图,如果只有一个主线程不写也没有关系,tensorflow 里面已经存好了一张默认图,可以使用tf.get_default_graph() 来调用(显示这张默认纸),当你有多个线程就可以创造多个tf.Graph(),就是你可以有一个画图本,有很多张图纸,这时候就会有一个默认图的概念了。
        with tf.device('/gpu:'+str(GPU_INDEX)):   #如果需要切换成CPU运算,可以调用tf.device(device_name)函数,其中device_name格式如/cpu:0其中的0表示设备号,TF不区分CPU的设备号,设置为0即可。GPU区分设备号/gpu:0和/gpu:1表示两张不同的显卡
            pointclouds_pl, labels_pl = MODEL.placeholder_inputs(BATCH_SIZE, NUM_POINT)
            is_training_pl = tf.placeholder(tf.bool, shape=())  #placeholder()函数是在神经网络构建graph的时候在模型中的占位,此时并没有把要输入的数据传入模型,它只会分配必要的内存
            print(is_training_pl)
            
            # Note the global_step=batch parameter to minimize. 
            #globe_step初始化为0,每次自动加1
            # That tells the optimizer to helpfully increment the 'batch' parameter for you every time it trains.
            batch = tf.Variable(0)    #创建一个变量
            bn_decay = get_bn_decay(batch)  
            tf.summary.scalar('bn_decay', bn_decay) 
           https://blog.csdn.net/heiheiya/article/details/80972322

			#预测值为pred,调用model.py 中的 get_model()得到。由get_model()可知,pred的维度为B×N×40,40为Channel数,对应40个分类标签。每个点的这40个值最大的一个的下标即为所预测的分类标签。
            # Get model and loss 
            pred, end_points = MODEL.get_model(pointclouds_pl, is_training_pl, bn_decay=bn_decay)
            loss = MODEL.get_loss(pred, labels_pl, end_points)
            tf.summary.scalar('loss', loss)
			
			#tf.argmax(pred, 2) 返回pred C 这个维度的最大值索引
            #tf.equal() 比较两个张量对应位置是否想等,返回相同维度的bool值矩阵
            correct = tf.equal(tf.argmax(pred, 1), tf.to_int64(labels_pl))
            accuracy = tf.reduce_sum(tf.cast(correct, tf.float32)) / float(BATCH_SIZE)
            #压缩求和,用于降维   https://blog.csdn.net/arjick/article/details/78415675
            tf.summary.scalar('accuracy', accuracy)

			
			#获得衰减后的学习率,以及选择优化器optimizer。
            # Get training operator
            learning_rate = get_learning_rate(batch)
            tf.summary.scalar('learning_rate', learning_rate)
            if OPTIMIZER == 'momentum':
                optimizer = tf.train.MomentumOptimizer(learning_rate, momentum=MOMENTUM)
            elif OPTIMIZER == 'adam':
                optimizer = tf.train.AdamOptimizer(learning_rate)
            train_op = optimizer.minimize(loss, global_step=batch)
            
            # Add ops to save and restore all the variables.
            saver = tf.train.Saver()
			
        #配置session 运行参数。    
        # Create a session
        config = tf.ConfigProto()   #创建sess的时候对sess进行参数配置
        config.gpu_options.allow_growth = True  # =True是让TensorFlow在运行过程中动态申请显存,避免过多的显存占用。
        config.allow_soft_placement = True  #当指定的设备不存在时,允许选择一个存在的设备运行。比如gpu不存在,自动降到cpu上运行。
        config.log_device_placement = False  #在终端打印出各项操作是在哪个设备上运行的
        sess = tf.Session(config=config)

        # Add summary writers
        #merged = tf.merge_all_summaries()
        merged = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'train'),
                                  sess.graph)
        test_writer = tf.summary.FileWriter(os.path.join(LOG_DIR, 'test'))
		
		#初始化参数,开始训练。
		#train_one_epoch 函数用来训练一个epoch,eval_one_epoch函数用来每运行一个epoch后evaluate在测试集的accuracy和loss。每10个epoch保存1次模型。
        # Init variables
        init = tf.global_variables_initializer()
        # To fix the bug introduced in TF 0.12.1 as in
        # http://stackoverflow.com/questions/41543774/invalidargumenterror-for-tensor-bool-tensorflow-0-12-1
        #sess.run(init)
        sess.run(init, {is_training_pl: True})

        ops = {'pointclouds_pl': pointclouds_pl,
               'labels_pl': labels_pl,
               'is_training_pl': is_training_pl,
               'pred': pred,
               'loss': loss,
               'train_op': train_op,
               'merged': merged,
               'step': batch}

        for epoch in range(MAX_EPOCH):
            log_string('**** EPOCH %03d ****' % (epoch))
            sys.stdout.flush()   #在同一个位置刷新输出
             
            train_one_epoch(sess, ops, train_writer)  #train_one_epoch 函数用来训练一个epoch
            eval_one_epoch(sess, ops, test_writer)  #eval_one_epoch函数用来每运行一个epoch后evaluate在测试集的accuracy和loss
            
			# Save the variables to disk.,每10个epoch保存1次模型
            if epoch % 10 == 0:  
                save_path = saver.save(sess, os.path.join(LOG_DIR, "model.ckpt"))
                log_string("Model saved in file: %s" % save_path)


#provider.shuffle_data 函数随机打乱数据,返回打乱后的数据。 
#num_batches = file_size // BATCH_SIZE,计算在指定BATCH_SIZE下,训练1个epoch 需要几个mini-batch训练。
def train_one_epoch(sess, ops, train_writer):
    """ ops: dict mapping from string to tf ops """
    is_training = True
    
    # Shuffle train files
    train_file_idxs = np.arange(0, len(TRAIN_FILES))
    np.random.shuffle(train_file_idxs)
    
    for fn in range(len(TRAIN_FILES)):
        log_string('----' + str(fn) + '-----')
        current_data, current_label = provider.loadDataFile(TRAIN_FILES[train_file_idxs[fn]])
        current_data = current_data[:,0:NUM_POINT,:]
        current_data, current_label, _ = provider.shuffle_data(current_data, np.squeeze(current_label))            
        current_label = np.squeeze(current_label) 
        
        file_size = current_data.shape[0]
        num_batches = file_size // BATCH_SIZE
        
        total_correct = 0
        total_seen = 0
        loss_sum = 0
       #在一个epoch 中逐个mini-batch训练直至遍历完一遍训练集。计算总分类正确数total_correct和已遍历样本数total_senn,总损失loss_sum.
        for batch_idx in range(num_batches):
            start_idx = batch_idx * BATCH_SIZE
            end_idx = (batch_idx+1) * BATCH_SIZE
            
            # Augment batched point clouds by rotation and jittering
            rotated_data = provider.rotate_point_cloud(current_data[start_idx:end_idx, :, :])
            jittered_data = provider.jitter_point_cloud(rotated_data)
            feed_dict = {ops['pointclouds_pl']: jittered_data,
                         ops['labels_pl']: current_label[start_idx:end_idx],
                         ops['is_training_pl']: is_training,}
            summary, step, _, loss_val, pred_val = sess.run([ops['merged'], ops['step'],
                ops['train_op'], ops['loss'], ops['pred']], feed_dict=feed_dict)
            train_writer.add_summary(summary, step)
            pred_val = np.argmax(pred_val, 1)
            correct = np.sum(pred_val == current_label[start_idx:end_idx])
            total_correct += correct
            total_seen += BATCH_SIZE
            loss_sum += loss_val
			
        #记录平均loss,以及平均accuracy。
        log_string('mean loss: %f' % (loss_sum / float(num_batches)))
        log_string('accuracy: %f' % (total_correct / float(total_seen)))

        
def eval_one_epoch(sess, ops, test_writer):
    """ ops: dict mapping from string to tf ops """
    is_training = False
    total_correct = 0
    total_seen = 0
    loss_sum = 0
    total_seen_class = [0 for _ in range(NUM_CLASSES)]
    total_correct_class = [0 for _ in range(NUM_CLASSES)]
    
    for fn in range(len(TEST_FILES)):
        log_string('----' + str(fn) + '-----')
        current_data, current_label = provider.loadDataFile(TEST_FILES[fn])
        current_data = current_data[:,0:NUM_POINT,:]
        current_label = np.squeeze(current_label)
        
        file_size = current_data.shape[0]
        num_batches = file_size // BATCH_SIZE
        
        for batch_idx in range(num_batches):
            start_idx = batch_idx * BATCH_SIZE
            end_idx = (batch_idx+1) * BATCH_SIZE

            feed_dict = {ops['pointclouds_pl']: current_data[start_idx:end_idx, :, :],
                         ops['labels_pl']: current_label[start_idx:end_idx],
                         ops['is_training_pl']: is_training}
            summary, step, loss_val, pred_val = sess.run([ops['merged'], ops['step'],
                ops['loss'], ops['pred']], feed_dict=feed_dict)
            pred_val = np.argmax(pred_val, 1)
            correct = np.sum(pred_val == current_label[start_idx:end_idx])
            total_correct += correct
            total_seen += BATCH_SIZE
            loss_sum += (loss_val*BATCH_SIZE)
            for i in range(start_idx, end_idx):
                l = current_label[i]
                total_seen_class[l] += 1
                total_correct_class[l] += (pred_val[i-start_idx] == l)
            
    log_string('eval mean loss: %f' % (loss_sum / float(total_seen)))
    log_string('eval accuracy: %f'% (total_correct / float(total_seen)))
    log_string('eval avg class acc: %f' % (np.mean(np.array(total_correct_class)/np.array(total_seen_class,dtype=np.float))))
         


if __name__ == "__main__":
    train()
    LOG_FOUT.close()
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/2169.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)


相关推荐

  • webstorm2021激活码_通用破解码

    webstorm2021激活码_通用破解码,https://javaforall.cn/100143.html。详细ieda激活码不妨到全栈程序员必看教程网一起来了解一下吧!

  • 转一篇难得的好文章-CPU流水线的探秘之旅

    转一篇难得的好文章-CPU流水线的探秘之旅作为程序员,CPU在我们的工作中扮演了核心角色,因此了解处理器内部的工作方式对程序员来说不无裨益。  CPU是如何工作的呢?一条指令执行需要多长时间?当我们讨论某个新款处理器拥有12级流水线还是18级流水线,甚至是更深的31级流水线时,这到些都意味着什么呢?  应用程序通常会将CPU看作是黑盒子。程序中的指令按照顺序依次进入CPU,执行完之后再按顺序依次从C

  • 学习大数据需要掌握哪些Java技术

    学习大数据需要掌握哪些Java技术大数据产业已进入发展的”快车道”,急需大量优秀的大数据人才作为后盾。如果你是Java编程出身,那学习大数据自然是锦上添花;但如果你是刚刚接触大数据技术,还在Java编程基础阶段,这篇文章非常值得你看!首先,我们学习大数据,为什么要先掌握Java技术?Java是目前使用非常广泛的编程语言,它具有的众多特性,特别适合作为大数据应用的开发语言。Java不仅吸收了C++语言的各种优点…

  • 推荐系统(Recommendation system )介绍[通俗易懂]

    推荐系统(Recommendation system )介绍[通俗易懂]前言随着电子商务的发展,网络购物成为一种趋势,当你打开某个购物网站比如淘宝、京东的时候,会看到很多给你推荐的产品,你是否觉得这些推荐的产品都是你似曾相识或者正好需要的呢。这个就是现在电子商务里面的推

  • 决策树算法的应用python实现_python怎么画出决策树的分支

    决策树算法的应用python实现_python怎么画出决策树的分支决策数(DecisionTree)在机器学习中也是比较常见的一种算法,属于监督学习中的一种。看字面意思应该也比较容易理解,相比其他算法比如支持向量机(SVM)或神经网络,似乎决策树感觉“亲切”许多。优点:计算复杂度不高,输出结果易于理解,对中间值的缺失值不敏感,可以处理不相关特征数据。缺点:可能会产生过度匹配的问题。使用数据类型:数值型和标称型。简单介绍完毕,让我们来通过一个例子让决策树“

  • JVM成长之路,记录一次内存溢出导致频繁FGC的问题排查及解决「建议收藏」

    JVM成长之路,记录一次内存溢出导致频繁FGC的问题排查及解决「建议收藏」现象:现象截图:内存:命令:jmap-heap30069GC截图:FGC次数19529次!!!何等的恐怖!!!!!命令:jstat-gcutil300691000现象描述:Node模块启动后收到请求却未能响应。一直在频繁的FGC。新生代内

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号