Pytorch 1.7.1代码规范

全功能AI开发平台 BML

  • 版本发布记录
  • 快速开始
    • 用BML实现表格预测
    • 用BML实现序列标注
    • 用BML实现文本实体抽取
    • 用BML实现图片分类
    • 用BML实现实例分割
    • 用BML评价短文本相似度
    • 用BML实现开源大模型的预训练(Post-pretrain)
    • 用BML实现文本分类
    • 用BML实现物体检测
  • 模型仓库
    • 从训练任务导入模型
    • 查看模型
    • 创建模型
    • 模型仓库简介
    • 从本地导入模型
    • 校验模型
    • 服务代码文件示例
      • Sklearn服务代码文件示例
      • XGBoost服务代码文件示例
  • 平台管理
    • 权限管理
    • 在BML平台使用并行文件系统PFS和对象存储BOS
    • 在BML平台使用容器镜像服务CCR
    • 在BML使用外部镜像
    • 项目空间管理
    • 镜像管理
      • 镜像使用
      • 镜像管理简介
      • 常见问题
      • 自定义镜像
    • 资源管理
      • 资源池管理简介
      • 资源池使用简介
  • 预测部署
    • 批量预测(用户资源池)API
    • 文字识别模型部署
      • 文字识别任务API参考文档
      • 文字识别任务公有云部署
    • 通用模型部署
      • 标准接口规范参考
      • Paddle框架API调用文档
      • sklearn框架API调用文档
      • 公有云部署
      • XGBoost框架API调用文档
      • tensorflow框架API调用文档
      • Pytorch框架API调用文档
      • 通用类模型API参考
      • 错误码
    • 语音技术模型部署
      • 声音分类API调用文档
    • 视觉模型部署
      • 智能边缘控制台-多节点版
      • 端云协同服务部署
      • 智能边缘控制台-单节点版
      • 视觉任务模型部署整体说明
      • 软硬一体方案部署
        • 视觉任务Jetson专用SDK集成文档
        • 如何获取视觉任务软硬一体产品
        • 视觉任务EdgeBoard(VMX)专用SDK集成文档
        • 视觉任务EdgeBoard(FZ)专用SDK集成文档
        • 视觉任务专用辨影SDK集成开发文档
      • 私有服务器部署
        • 视觉模型如何部署在私有服务器
        • 私有API
          • 如何发布私有API
          • 图像分类-单图单标签私有API集成文档
          • 图像分类-单图多标签私有API集成文档
          • 物体检测私有API集成文档
        • 服务器端SDK
          • 视觉任务服务器端LinuxSDK集成文档-Python
          • 视觉任务服务器端LinuxSDK集成文档-C++
          • 如何发布服务器端SDK
          • 视觉任务服务器端WindowsSDK集成文档
          • 视觉任务服务器端SDK简介
      • 设备端SDK部署
        • 视觉任务WindowsSDK集成文档
        • 视觉任务iOSSDK集成文档
        • 视觉任务LinuxSDK集成文档-Python
        • 视觉任务LinuxSDK集成文档-C++
        • 视觉任务设备端SDK使用说明
        • 如何发布视觉任务设备端SDK
        • 视觉任务AndroidSDK集成文档
      • 公有云部署
        • 文字识别API参考文档
        • 视觉任务公有云部署
        • 物体检测API参考文档
        • 图像分类-单图单标签API参考文档
        • 实例分割API参考文档
        • 图像分类-单图多标签API参考文档
    • 表格预测模型部署
      • 整体说明
      • 公有云部署
    • 公有云部署管理
      • 配置AB测试版本
      • 批量预测服务
      • 公有云部署
      • 公有云部署简介
      • 配置休眠策略
    • NLP模型部署
      • 自然语言处理任务模型部署整体说明
      • 私有服务器部署
        • 如何部署在私有服务器
        • 私有服务API说明
          • 私有部署说明-短文本相似度
          • 私有化部署接口说明-文本分类
          • 私有部署文档-序列标注
          • 文本实体抽取API调用文档
      • 公有云部署
        • 短文本匹配API调用文档
        • 文本实体抽取私有API调用说明
        • 如何发布自然语言处理任务API
        • 文本分类-多标签API调用文档
        • 文本分类API调用文档
        • 序列标注API调用文档
  • 数据服务
    • 数据服务简介
    • 智能数据API
    • 公有云服务调用数据反馈
    • 智能标注
      • 文本智能标注介绍及原理说明
      • 图像智能标注介绍说明
    • 管理视觉数据
      • 实例分割数据导入与标注
        • 数据标注说明
        • 导入未标注数据
        • 导入已标注数据
      • 物体检测数据导入与标注
        • 物体检测数据标注说明
        • 物体检测导入未标注数据
        • 物体检测导入已标注数据
      • 图像分类数据导入与标注
        • 图像分类导入未标注数据
        • 图像分类导入已标注数据
        • 图像分类数据标注说明
    • 管理文本数据
      • 文本分类数据导入与标注
        • 文本分类数据标注说明
        • 文本分类数据导入与标注
        • 数据去重策略
      • 序列标注数据导入与标注
        • 序列标注标注说明
        • 序列标注数据导入
        • 数据去重策略
      • 文本实体抽取数据标注
        • 文本实体抽取数据标注
        • 文本实体抽取数据导入
        • 数据去重策略
      • 短文本匹配数据导入与标注
        • 短文本匹配数据导入与标注
        • 数据去重策略说明
        • 短文本匹配数据标注
  • 产品简介
    • BML平台升级公告
    • 平台重点升级介绍
    • 产品优势
    • 产品功能
    • 什么是BML
    • 文心大模型
  • 产品定价
    • 服务器部署价格说明
    • 专项适配硬件部署价格说明
    • 公有云部署计费说明
    • 批量预测计费说明
    • 模型训练计费说明
    • 通用小型设备部署价格说明
  • 模型训练
    • Notebook建模
      • 创建并启动Notebook
      • Notebook导入数据集
      • 保存Notebook中的模型
      • Notebook使用参考
      • 常见问题
      • 数据模型可视化功能说明
      • Notebook简介
      • 发布模型
      • 配置模型
      • 使用Notebook开发模型
      • 如何使用Notebook SSH 功能
      • Notebook从训练到部署快速入门
        • Codelab Notebook自定义环境部署最佳实践
        • 基于Notebook的图像分类模板使用指南
        • 基于 Notebook 的 NLP 通用模板使用指南
        • Notebook 模板使用指南概述
        • 基于 Notebook 的通用模板使用指南
        • 基于 Notebook 的物体检测模板使用指南
    • 自定义作业建模
      • 自定义作业简介
      • 训练作业API
      • 训练作业
        • 使用训练作业训练模型
        • 创建训练作业
        • 发布模型
        • 训练作业代码示例
          • TensorFlow 1.13.2
          • AIAK- Training Pytorch版
          • TensorFlow 2.3.0
          • Blackhole 1.0.0
          • Pytorch 1.7.1
          • Sklearn 0.23.2
          • XGBoost 1.3.1
          • PaddlePaddle 2.0.0rc
      • 自动搜索作业
        • 创建自动搜索作业
        • yaml文件编写规范
        • 自动搜索作业简介
        • 自动搜索作业代码编写规范
        • 自动搜索作业代码示例
          • XGBoost 1.3.1代码规范
          • TensorFlow 1.13.2代码规范
          • Sklearn 0.23.2代码规范
          • Pytorch 1.7.1代码规范
          • Tensorflow2.3.0代码规范
          • PaddlePaddle 2.1.1代码规范
    • 可视化建模
      • 快速入门
      • 概述
      • 组件菜单
        • 001-基本操作
        • 003-查看模型特征溯源
        • 007-组件状态
        • 008-组件资源配置
        • 006-组件列选择
        • 002-查看模型可解释性
        • 004-查看特征重要性
      • 组件说明
        • 015-图算法
        • 004-特征工程组件
        • 003-数据处理组件
        • 012-预测组件
        • 008-聚类算法
        • 009-Python算法组件
        • 002-数据集组件
        • 014-自然语言处理组件
        • 010-NLP算法
        • 016-统计分析组件
        • 006-回归算法
        • 007-异常检测算法
        • 013-模型评估组件
        • 005-分类算法
        • 018-时间序列组件
      • 画布操作说明
        • 005-AutoML(自动调参)
        • 002-开始训练
        • 001-概述
    • 预置模型调参建模
      • 预置模型调参简介
      • 神经网络训练搜索
      • 开发视觉模型
        • 视觉任务简介
        • 查看训练结果
        • 创建视觉任务
        • 配置视觉任务
        • 开发参考
          • 视觉预训练模型
          • 超参数配置参考
          • 评估报告参考
          • 自动超参搜索配置参考
          • 数据增强算子参考
          • 训练时长设置参考
          • 网络选型参考
      • 开发表格预测模型
        • 创建表格预测任务
        • 配置专家模式表格数据预测任务
        • 查看训练结果
        • 配置AUTOML模式表格数据预测任务
        • 表格预测任务简介
      • 开发文字识别模型
        • 文字识别任务简介
        • 文字识别任务操作流程
      • 开发自然语言处理模型
        • 查看训练结果
        • 自然语言处理任务简介
        • 配置NLP任务
        • 创建NLP任务
        • 代码模板升级及迁移说明
