[Pytorch] Tensor 다차원 인덱싱
torch tensor를 다루다 보면 3차원 혹은 4차원 이상의 텐서를 다룰 때, 인덱싱의 어려움이 생길 때가 있습니다. 예를 들어, 4차원에서 각 batch마다 다른 channel의 특정 index만 추출하고 싶을 때가 그렇습니다. 아래 코드는 각 배치마다 다른 채널의 값들을 가져오는 예시입니다. import torch a = torch.randn(2,3,2,3) b,c,w,h = a.shape batch_index = list([i] for i in range(b)) # [[0],[1]] channel_index = [[0,0,0],[1,1,1]] # 각 배치마다 가져올 채널 지정 k = a[batch_index, channel_index] # 각 배치에 해당하는 각 채널들을 인덱싱하여 가져옴 a라는..
2023. 7. 29.