dataloader 源码_DataLoader

dataloader 源码_DataLoaderimportpaddle.fluidasfluidimportnumpyasnpBATCH_NUM=10BATCH_SIZE=16EPOCH_NUM=4CLASS_NUM=10ITERABLE=True#whetherthecreatedDataLoaderobjectisiterableUSE_GPU=False#whethertous…

大家好,又见面了,我是你们的朋友全栈君。

import paddle.fluid as fluid

import numpy as np

BATCH_NUM = 10

BATCH_SIZE = 16

EPOCH_NUM = 4

CLASS_NUM = 10

ITERABLE = True # whether the created DataLoader object is iterable

USE_GPU = False # whether to use GPU

DATA_FORMAT = ‘batch_generator’ # data format of data source user provides

def simple_net(image, label):

fc_tmp = fluid.layers.fc(image, size=CLASS_NUM)

cross_entropy = fluid.layers.softmax_with_cross_entropy(image, label)

loss = fluid.layers.reduce_mean(cross_entropy)

sgd = fluid.optimizer.SGD(learning_rate=1e-3)

sgd.minimize(loss)

return loss

def get_random_images_and_labels(image_shape, label_shape):

image = np.random.random(size=image_shape).astype(‘float32’)

label = np.random.random(size=label_shape).astype(‘int64’)

return image, label

# If the data generator yields one sample each time,

# use DataLoader.set_sample_generator to set the data source.

def sample_generator_creator():

def __reader__():

for _ in range(BATCH_NUM * BATCH_SIZE):

image, label = get_random_images_and_labels([784], [1])

yield image, label

return __reader__

# If the data generator yield list of samples each time,

# use DataLoader.set_sample_list_generator to set the data source.

def sample_list_generator_creator():

def __reader__():

for _ in range(BATCH_NUM):

sample_list = []

for _ in range(BATCH_SIZE):

image, label = get_random_images_and_labels([784], [1])

sample_list.append([image, label])

yield sample_list

return __reader__

# If the data generator yields a batch each time,

# use DataLoader.set_batch_generator to set the data source.

def batch_generator_creator():

def __reader__():

for _ in range(BATCH_NUM):

batch_image, batch_label = get_random_images_and_labels([BATCH_SIZE, 784], [BATCH_SIZE, 1])

yield batch_image, batch_label

return __reader__

# If DataLoader is iterable, use for loop to train the network

def train_iterable(exe, prog, loss, loader):

for _ in range(EPOCH_NUM):

for data in loader():

exe.run(prog, feed=data, fetch_list=[loss])

# If DataLoader is not iterable, use start() and reset() method to control the process

def train_non_iterable(exe, prog, loss, loader):

for _ in range(EPOCH_NUM):

loader.start() # call DataLoader.start() before each epoch starts

try:

while True:

exe.run(prog, fetch_list=[loss])

except fluid.core.EOFException:

loader.reset() # call DataLoader.reset() after catching EOFException

def set_data_source(loader, places):

if DATA_FORMAT == ‘sample_generator’:

loader.set_sample_generator(sample_generator_creator(), batch_size=BATCH_SIZE, drop_last=True, places=places)

elif DATA_FORMAT == ‘sample_list_generator’:

loader.set_sample_list_generator(sample_list_generator_creator(), places=places)

elif DATA_FORMAT == ‘batch_generator’:

loader.set_batch_generator(batch_generator_creator(), places=places)

else:

raise ValueError(‘Unsupported data format’)

image = fluid.layers.data(name=’image’, shape=[784], dtype=’float32′)

label = fluid.layers.data(name=’label’, shape=[1], dtype=’int64′)

# Define DataLoader

loader = fluid.io.DataLoader.from_generator(feed_list=[image, label], capacity=16, iterable=ITERABLE)

# Define network

loss = simple_net(image, label)

# Set data source of DataLoader

#

# If DataLoader is iterable, places must be given and the number of places must be the same with device number.

# – If you are using GPU, call `fluid.cuda_places()` to get all GPU places.

# – If you are using CPU, call `fluid.cpu_places()` to get all CPU places.

#

# If DataLoader is not iterable, places can be None.

places = fluid.cuda_places() if USE_GPU else fluid.cpu_places()

set_data_source(loader, places)

exe = fluid.Executor(places[0])

exe.run(fluid.default_startup_program())

prog = fluid.CompiledProgram(fluid.default_main_program()).with_data_parallel(loss_name=loss.name)

if loader.iterable:

train_iterable(exe, prog, loss, loader)

else:

train_non_iterable(exe, prog, loss, loader)

”’

Users can use return_list = True in dygraph mode.

”’

with fluid.dygraph.guard(places[0]):

loader = fluid.io.DataLoader.from_generator(capacity=2, return_list=True)

set_data_source(loader, places[0])

for image, label in loader():

relu = fluid.layers.relu(image)

assert image.shape == [BATCH_SIZE, 784]

assert label.shape == [BATCH_SIZE, 1]

assert relu.shape == [BATCH_SIZE, 784]

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/132199.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)


相关推荐

  • linq 实现动态 orderby

    推荐:http://www.cnblogs.com/roucheng/p/dushubiji.html

    2021年12月26日
  • linux配置虚拟IP地址方法「建议收藏」

    linux配置虚拟IP地址方法「建议收藏」linux配置虚拟IP地址方法在日常linux管理工作中,需要为应用配置单独的IP地址,以达到主机与应用的分离,在应用切换与迁移过程中可以做到动态切换,特别是在使用HA的时候,这种方案可以保证主机与应用的隔离,对日常的运维有很大的益处.但在有些应用中还没有配置HA,后期需要配置HA时,我们可以先配置虚拟IP给在线的应用使用,这要后期的系统运维可以做到更好的可扩展性.本文主要是对IP地址

    2022年10月20日
  • linux抓包和分析工具_linux tcpdump 抓包

    linux抓包和分析工具_linux tcpdump 抓包实践中,通常在Linux里用tcpdump命令抓包,然后在Windows里用wireshark软件分析包。较通用的tcpdump命令:tcpdump-ieth0-s0-wpackage.cap注[对eth0进行完整数据包抓取,数据包输入保存到当前目录package.cap中,因为没有-c参数限制,须按Ctrl+C停止抓包]—————–

    2022年10月14日
  • 认识单片机-单片机最小系统

    认识单片机-单片机最小系统现在很火的STC类51单片机的最小系统,其中分几部分:1.电源部分为图右上解的电源开头,5V输入给单片机进行供,常用的单片机系统电源电压有5V,3.3V,STC单片机也是有这两种不同电压的片子的,大家在做设计时需先确认系统电压后来进行选择。2.晶振部分,在图的左下角连接到单片机中的X1,晶振是什么作用哪?对单片机来讲,他就是心脏,没了晶振就单片机就没了心跳,就不可能正常运行了,晶振是提供单

  • setPositiveButton,setNegativeButton,setNeutralButton各代表什么意思

    setPositiveButton,setNegativeButton,setNeutralButton各代表什么意思本质上都是三个Button并没有很大的区别:Positive:积极的Negative:否定的Neutral:中性的setPositiveButton表示设置弹框后的确定按钮。setNegativeButton表示设置弹框后的取消按钮,设置的是出现在最右边,一般把最右的button功能设置为“取消”,问也就是调用dlg.dismiss()。setNeutralButton:这个是相当于一个忽略操作的按钮。(中立)…

  • PCB设计资料:看到最后才知道是福利

    PCB设计资料:看到最后才知道是福利

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号