所有文档
menu
没有找到结果,请重新输入

全功能AI开发平台 BML

  • 版本发布记录
  • 快速开始
    • 用BML实现表格预测
    • 用BML实现序列标注
    • 用BML实现文本实体抽取
    • 用BML实现图片分类
    • 用BML实现实例分割
    • 用BML评价短文本相似度
    • 用BML实现开源大模型的预训练(Post-pretrain)
    • 用BML实现文本分类
    • 用BML实现物体检测
  • 模型仓库
    • 从训练任务导入模型
    • 查看模型
    • 创建模型
    • 模型仓库简介
    • 从本地导入模型
    • 校验模型
    • 服务代码文件示例
      • Sklearn服务代码文件示例
      • XGBoost服务代码文件示例
  • 平台管理
    • 权限管理
    • 在BML平台使用并行文件系统PFS和对象存储BOS
    • 在BML平台使用容器镜像服务CCR
    • 在BML使用外部镜像
    • 项目空间管理
    • 镜像管理
      • 镜像使用
      • 镜像管理简介
      • 常见问题
      • 自定义镜像
    • 资源管理
      • 资源池管理简介
      • 资源池使用简介
  • 预测部署
    • 批量预测(用户资源池)API
    • 文字识别模型部署
      • 文字识别任务API参考文档
      • 文字识别任务公有云部署
    • 通用模型部署
      • 标准接口规范参考
      • Paddle框架API调用文档
      • sklearn框架API调用文档
      • 公有云部署
      • XGBoost框架API调用文档
      • tensorflow框架API调用文档
      • Pytorch框架API调用文档
      • 通用类模型API参考
      • 错误码
    • 语音技术模型部署
      • 声音分类API调用文档
    • 视觉模型部署
      • 智能边缘控制台-多节点版
      • 端云协同服务部署
      • 智能边缘控制台-单节点版
      • 视觉任务模型部署整体说明
      • 软硬一体方案部署
        • 视觉任务Jetson专用SDK集成文档
        • 如何获取视觉任务软硬一体产品
        • 视觉任务EdgeBoard(VMX)专用SDK集成文档
        • 视觉任务EdgeBoard(FZ)专用SDK集成文档
        • 视觉任务专用辨影SDK集成开发文档
      • 私有服务器部署
        • 视觉模型如何部署在私有服务器
        • 私有API
          • 如何发布私有API
          • 图像分类-单图单标签私有API集成文档
          • 图像分类-单图多标签私有API集成文档
          • 物体检测私有API集成文档
        • 服务器端SDK
          • 视觉任务服务器端LinuxSDK集成文档-Python
          • 视觉任务服务器端LinuxSDK集成文档-C++
          • 如何发布服务器端SDK
          • 视觉任务服务器端WindowsSDK集成文档
          • 视觉任务服务器端SDK简介
      • 设备端SDK部署
        • 视觉任务WindowsSDK集成文档
        • 视觉任务iOSSDK集成文档
        • 视觉任务LinuxSDK集成文档-Python
        • 视觉任务LinuxSDK集成文档-C++
        • 视觉任务设备端SDK使用说明
        • 如何发布视觉任务设备端SDK
        • 视觉任务AndroidSDK集成文档
      • 公有云部署
        • 文字识别API参考文档
        • 视觉任务公有云部署
        • 物体检测API参考文档
        • 图像分类-单图单标签API参考文档
        • 实例分割API参考文档
        • 图像分类-单图多标签API参考文档
    • 表格预测模型部署
      • 整体说明
      • 公有云部署
    • 公有云部署管理
      • 配置AB测试版本
      • 批量预测服务
      • 公有云部署
      • 公有云部署简介
      • 配置休眠策略
    • NLP模型部署
      • 自然语言处理任务模型部署整体说明
      • 私有服务器部署
        • 如何部署在私有服务器
        • 私有服务API说明
          • 私有部署说明-短文本相似度
          • 私有化部署接口说明-文本分类
          • 私有部署文档-序列标注
          • 文本实体抽取API调用文档
      • 公有云部署
        • 短文本匹配API调用文档
        • 文本实体抽取私有API调用说明
        • 如何发布自然语言处理任务API
        • 文本分类-多标签API调用文档
        • 文本分类API调用文档
        • 序列标注API调用文档
  • 数据服务
    • 数据服务简介
    • 智能数据API
    • 公有云服务调用数据反馈
    • 智能标注
      • 文本智能标注介绍及原理说明
      • 图像智能标注介绍说明
    • 管理视觉数据
      • 实例分割数据导入与标注
        • 数据标注说明
        • 导入未标注数据
        • 导入已标注数据
      • 物体检测数据导入与标注
        • 物体检测数据标注说明
        • 物体检测导入未标注数据
        • 物体检测导入已标注数据
      • 图像分类数据导入与标注
        • 图像分类导入未标注数据
        • 图像分类导入已标注数据
        • 图像分类数据标注说明
    • 管理文本数据
      • 文本分类数据导入与标注
        • 文本分类数据标注说明
        • 文本分类数据导入与标注
        • 数据去重策略
      • 序列标注数据导入与标注
        • 序列标注标注说明
        • 序列标注数据导入
        • 数据去重策略
      • 文本实体抽取数据标注
        • 文本实体抽取数据标注
        • 文本实体抽取数据导入
        • 数据去重策略
      • 短文本匹配数据导入与标注
        • 短文本匹配数据导入与标注
        • 数据去重策略说明
        • 短文本匹配数据标注
  • 产品简介
    • BML平台升级公告
    • 平台重点升级介绍
    • 产品优势
    • 产品功能
    • 什么是BML
    • 文心大模型
  • 产品定价
    • 服务器部署价格说明
    • 专项适配硬件部署价格说明
    • 公有云部署计费说明
    • 批量预测计费说明
    • 模型训练计费说明
    • 通用小型设备部署价格说明
  • 模型训练
    • Notebook建模
      • 创建并启动Notebook
      • Notebook导入数据集
      • 保存Notebook中的模型
      • Notebook使用参考
      • 常见问题
      • 数据模型可视化功能说明
      • Notebook简介
      • 发布模型
      • 配置模型
      • 使用Notebook开发模型
      • 如何使用Notebook SSH 功能
      • Notebook从训练到部署快速入门
        • Codelab Notebook自定义环境部署最佳实践
        • 基于Notebook的图像分类模板使用指南
        • 基于 Notebook 的 NLP 通用模板使用指南
        • Notebook 模板使用指南概述
        • 基于 Notebook 的通用模板使用指南
        • 基于 Notebook 的物体检测模板使用指南
    • 自定义作业建模
      • 自定义作业简介
      • 训练作业API
      • 训练作业
        • 使用训练作业训练模型
        • 创建训练作业
        • 发布模型
        • 训练作业代码示例
          • TensorFlow 1.13.2
          • AIAK- Training Pytorch版
          • TensorFlow 2.3.0
          • Blackhole 1.0.0
          • Pytorch 1.7.1
          • Sklearn 0.23.2
          • XGBoost 1.3.1
          • PaddlePaddle 2.0.0rc
      • 自动搜索作业
        • 创建自动搜索作业
        • yaml文件编写规范
        • 自动搜索作业简介
        • 自动搜索作业代码编写规范
        • 自动搜索作业代码示例
          • XGBoost 1.3.1代码规范
          • TensorFlow 1.13.2代码规范
          • Sklearn 0.23.2代码规范
          • Pytorch 1.7.1代码规范
          • Tensorflow2.3.0代码规范
          • PaddlePaddle 2.1.1代码规范
    • 可视化建模
      • 快速入门
      • 概述
      • 组件菜单
        • 001-基本操作
        • 003-查看模型特征溯源
        • 007-组件状态
        • 008-组件资源配置
        • 006-组件列选择
        • 002-查看模型可解释性
        • 004-查看特征重要性
      • 组件说明
        • 015-图算法
        • 004-特征工程组件
        • 003-数据处理组件
        • 012-预测组件
        • 008-聚类算法
        • 009-Python算法组件
        • 002-数据集组件
        • 014-自然语言处理组件
        • 010-NLP算法
        • 016-统计分析组件
        • 006-回归算法
        • 007-异常检测算法
        • 013-模型评估组件
        • 005-分类算法
        • 018-时间序列组件
      • 画布操作说明
        • 005-AutoML(自动调参)
        • 002-开始训练
        • 001-概述
    • 预置模型调参建模
      • 预置模型调参简介
      • 神经网络训练搜索
      • 开发视觉模型
        • 视觉任务简介
        • 查看训练结果
        • 创建视觉任务
        • 配置视觉任务
        • 开发参考
          • 视觉预训练模型
          • 超参数配置参考
          • 评估报告参考
          • 自动超参搜索配置参考
          • 数据增强算子参考
          • 训练时长设置参考
          • 网络选型参考
      • 开发表格预测模型
        • 创建表格预测任务
        • 配置专家模式表格数据预测任务
        • 查看训练结果
        • 配置AUTOML模式表格数据预测任务
        • 表格预测任务简介
      • 开发文字识别模型
        • 文字识别任务简介
        • 文字识别任务操作流程
      • 开发自然语言处理模型
        • 查看训练结果
        • 自然语言处理任务简介
        • 配置NLP任务
        • 创建NLP任务
        • 代码模板升级及迁移说明
  • 文档中心
  • arrow
  • 全功能AI开发平台BML
  • arrow
  • 模型训练
  • arrow
  • 自定义作业建模
  • arrow
  • 自动搜索作业
  • arrow
  • 自动搜索作业代码示例
  • arrow
  • Pytorch 1.7.1代码规范
