社区所有版块导航
Python
python开源   Django   Python   DjangoApp   pycharm  
DATA
docker   Elasticsearch  
aigc
aigc   chatgpt  
WEB开发
linux   MongoDB   Redis   DATABASE   NGINX   其他Web框架   web工具   zookeeper   tornado   NoSql   Bootstrap   js   peewee   Git   bottle   IE   MQ   Jquery  
机器学习
机器学习算法  
Python88.com
反馈   公告   社区推广  
产品
短视频  
印度
印度  
Py学习  »  机器学习算法

深度学习pytorch训练代码模板(个人习惯)

深度学习这件小事 • 2 年前 • 228 次点击  

来源|知乎  作者|wfnian

地址|https://zhuanlan.zhihu.com/p/396666255

报道|人工智能前沿讲习

从参数定义,到网络模型定义,再到训练步骤,验证步骤,测试步骤,总结了一套较为直观的模板。目录如下:
1. 导入包以及设置随机种子
2. 以类的方式定义超参数
3. 定义自己的模型
4. 定义早停类(此步骤可以省略)
5. 定义自己的数据集Dataset,DataLoader
6. 实例化模型,设置loss,优化器等
7. 开始训练以及调整lr
8. 绘图
9. 预测

01

导入包以及设置随机种子
import numpy as npimport torchimport torch.nn as nnimport numpy as npimport pandas as pdfrom torch.utils.data import DataLoader, Datasetfrom sklearn.model_selection import train_test_splitimport matplotlib.pyplot as plt
import randomseed = 42torch.manual_seed(seed)np.random.seed(seed)random.seed(seed)

02

以类的方式定义超参数
class argparse():    pass
args = argparse()args.epochs, args.learning_rate, args.patience = [30, 0.001, 4]args.hidden_size, args.input_size= [40, 30] args.device, = [torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),]

03

定义自己的模型
class Your_model(nn.Module):    def __init__(self):        super(Your_model, self).__init__()        pass
def forward(self,x): pass        return x

04

定义早停类(此步骤可以省略)
class EarlyStopping():    def __init__(self,patience=7,verbose=False,delta=0):        self.patience = patience        self.verbose = verbose        self.counter = 0        self.best_score = None        self.early_stop = False        self.val_loss_min = np.Inf        self.delta = delta    def __call__(self,val_loss,model,path):        print("val_loss={}".format(val_loss))        score = -val_loss        if self.best_score is None:            self.best_score = score            self.save_checkpoint(val_loss,model,path)        elif score < self.best_score+self.delta:            self.counter+=1            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')            if self.counter>=self.patience:                self.early_stop = True        else:            self.best_score = score            self.save_checkpoint(val_loss,model,path)            self.counter = 0    def save_checkpoint(self,val_loss,model,path):        if self.verbose:            print(                f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')        torch.save(model.state_dict(), path+'/'+'model_checkpoint.pth')        self.val_loss_min = val_loss

05

定义自己的数据集Dataset,DataLoader
class Dataset_name(Dataset):    def __init__(self, flag='train'):        assert flag in ['train', 'test', 'valid']        self.flag = flag        self.__load_data__()
def __getitem__(self, index): pass def __len__(self): pass
def __load_data__(self, csv_paths: list): pass print( "train_X.shape:{}\ntrain_Y.shape:{}\nvalid_X.shape:{}\nvalid_Y.shape:{}\n" .format(self.train_X.shape, self.train_Y.shape, self.valid_X.shape, self.valid_Y.shape))
train_dataset = Dataset_name(flag='train')train_dataloader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)valid_dataset = Dataset_name(flag='valid')valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=64, shuffle=True)

06

实例化模型,设置loss,优化器等
model = Your_model().to(args.device)criterion = torch.nn.MSELoss()optimizer = torch.optim.Adam(Your_model.parameters(),lr=args.learning_rate)
train_loss = []valid_loss = []train_epochs_loss = []valid_epochs_loss = []
early_stopping = EarlyStopping(patience=args.patience,verbose=True)

07

开始训练以及调整lr
for epoch in range(args.epochs):    Your_model.train()    train_epoch_loss = []    for idx,(data_x,data_y) in enumerate(train_dataloader,0):        data_x = data_x.to(torch.float32).to(args.device)        data_y = data_y.to(torch.float32).to(args.device)        outputs = Your_model(data_x)        optimizer.zero_grad()        loss = criterion(data_y,outputs)        loss.backward()        optimizer.step()        train_epoch_loss.append(loss.item())


    
        train_loss.append(loss.item())        if idx%(len(train_dataloader)//2)==0:            print("epoch={}/{},{}/{}of train, loss={}".format(                epoch, args.epochs, idx, len(train_dataloader),loss.item()))    train_epochs_loss.append(np.average(train_epoch_loss))
#=====================valid============================ Your_model.eval() valid_epoch_loss = [] for idx,(data_x,data_y) in enumerate(valid_dataloader,0): data_x = data_x.to(torch.float32).to(args.device) data_y = data_y.to(torch.float32).to(args.device) outputs = Your_model(data_x) loss = criterion(outputs,data_y) valid_epoch_loss.append(loss.item()) valid_loss.append(loss.item()) valid_epochs_loss.append(np.average(valid_epoch_loss)) #==================early stopping====================== early_stopping(valid_epochs_loss[-1],model=Your_model,path=r'c:\\your_model_to_save') if early_stopping.early_stop: print("Early stopping") break #====================adjust lr======================== lr_adjust = { 2: 5e-5, 4: 1e-5, 6: 5e-6, 8: 1e-6, 10: 5e-7, 15: 1e-7, 20: 5e-8 } if epoch in lr_adjust.keys(): lr = lr_adjust[epoch] for param_group in optimizer.param_groups: param_group['lr'] = lr        print('Updating learning rate to {}'.format(lr))

08

绘图
plt.figure(figsize=(12,4))plt.subplot(121)plt.plot(train_loss[:])plt.title("train_loss")plt.subplot(122)plt.plot(train_epochs_loss[1:],'-o',label="train_loss")plt.plot(valid_epochs_loss[1:],'-o',label="valid_loss")plt.title("epochs_loss")plt.legend()plt.show()

09

预测
# 此处可定义一个预测集的Dataloader。也可以直接将你的预测数据reshape,添加batch_size=1Your_model.eval()predict = Your_model(data)


技术交流群邀请函



△长按添加小助手

扫描二维码添加小助手微信(ID : HIT_NLP

请备注:姓名-学校/公司-研究方向-城市
(如:小事-浙大-对话系统-北京)
即可申请加入深度学习/机器学习等技术交流群
为您推荐

思考丨到底什么叫算法工程师的落地能力?

Transformer模型有多少种变体?看看这篇全面综述
从SGD到NadaMax,十种优化算法原理及实现
各种注意力机制的PyTorch实现
你写的ML代码占多少内存?这件事很重要

Python社区是高质量的Python/Django开发社区
本文地址:http://www.python88.com/topic/118959
 
228 次点击