dataloader 源码_DataLoader

dataloader 源码_DataLoaderimportpaddle.fluidasfluidimportnumpyasnpBATCH_NUM=10BATCH_SIZE=16EPOCH_NUM=4CLASS_NUM=10ITERABLE=True#whetherthecreatedDataLoaderobjectisiterableUSE_GPU=False#whethertous…

大家好,又见面了,我是你们的朋友全栈君。

import paddle.fluid as fluid

import numpy as np

BATCH_NUM = 10

BATCH_SIZE = 16

EPOCH_NUM = 4

CLASS_NUM = 10

ITERABLE = True # whether the created DataLoader object is iterable

USE_GPU = False # whether to use GPU

DATA_FORMAT = ‘batch_generator’ # data format of data source user provides

def simple_net(image, label):

fc_tmp = fluid.layers.fc(image, size=CLASS_NUM)

cross_entropy = fluid.layers.softmax_with_cross_entropy(image, label)

loss = fluid.layers.reduce_mean(cross_entropy)

sgd = fluid.optimizer.SGD(learning_rate=1e-3)

sgd.minimize(loss)

return loss

def get_random_images_and_labels(image_shape, label_shape):

image = np.random.random(size=image_shape).astype(‘float32’)

label = np.random.random(size=label_shape).astype(‘int64’)

return image, label

# If the data generator yields one sample each time,

# use DataLoader.set_sample_generator to set the data source.

def sample_generator_creator():

def __reader__():

for _ in range(BATCH_NUM * BATCH_SIZE):

image, label = get_random_images_and_labels([784], [1])

yield image, label

return __reader__

# If the data generator yield list of samples each time,

# use DataLoader.set_sample_list_generator to set the data source.

def sample_list_generator_creator():

def __reader__():

for _ in range(BATCH_NUM):

sample_list = []

for _ in range(BATCH_SIZE):

image, label = get_random_images_and_labels([784], [1])

sample_list.append([image, label])

yield sample_list

return __reader__

# If the data generator yields a batch each time,

# use DataLoader.set_batch_generator to set the data source.

def batch_generator_creator():

def __reader__():

for _ in range(BATCH_NUM):

batch_image, batch_label = get_random_images_and_labels([BATCH_SIZE, 784], [BATCH_SIZE, 1])

yield batch_image, batch_label

return __reader__

# If DataLoader is iterable, use for loop to train the network

def train_iterable(exe, prog, loss, loader):

for _ in range(EPOCH_NUM):

for data in loader():

exe.run(prog, feed=data, fetch_list=[loss])

# If DataLoader is not iterable, use start() and reset() method to control the process

def train_non_iterable(exe, prog, loss, loader):

for _ in range(EPOCH_NUM):

loader.start() # call DataLoader.start() before each epoch starts

try:

while True:

exe.run(prog, fetch_list=[loss])

except fluid.core.EOFException:

loader.reset() # call DataLoader.reset() after catching EOFException

def set_data_source(loader, places):

if DATA_FORMAT == ‘sample_generator’:

loader.set_sample_generator(sample_generator_creator(), batch_size=BATCH_SIZE, drop_last=True, places=places)

elif DATA_FORMAT == ‘sample_list_generator’:

loader.set_sample_list_generator(sample_list_generator_creator(), places=places)

elif DATA_FORMAT == ‘batch_generator’:

loader.set_batch_generator(batch_generator_creator(), places=places)

else:

raise ValueError(‘Unsupported data format’)

image = fluid.layers.data(name=’image’, shape=[784], dtype=’float32′)

label = fluid.layers.data(name=’label’, shape=[1], dtype=’int64′)

# Define DataLoader

loader = fluid.io.DataLoader.from_generator(feed_list=[image, label], capacity=16, iterable=ITERABLE)

# Define network

loss = simple_net(image, label)

# Set data source of DataLoader

#

# If DataLoader is iterable, places must be given and the number of places must be the same with device number.

# – If you are using GPU, call `fluid.cuda_places()` to get all GPU places.

# – If you are using CPU, call `fluid.cpu_places()` to get all CPU places.

#

# If DataLoader is not iterable, places can be None.