本页目录
  • Pytorch 1.7.1代码规范

Pytorch 1.7.1代码规范

更新时间:2025-08-21

Pytorch 1.7.1代码规范

基于Pytorch 1.7.1框架的MNIST图像分类,训练数据集pytorch_train_data.zip点击这里下载。
如下所示是其超参搜索任务中一个超参数组合的训练代码,代码会通过argparse模块接受在平台中填写的信息,请保持一致。
特别注意,示例采用的是进化算法进行超参搜索,每个试验在训练时会继承之前试验的权重,resume_checkpoint_path是权重的保存路径,由搜索算法自身提供,与job_id及trial_id一样,只需要在argparse中提供对应参数即可。

pytorch1.7.1_autosearch.py示例代码

Python
1# -*- coding:utf-8 -*-
2""" pytorch train demo """
3import argparse
4import torch
5import torch.nn as nn
6import torch.nn.functional as F
7import torch.optim as optim
8import torch.utils.data as data
9from torchvision import transforms
10import codecs
11import errno
12import gzip
13import numpy as np
14import os
15import time
16from PIL import Image
17from rudder_autosearch.sdk.amaas_tools import AMaasTools
18
19def parse_arg():
20    """parse arguments"""
21    parser = argparse.ArgumentParser(description='PyTorch1.7.1 MNIST Example')
22    parser.add_argument('--train_dir', type=str, default='./train_data',
23                        help='input data dir for training (default: ./train_data)')
24    parser.add_argument('--test_dir', type=str, default='./test_data',
25                        help='input data dir for test (default: ./test_data)')
26    parser.add_argument('--output_dir', type=str, default='./output',
27                        help='output dir for auto_search job (default: ./output)')
28    parser.add_argument('--job_id', type=str, default="job-1234",
29                        help='auto_search job id (default: "job-1234")')
30    parser.add_argument('--trial_id', type=str, default="0-0",
31                        help='auto_search id of a single trial (default: "0-0")')
32    parser.add_argument('--metric', type=str, default="acc",
33                        help='evaluation metric of the model')
34    parser.add_argument('--data_sampling_scale', type=float, default=1.0,
35                        help='sampling ratio of the data (default: 1.0)')
36    parser.add_argument('--batch_size', type=int, default=64,
37                        help='number of images input in an iteration (default: 64)')
38    parser.add_argument('--lr', type=float, default=0.01,
39                        help='learning rate (default: 0.01)')
40    parser.add_argument('--momentum', type=float, default=0.5,
41                        help='SGD momentum (default: 0.5)')
42    parser.add_argument('--no_cuda', action='store_true', default=False,
43                        help='disables CUDA training')
44    parser.add_argument('--log_interval', type=int, default=10,
45                        help='how many batches to wait before logging training status')
46    parser.add_argument('--perturb_interval', type=int, default=10,
47                        help='number of epochs to train (default: 10)')
48    parser.add_argument('--resume_checkpoint_path', type=str, default="",
49                        help='inherit the initial weight of the previous trial')
50    args = parser.parse_args()
51    args.output_dir = os.path.join(args.output_dir, args.job_id, args.trial_id)
52    if not os.path.exists(args.output_dir):
53        os.makedirs(args.output_dir)
54    args.cuda = not args.no_cuda and torch.cuda.is_available()
55
56
57    print("job_id: {}, trial_id: {}".format(args.job_id, args.trial_id))
58    return args
59
60# 定义MNIST数据集的dataset
61class MNIST(data.Dataset):
62    """
63    MNIST dataset
64    """
65    training_file = 'training.pt'
66    test_file = 'test.pt'
67    classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
68               '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
69
70    def __init__(self, root, train=True, transform=None, target_transform=None, data_sampling_scale=1):
71        self.root = os.path.expanduser(root)
72        self.transform = transform
73        self.target_transform = target_transform
74        self.train = train  # training set or test set
75        self.data_sampling_scale = data_sampling_scale
76        self.preprocess(root, train, False)
77        if self.train:
78            data_file = self.training_file
79        else:
80            data_file = self.test_file
81        self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file))
82
83    def __getitem__(self, index):
84        """
85        Args:
86            index (int): Index
87        Returns:
88            tuple: (image, target) where target is index of the target class.
89        """
90        img, target = self.data[index], int(self.targets[index])
91        # doing this so that it is consistent with all other datasets
92        # to return a PIL Image
93        img = Image.fromarray(img.numpy(), mode='L')
94        if self.transform is not None:
95            img = self.transform(img)
96        if self.target_transform is not None:
97            target = self.target_transform(target)
98
99        return img, target
100
101    def __len__(self):
102        return len(self.data)
103
104    @property
105    def raw_folder(self):
106        """
107        raw folder
108        """
109        return os.path.join('/tmp', 'raw')
110
111    @property
112    def processed_folder(self):
113        """
114        processed folder
115        """
116        return os.path.join('/tmp', 'processed')
117
118    # data preprocessing
119    def preprocess(self, train_dir, train, remove_finished=False):
120        """
121        preprocess
122        """
123        makedir_exist_ok(self.raw_folder)
124        makedir_exist_ok(self.processed_folder)
125        train_list = ['train-images-idx3-ubyte.gz', 'train-labels-idx1-ubyte.gz']
126        test_list = ['t10k-images-idx3-ubyte.gz', 't10k-labels-idx1-ubyte.gz']
127        zip_list = train_list if train else test_list
128        for zip_file in zip_list:
129            print('Extracting {}'.format(zip_file))
130            zip_file_path = os.path.join(train_dir, zip_file)
131            raw_folder_path = os.path.join(self.raw_folder, zip_file)
132            with open(raw_folder_path.replace('.gz', ''), 'wb') as out_f, gzip.GzipFile(zip_file_path) as zip_f:
133                out_f.write(zip_f.read())
134            if remove_finished:
135                os.unlink(zip_file_path)
136        if train:
137            x_train = read_image_file(os.path.join(self.raw_folder, 'train-images-idx3-ubyte'))
138            y_train = read_label_file(os.path.join(self.raw_folder, 'train-labels-idx1-ubyte'))
139            np.random.seed(0)
140            sample_data_num = int(self.data_sampling_scale * len(x_train))
141            idx = np.arange(len(x_train))
142            np.random.shuffle(idx)
143            x_train, y_train = x_train[0:sample_data_num], y_train[0:sample_data_num]
144            training_set = (x_train, y_train)
145            with open(os.path.join(self.processed_folder, self.training_file), 'wb') as f:
146                torch.save(training_set, f)
147        else:
148            test_set = (
149                read_image_file(os.path.join(self.raw_folder, 't10k-images-idx3-ubyte')),
150                read_label_file(os.path.join(self.raw_folder, 't10k-labels-idx1-ubyte'))
151            )
152            with open(os.path.join(self.processed_folder, self.test_file), 'wb') as f:
153                torch.save(test_set, f)
154
155def get_int(b):
156    """
157    get int
158    """
159    return int(codecs.encode(b, 'hex'), 16)
160
161def read_label_file(path):
162    """
163    read label file
164    """
165    with open(path, 'rb') as f:
166        data = f.read()
167        assert get_int(data[:4]) == 2049
168        length = get_int(data[4:8])
169        parsed = np.frombuffer(data, dtype=np.uint8, offset=8)
170        return torch.from_numpy(parsed).view(length).long()
171
172def read_image_file(path):
173    """
174    read image file
175    """
176    with open(path, 'rb') as f:
177        data = f.read()
178        assert get_int(data[:4]) == 2051
179        length = get_int(data[4:8])
180        num_rows = get_int(data[8:12])
181        num_cols = get_int(data[12:16])
182        parsed = np.frombuffer(data, dtype=np.uint8, offset=16)
183        return torch.from_numpy(parsed).view(length, num_rows, num_cols)
184
185def makedir_exist_ok(dirpath):
186    """
187    Python2 support for os.makedirs(.., exist_ok=True)
188    """
189    try:
190        os.makedirs(dirpath)
191    except OSError as e:
192        if e.errno == errno.EEXIST:
193            pass
194        else:
195            raise
196
197def load_data(args):
198    """load_data"""
199    # 若无测试集,训练集做验证集
200    if not os.path.exists(args.test_dir) or not os.listdir(args.test_dir):
201        args.test_dir = args.train_dir
202    # 将数据进行转化,从PIL.Image/numpy.ndarray的数据进转化为torch.FloadTensor
203    trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
204    train_set = MNIST(root=args.train_dir, train=True, transform=trans, data_sampling_scale=args.data_sampling_scale)
205    test_set = MNIST(root=args.test_dir, train=False, transform=trans)
206    # 定义data reader
207    train_loader = torch.utils.data.DataLoader(
208        dataset=train_set,
209        batch_size=args.batch_size,
210        shuffle=True)
211    test_loader = torch.utils.data.DataLoader(
212        dataset=test_set,
213        batch_size=args.batch_size,
214        shuffle=False)
215    return train_loader, test_loader
216
217# 定义网络模型
218class Net(nn.Module):
219    """
220    Net
221    """
222    def __init__(self):
223        super(Net, self).__init__()
224        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
225        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
226        self.conv2_drop = nn.Dropout2d()
227        self.fc1 = nn.Linear(320, 50)
228        self.fc2 = nn.Linear(50, 10)
229
230    def forward(self, x):
231        """
232        forward
233        """
234        x = F.relu(F.max_pool2d(self.conv1(x), 2))
235        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
236        x = x.view(-1, 320)
237        x = F.relu(self.fc1(x))
238        x = F.dropout(x, training=self.training)
239        x = self.fc2(x)
240        return F.log_softmax(x)
241
242def load_state_dict(model, resume_checkpoint_path):
243    """load_state_dict"""
244    if resume_checkpoint_path:
245        model.load_state_dict(torch.load(resume_checkpoint_path))
246
247def run_train(model, args, train_loader):
248    """run_train"""
249    if args.cuda:
250        # Move model to GPU.
251        model.cuda()
252    # 选择优化器
253    optimizer = optim.SGD(model.parameters(), lr=args.lr,
254                          momentum=args.momentum)
255    for epoch in range(1, args.perturb_interval + 1):
256        train(model, args, train_loader, optimizer, epoch)
257
258def train(model, args, train_loader, optimizer, epoch):
259    """
260    train
261    """
262    model.train()
263    for batch_idx, (data, target) in enumerate(train_loader):
264        if args.cuda:
265            data, target = data.cuda(), target.cuda()
266        optimizer.zero_grad()
267        output = model(data)  # 获取预测值
268        loss = F.nll_loss(output, target)  # 计算loss
269        loss.backward()
270        optimizer.step()
271        if batch_idx % args.log_interval == 0:
272            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
273                epoch, batch_idx, len(train_loader),
274                100. * batch_idx / len(train_loader), loss.item()))
275
276def evaluate(model, args, test_loader):
277    """evaluate"""
278    model.eval()
279    test_loss = 0.
280    test_accuracy = 0.
281    for data, target in test_loader:
282        if args.cuda:
283            data, target = data.cuda(), target.cuda()
284        output = model(data)
285        # sum up batch loss
286        test_loss += F.nll_loss(output, target, size_average=False).item()
287        # get the index of the max log-probability
288        pred = output.data.max(1, keepdim=True)[1]
289        test_accuracy += pred.eq(target.data.view_as(pred)).cpu().float().sum()
290    test_loss /= len(test_loader) * args.batch_size
291    test_accuracy /= len(test_loader) * args.batch_size
292    print('\nTest set: Average loss: {:.4f}, Accuracy: {:.2f}%\n'.format(
293        test_loss, 100. * test_accuracy))
294    return float(test_accuracy)
295
296def save(model, output_dir):
297    """
298    save
299    """
300    if not os.path.exists(output_dir):
301        os.makedirs(output_dir)
302    # 保存模型
303    torch.save(model.state_dict(), os.path.join(output_dir, 'model.pkl'))
304
305def report_final(args, metric):
306    """report_final_result"""
307    # 结果上报sdk
308    amaas_tools = AMaasTools(args.job_id, args.trial_id)
309    metric_dict = {args.metric: metric}
310    checkpoint_path = os.path.join(args.output_dir, 'model.pkl')
311    for i in range(3):
312        flag, ret_msg = amaas_tools.report_final_result(metric=metric_dict,
313                                                        export_model_path=args.output_dir,
314                                                        checkpoint_path=checkpoint_path)
315        print("End Report, metric:{}, ret_msg:{}".format(metric, ret_msg))
316        if flag:
317            break
318        time.sleep(1)
319    assert flag, "Report final result to manager failed! Please check whether manager'address or manager'status " \
320                 "is ok! "
321
322def main():
323    """main"""
324    # 获取参数
325    args = parse_arg()
326    # 加载数据集
327    train_loader, test_loader = load_data(args)
328    # 模型定义
329    model = Net()
330    # 继承之前实验的模型参数
331    load_state_dict(model, args.resume_checkpoint_path)
332    # 模型训练
333    run_train(model, args, train_loader)
334    # 模型保存
335    save(model, args.output_dir)
336    # 模型评估
337    acc = evaluate(model, args, test_loader)
338    # 上报结果
339    report_final(args, metric=acc)
340
341if __name__ == '__main__':
342    main()

