728x90
- torch tensor를 다루다 보면 3차원 혹은 4차원 이상의 텐서를 다룰 때, 인덱싱의 어려움이 생길 때가 있습니다.
- 예를 들어, 4차원에서 각 batch마다 다른 channel의 특정 index만 추출하고 싶을 때가 그렇습니다.
- 아래 코드는 각 배치마다 다른 채널의 값들을 가져오는 예시입니다.
import torch
a = torch.randn(2,3,2,3)
b,c,w,h = a.shape
batch_index = list([i] for i in range(b)) # [[0],[1]]
channel_index = [[0,0,0],[1,1,1]] # 각 배치마다 가져올 채널 지정
k = a[batch_index, channel_index] # 각 배치에 해당하는 각 채널들을 인덱싱하여 가져옴
- a라는 원본 텐서에서 0번 배치에서는 채널 0만 3개 가져오고, 1번 배치에서는 채널 1만 3개를 가져왔습니다.
- 원하는 채널에 대한 인덱스 번호만 설정하면 특정 채널만 추출 가능합니다.
- 단, 주의할 점은 각 배치마다 추출한 인덱싱의 길이가 동일해야 합니다.
- 보통 사용할 때는 각 배치마다 특정 조건을 만족하는 20%정도의 채널만 사용하고자 할때 유용합니다.
각 배치마다 특정 조건에 적한하는 채널들의 인덱스를 구하고 이를 모든 배치에 한 번에 인덱싱하여 가져오면 되니까요.
728x90
'Deep Learning (AI) > Pytorch skills' 카테고리의 다른 글
[Pytorch] Set seed (1) | 2023.11.05 |
---|---|
[Pytorch] nn.Parameter()로 grad 확인하기 (5) | 2023.08.06 |