places = fluid.cuda_places() if USE_GPU else fluid.cpu_places()

set_data_source(loader, places)

exe = fluid.Executor(places[0])

exe.run(fluid.default_startup_program())

prog = fluid.CompiledProgram(fluid.default_main_program()).with_data_parallel(loss_name=loss.name)

if loader.iterable:

train_iterable(exe, prog, loss, loader)

else:

train_non_iterable(exe, prog, loss, loader)

”’

Users can use return_list = True in dygraph mode.

”’

with fluid.dygraph.guard(places[0]):

loader = fluid.io.DataLoader.from_generator(capacity=2, return_list=True)

set_data_source(loader, places[0])

for image, label in loader():

relu = fluid.layers.relu(image)

assert image.shape == [BATCH_SIZE, 784]

assert label.shape == [BATCH_SIZE, 1]

assert relu.shape == [BATCH_SIZE, 784]

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

发布者:全栈程序员-用户IM,转载请注明出处:https://javaforall.cn/132199.html原文链接:https://javaforall.cn

【正版授权,激活自己账号】: Jetbrains全家桶Ide使用,1年售后保障,每天仅需1毛

【官方授权 正版激活】: 官方授权 正版激活 支持Jetbrains家族下所有IDE 使用个人JB账号...

(0)


相关推荐

  • 解决AMD CPU 启动Android模拟器时无法安装Intel HAXM 的问题

    解决AMD CPU 启动Android模拟器时无法安装Intel HAXM 的问题刚买的电脑,是用的AMD的CPU,在安装Android开发环境时,遇到以下问题:FailedtoinstallIntelHAXM.Fordetails,pleasechecktheinstallationlog:”C:\Users\zhangqs\AppData\Local\Temp\haxm_log9.txt”HAXMinstallationfailed.T…

  • NOIP2008_2012年12月16号农历是

    NOIP2008_2012年12月16号农历是NOIP2012DAY1T1Vigenère密码纯模拟#include<bits/stdc++.h>chara[105],b[1005],ans[1005];intlen1,len2,h=1,p;intc[105];usingnamespacestd;intmain(){// freopen(“vigenere.in”,”r”,stdin); …

  • Mysql使用到substring截取字符串[通俗易懂]

    Mysql使用到substring截取字符串[通俗易懂]mysql截取字符串的时候是从1开始的而不是从0开始的语法:substring(str,start,len)bz:*_*的形式例子:select* fromcost wheresubstring(bz,1,1)

  • kafka应用场景有哪些_kafka顺序性的消费

    kafka应用场景有哪些_kafka顺序性的消费序在学习一门新技术之前,我们需要先去了解一下这门技术的具体应用场景,使用它能够做什么,能够达到什么目的,学习kafka的初衷是用作消息队列;但是还可以使用KafkaStream进行一些实时的流计算,多用于大数据处理;也可以做日志收集汇总、网站活动跟踪等任务。消息队列kafka可以很好的替代一些传统的消息系统,kafka具有更好的吞吐量,内置的分区使kaf…

    2022年10月14日
  • BufferedWriter导出数据excel文件

    BufferedWriter导出数据excel文件BufferedWriter导出数据BufferedWriter将文本写入字符输出流,缓冲各个字符,从而提供单个字符、数组和字符串的高效写入。可以指定缓冲区的大小,或者接受默认的大小。在大多数情况下,默认值就足够大了js页面//导出数据functionexportData(){vardata={};…

  • drupal 6.0 入门教程 – 第一章

    drupal 6.0 入门教程 – 第一章
     
    由于工作项目的原因,需要采用drupal来部署,所以最近学习了drupalcms,天天到 drupal.org,drupalchina.org,zhupou.cn,5iphp.com上学习
     
     
    项目的核心是提供一款在线教学和互动社区,希望通过这个教程提供给大家一个比较全面的项目开发指导。首先,我近期的主要任务是熟悉drupalCMS,和设计主页的版式也就是themes。
     
    下面我们从drupal的介绍入手,开始讲解如果

发表回复

您的电子邮箱地址不会被公开。

关注全栈程序员社区公众号