示例代码对应的yaml配置如下,请保持格式一致

pbt_search_demo.yml示例内容

Plain Text
1#搜索算法参数
2search_strategy:
3  algo: PBT_SEARCH #搜索策略:进化算法
4  params:
5    population_num: 8 #种群个体数量 | [1,10] int类型
6    round: 10 #迭代轮数  |[5,50] int类型
7    perturb_interval: 10 # 扰动间隔  | [1,20] int类型
8    quantile_frac: 0.5 #扰动比例  | (0,0.5] float类型
9    explore_prob: 0.25 #扰动概率 | (0,0.5] float类型
10
11#单次训练时数据的采样比例,单位%
12data_sampling_scale: 100 #|(0,100] int类型
13
14#评价指标参数
15metrics:
16  name: acc #评价指标 | 任意字符串 str类型
17  goal: MAXIMIZE #最大值/最小值 | str类型   MAXIMIZE or MINIMIZE   必须为这两个之一(也即支持大写)
18  expected_value: 100 #早停标准值,评价指标超过该值则结束整个超参搜索,单位% |无限制 int类型
19
20#搜索参数空间
21search_space:
22  batch_size:
23    htype: choice
24    value: [64, 128, 256, 512]
25  lr:
26    htype: loguniform
27    value: [0.0001, 0.1]
28  momentum:
29    htype: uniform
30    value: [0.1, 0.9]

