计算机系统应用教程网站

网站首页 > 技术文章 正文

Pytorch:使用scatter 来产生one-hot 编码

btikc 2024-09-25 15:14:41 技术文章 18 ℃ 0 评论

例如我们处理DNA序列, DNA序列是由 ACGT四个字符产生的字符串。

我们希望 序列 “ACGTTT” 中 A编码为1000, C编码0100,G编码0010 T编码 0001


import torch

>>> aseq="ACGTTT"
# 定义一个字典。数字从0 开始,后面还需要这些数字做索引。
>>> adb={'A':0,'C':1,'G':2,'T':3}  
#按照字典adb 将字符映射为数字。
>>> target=[adb[ch] for ch in aseq]
>>> target
[0, 1, 2, 3, 3, 3]
#python的列表转化为 张量
>>> target=torch.tensor(target)
>>> target
tensor([0, 1, 2, 3, 3, 3])


# 开始转化one-hot
#1. 按照序列长度和 编码长度产生0矩阵。
#    这里len(adb),可改为4. 就是说用4长的向量表示一个核苷酸(A、C、G、T)
>>> target_onehot=torch.zeros(target.shape[0],len(adb))
>>> target_onehot
tensor([[0., 0., 0., 0.],
 [0., 0., 0., 0.],
 [0., 0., 0., 0.],
 [0., 0., 0., 0.],
 [0., 0., 0., 0.],
 [0., 0., 0., 0.]])

# 2 使用scatter 产生onehot编码。
# scatter()  不改变原来变量值; scatter_() 改变原来变量值。
# scatter_(dim, index, val/src)  这里使用val
# 就是说,在target_onehot 的 列维度(dim=1), 按照 
# 索引值  target.unsqueeze(1),  填满
#  val  值 1
>>> target_onehot.scatter_(1,target.unsqueeze(1),1)
tensor([[1., 0., 0., 0.],
 [0., 1., 0., 0.],
 [0., 0., 1., 0.],
 [0., 0., 0., 1.],
 [0., 0., 0., 1.],
 [0., 0., 0., 1.]])

# 3 最后的结果
>>> target_onehot
tensor([[1., 0., 0., 0.],
 [0., 1., 0., 0.],
 [0., 0., 1., 0.],
 [0., 0., 0., 1.],
 [0., 0., 0., 1.],
 [0., 0., 0., 1.]])

# 注:
# unsqueeze 帮助target添加一个额外维度,为下一步做索引值使用。
>>> target.unsqueeze(1)
tensor([[0],
 [1],
 [2],
 [3],
 [3],
 [3]])

Tags:

本文暂时没有评论,来添加一个吧(●'◡'●)

欢迎 发表评论:

最近发表
标签列表