当前位置:
首页
文章
后端
详情

MXNet动手学深度学习笔记:VGG神经网络实现

#coding:utf-8
'''
VGG网络
'''
from mxnet.gluon import nn
from mxnet import ndarray as nd
import mxnet as mx
from mxnet import init
import os
import sys
sys.path.append(os.getcwd())
import utils

def vgg_block(num_convs,channels):
    out = nn.Sequential()
    for _ in range(num_convs):
        out.add(
            nn.Conv2D(channels=channels,kernel_size=3,padding=1,activation='relu')
        )

    out.add(nn.MaxPool2D(pool_size=2,strides=2))

    return out

# 将 vgg_block堆起来
def vgg_stack(arhitechure):
    out = nn.Sequential()
    for(num_convs,channels) in arhitechure:
        out.add(vgg_block(num_convs,channels))

    return out

blk = vgg_block(2,128)
blk.initialize()

x = nd.random.uniform(shape=(2,3,16,16))
y = blk(x)
print(y.shape)

# 定义一个最简单的VGG结构,8个卷积层,3个全连接层,称为VGG11
ctx = mx.cpu() #utils.try_gpu()
num_outputs = 10
architechure = ((1,64),(1,128),(2,256),(2,512),(2,512))
net = nn.Sequential()
with net.name_scope():
    net.add(
        vgg_stack(architechure),
        nn.Flatten(),
        nn.Dense(4096,activation='relu'),
        nn.Dropout(0.5),
        nn.Dense(4096,activation='relu'),
        nn.Dropout(0.5),
        nn.Dense(num_outputs)
    )

# 训练模型
train_data,test_data = utils.load_data_fashion_mnist(batch_size=64,
            resize=96)
ctx = utils.try_gpu()
net.initialize(ctx=ctx,init=init.Xavier())

loss = mx.gluon.loss.SoftmaxCrossEntropyLoss()
trainer = mx.gluon.Trainer(net.collect_params(),'sgd',{'learning_rate':0.05})

utils.train(train_data,test_data,net,loss,trainer,ctx,num_epochs=5)

免责申明:本站发布的内容(图片、视频和文字)以转载和分享为主,文章观点不代表本站立场,如涉及侵权请联系站长邮箱:xbc-online@qq.com进行反馈,一经查实,将立刻删除涉嫌侵权内容。

同类热门文章

深入了解C++中的new操作符:使用具体实例学习

C++中的new操作符是动态分配内存的主要手段之一。在程序运行时,我们可能需要动态地创建和销毁对象,而new就是为此提供了便利。但是,使用new也常常会引发一些问题,如内存泄漏、空指针等等。因此,本文将通过具体的示例,深入介绍C++中的new操作符,帮助读者更好地掌握其使用。


深入了解C++中的new操作符:使用具体实例学习

怎么用Java反射获取包下所有类? 详细代码实例操作

Java的反射机制就是在运行状态下,对于任何一个类,它能知道这个类的所有属性和方法;对于任何一个对象,都能调用这个对象的任意一个方法。本篇文章将通过具体的代码示例,展示如何通过Java反射来获取包下的所有类。


怎么用Java反射获取包下所有类? 详细代码实例操作

员工线上学习考试系统

有点播,直播,在线支付,三级分销等功能,可以对学员学习情况的监督监控,有源码,可二次开发。支持外网和局域网私有化部署,经过测试源码完整可用!1、视频点播:视频播放,图文资料,课件下载,章节试学,限时免

员工线上学习考试系统

了解Java中的volati关键字的作用 以及具体使用方法

本篇文章将和大家分享一下Java当中的volatile关键字,下面将为各位小伙伴讲述volatile关键字的作用以及它的具体使用方法。


了解Java中的volati关键字的作用 以及具体使用方法

Java Map 所有的值转为String类型

可以使用 Java 8 中的 Map.replaceAll() 方法将所有的值转为 String 类型: 上面的代码会将 map 中所有的值都转为 String 类型。 HashMap 是 Java

Java Map 所有的值转为String类型