利用九天深度学习平台复现SSA-GAN
时间:2023-04-10 14:07:00
目录
- 一、计算能力收集
- 1.1.填写比赛问卷
- 1.2.正式报名参加比赛
- 1.3、领取算力
- 二、复现SSA-GAN
- 2.1、创建实例
- 2.2.下载代码和数据集
- 2.3.下载预训练 DAMSM 模型
- 2.4、环境配置
- 2.5、训练
- 最后
Semantic-Spatial Aware GAN语义空间感知是2021年10月发表的GAN(SSA-GAN)主要提出框架:
- 语义空间感知卷积网络(SSACN)该模块不仅可以根据当前生成的图像特征映射草图,不仅可以决定在哪里添加文本信息,还可以决定在某一部分加强多少文本信息。
- 一种新的仿射参数计算方法,计算方法中SCBN作为空间条件,然后从编码的文本向量中学习仿射参数,并将语义空间条件分批整合。
文章精读报告:https://blog.csdn.net/air__Heaven/article/details/124469059
本篇文章将使用中国移动云-九天深度学习平台复制SSA-GAN。
一、计算能力收集
九天深度学习平台是中国移动旗下的机器学习平台CPU、V100、T等高性能计算资源调度管理,集成主流人工智能开源算法框架,Jupyter lab开发工具、主流公共数据集、参考源代码和预培训模型,为模型培训、服务部署和在线推理提供一站式服务。我们可以免费参加中国移动云杯比赛300小时计算深度平台9天。
1.1.填写比赛问卷
首先,我们填写比赛问卷预报名:中国移动云杯比赛报名问卷
填写姓名和手机号码(相应的计算能力),任意提交作品(有能力的可以提交,比赛奖励也很丰富),注册轨道可以填写大学-全国轨道
1.2.正式报名参加比赛
其次,点击竞赛正式注册链接:2022年移动云杯计算网络应用创新竞赛(学生可选择大学轨道)
这一步需要注册中国移动-移动云账号(非中国移动手机号也可以注册)。
注册成功后,点击刚刚的比赛链接左上注册比赛(报名渠道填CSDN),如果你已经注册,你可以开始获得算力。
1.3、领取算力
九天人工智能平台深度学习平台为已完成实名认证的互联网用户提供免费试用活动。开放后,各种资源可免费使用300小时(3个月内有效)。
使用刚注册的移动云账号,深度学习平台计算能力300小时9天,点击9天深度学习平台免费试用申请
这一步可能需要实名认证,实名认证后才能成功申请。
当你进入移动云控制台时,你可以看到我的云产品申请算力:九天深度学习平台
点击进入后,我们可以看到平台提供了非常高的性能CPU、GPU免费使用,其中提供300小时的T4、V100使用时长,更有趣的是,你可以看到他的计算能力是分开计算的,也就是说,你每一个都可以用足足300个小时,白嫖党福音:
二、复现SSA-GAN
2.1、创建实例
首先,我们点击控制台页面进入中国移动云-九天深度学习平台:
点击左侧notebook建模,然后点击新的例子创建我们ssagan模型。
因为我们已经申请了计算力试用,所以不用担心费用,直接选择vGPU或者V创建100套餐的例子。
如上图所示,例子,如上图所示,点击操作,进入熟悉的例子juypter界面:
2.2.下载代码和数据集
点击底部terminal,打开终端,然后使用git克隆代码:git clone https://github.com/wtliao/text2image.git
成功下载代码。
然后下载元数据包,使用命令行cd 进入text2image目录下载为鸟类准备的预处理元数据。元数据的谷歌链接无法打开CSDN下载链接1或链接2,然后上传原始数据data目录并使用unzip命令解压成文件夹:
最后,下载鸟数据集,下载链接:
http://www.vision.caltech.edu/visipedia/CUB-200-2011.html,上传数据(可以发现上传速度很快)
然后使用命令tar zxvf CUB_200_2011.tgz
解压保存在data/birds/中:
另外还要用unzip text.zip
终端命令解压text.zip文件:
2.3.下载预训练 DAMSM 模型
下载预处理DAMSM模型不能打开可以访问CSDN链接下载。
然后上传到 DAMSMencoders目录下:
我们也用unzip bird.zip
命令解压。
2.4、环境配置
到目前为止,我们基本上已经拥有了我们需要的资源。下一步是我们安装所需的虚拟环境:
首先我们conda create -n ssagan
创造一个新的虚拟环境,环境被称为ssagan(也可以随意命名)
然后可以通过nvcc命令看到cuda版本为10.1,
所以我们首先,激活虚拟环境:conda activate ssagan
然后安装pytorch,首先,我们使用它conda search pytorch,找到可安装的版本:
因为cuda版本是10.所以我们优先考虑1,cuda101的:
终端输入:conda install pytorch=1. 7.1=cuda101py36h42dc283_ 1
安装pytorch
安装方式相同torchvision
按提示安装其它环境:
conda install tensorboardX
conda install python-dateutil
conda install tqdm
conda install matplotlib
pip install scikit-image
pip install easydict
pip install nltk
pip install pandas
pip install pyyaml
2.5、训练
将bird.yml中的B_VALIDATION改为 False
cd进入text2image输入终端命名:python main.py
开始训练
可能的报错1:load() missing 1 required positional argument: ‘Loader‘
解决方案:这是因为.yaml文件在load()缺少必填loader参数,只需将 pyyaml 版本降级或将config.py的yaml_cfg = edict(yaml.load(f))
改为safeload
可能的报错2:module ‘torchvision.transforms’ has no attribute ‘Resize’
解决方案:pip install --upgrade torchvision
可能的报错3:TypeError: init() got an unexpected keyword argument ‘serialized_options’
解决方案:终端 protoc 版本 与python库内的protobuf不同的版本。我们只需要pip install -U protobuf
,如仍有错误,建议卸载删除低版本protobuf,再重新安装
可能的报错4:urllib.eror.HTTPError: HTTP Error 403: Forbidden
问题原因:网站设置了白名单,大部分网站不让访问,故Downloading: “https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth” to /root/.cache/torch/hub/checkpoints/inception_v3_google-0cc3c7bd.pth时被拒绝。
解决方案:打开https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth下载pth文件然后上传到/root/.cache/torch/hub/checkpoints目录。
运行成功如下:
也可以下载已经训练好的SSA-GAN模型进行采用生成。
可以看到nf=64时,SSA-GAN的消耗为每轮epoch要14分钟左右,共600轮epoch
最后
💖 个人简介:人工智能领域研究生,目前主攻文本生成图像(text to image)方向
📝 关注我:中杯可乐多加冰
🔥 限时免费订阅:文本生成图像T2I专栏
🎉 支持我:点赞👍+收藏⭐️+留言📝
如果这篇文章帮助到你很多,希望能点击下方打赏我一杯可乐!多加冰哦