본문 바로가기
Deep Learning (AI)/Pytorch skills

[Pytorch] Tensor 다차원 인덱싱

by 스프링섬머 2023. 7. 29.
728x90
  • torch tensor를 다루다 보면 3차원 혹은 4차원 이상의 텐서를 다룰 때, 인덱싱의 어려움이 생길 때가 있습니다.
  • 예를 들어, 4차원에서 각 batch마다 다른 channel의 특정 index만 추출하고 싶을 때가 그렇습니다.
  • 아래 코드는 각 배치마다 다른 채널의 값들을 가져오는 예시입니다.
import torch

a = torch.randn(2,3,2,3)
b,c,w,h = a.shape
batch_index = list([i] for i in range(b))  # [[0],[1]]
channel_index = [[0,0,0],[1,1,1]]  # 각 배치마다 가져올 채널 지정

k = a[batch_index, channel_index]  # 각 배치에 해당하는 각 채널들을 인덱싱하여 가져옴

왼쪽은 원본 a 텐서, 오른쪽은 배치마다 채널인덱싱한 k 텐서

  • a라는 원본 텐서에서 0번 배치에서는 채널 0만 3개 가져오고, 1번 배치에서는 채널 1만 3개를 가져왔습니다.
  • 원하는 채널에 대한 인덱스 번호만 설정하면 특정 채널만 추출 가능합니다.
  • 단, 주의할 점은 각 배치마다 추출한 인덱싱의 길이가 동일해야 합니다. 
  • 보통 사용할 때는 각 배치마다 특정 조건을 만족하는 20%정도의 채널만 사용하고자 할때 유용합니다.
    각 배치마다 특정 조건에 적한하는 채널들의 인덱스를 구하고 이를 모든 배치에 한 번에 인덱싱하여 가져오면 되니까요.
728x90

'Deep Learning (AI) > Pytorch skills' 카테고리의 다른 글

[Pytorch] Set seed  (1) 2023.11.05
[Pytorch] nn.Parameter()로 grad 확인하기  (5) 2023.08.06