python类参数定义及数据扩展方式unsqueeze/expand
时间:2022-09-07 11:12
【相关推荐:Python3视频教程 】 将conda环境设置为ai, 这个文件的由来: 由于在yolov1的pytorch实现的损失函数中,看到继承了nn.Module,并且其中两个参数不像c++那里指定类型,那么他们的类型是哪里来的 这里就是在探索这样一件事 操作逻辑: 探究unsqueeze以及expand的使用方法,unsqueeze可以增加一个纬度,但是维度的siz只是1而已,而expand就可以将数据进行复制,将数据变为n 【相关推荐:Python3视频教程 】 以上就是python类参数定义及数据扩展方式unsqueeze/expand的详细内容,更多请关注gxlsystem.com其它相关文章!类的参数定义
conda activate ai
N = box1.size(0) M = box2.size(0)
说明了它是类似一个矩阵的东西,对应的box1的定义就是`torch.rand(10,4)import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
#探究属性S,B是如何产生的,以及box1、box2是如何产生的、如何调用
class yoloLoss(nn.Module):
def __init__(self,S,B):
self.S=S
self.B=B
def compute_iot(self,box1,box2):
N = box1.size(0) #调用方式就表示了变量是什么类型,这里是一个张量,其中每个元素是一个tensor,所以是N*4的张量
M = box2.size(0)
print(M,N)
yoloLoss1 =yoloLoss(10, 11)
yoloLoss1.compute_iot(torch.rand(10,4),torch.rand(11,4))
数据扩展
# 获得一开始的初始化数值:tensor([[a1,a2,a3]])
nn1=torch.rand(1,3)
print(nn1)
# unsqueeze是解压的意思,在第i个维度上进行扩展,将其扩展为tensor([[[a1,a2,a3]]])
nn1=nn1.unsqueeze(0)
print("*"*100)
print(nn1)
#利用expand对数据进行扩展
nn1=nn1.expand(1,3,3)
print("*"*100)
print(nn1)