pytorch_predict.py示例代码

Python
1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4@license: Copyright (c) 2019 Baidu.com, Inc. All Rights Reserved.
5@desc: 图像预测算法示例
6"""
7import logging
8import torch
9import torch.nn as nn
10import torch.nn.functional as F
11import base64
12import json
13from PIL import Image
14from io import BytesIO
15from torchvision import datasets, models, transforms
16MODEL_FILE_NAME = 'model.pkl'  # 模型文件名称
17def get_image_transform():
18    """获取图片处理的transform
19    Args:
20        data_type: string, type of data(train/test)
21    Returns:
22        torchvision.transforms.Compose
23    """
24    trans = transforms.Compose([transforms.Resize((28, 28)),
25                                transforms.ToTensor(),
26                                transforms.Normalize((0.5,), (1.0,))])
27    return trans
28def model_fn(model_dir):
29    """模型加载
30    Args:
31        model_dir: 模型路径,该目录存储的文件为在自动搜索作业中选择的输出路径下产出的文件
32    Returns:
33        加载好的模型对象
34    """
35    class Net(nn.Module):
36        """Net"""
37        def __init__(self):
38            super(Net, self).__init__()
39            self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
40            self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
41            self.conv2_drop = nn.Dropout2d()
42            self.fc1 = nn.Linear(320, 50)
43            self.fc2 = nn.Linear(50, 10)
44        def forward(self, x):
45            """
46            forward
47            """
48            x = F.relu(F.max_pool2d(self.conv1(x), 2))
49            x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
50            x = x.view(-1, 320)
51            x = F.relu(self.fc1(x))
52            x = F.dropout(x, training=self.training)
53            x = self.fc2(x)
54            return F.log_softmax(x)
55    model = Net()
56    meta_info_path = "%s/%s" % (model_dir, MODEL_FILE_NAME)
57    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
58    model.load_state_dict(torch.load(meta_info_path, map_location=device))
59    model.to(device)
60    logging.info("device type: " + str(device))
61    return model
62def input_fn(request):
63    """对输入进行格式化,处理为预测需要的输入格式
64    Args:
65        request: api请求的json
66    Returns:
67        预测需要的输入数据,一般为tensor
68    """
69    instances = request['instances']
70    transform_composes = get_image_transform()
71    arr_tensor_data = []
72    for instance in instances:
73        decoded_data = base64.b64decode(instance['data'].encode("utf8"))
74        byte_stream = BytesIO(decoded_data)
75        roiImg = Image.open(byte_stream)
76        target_data = transform_composes(roiImg)
77        arr_tensor_data.append(target_data)
78    tensor_data = torch.stack(arr_tensor_data, dim=0)
79    return tensor_data
80def output_fn(predict_result):
81    """进行输出格式化
82    Args:
83        predict_result: 预测结果
84    Returns:
85        格式化后的预测结果,需能够json序列化以便接口返回
86    """
87    js_str = None
88    if type(predict_result) == torch.Tensor:
89        list_prediction = predict_result.detach().cpu().numpy().tolist()
90        js_str = json.dumps(list_prediction)
91    return js_str

上一篇
Sklearn 0.23.2代码规范
下一篇
Tensorflow2.3.0代码规范