본문 바로가기
Deep Learning (AI)

torch model parameter(params) 개수 구하기

by 스프링섬머 2023. 7. 16.
728x90
from torchvision import models


def get_n_params(model):
    pp = 0
    for p in list(model.parameters()):
        nn = 1
        for s in list(p.size()):
            nn = nn * s
        pp += nn
    return pp

resnet50_pretrained = models.resnet50(pretrained=True)
parms_num = get_n_params(resnet50_pretrained)  # 25557032 = 25.55M
728x90

'Deep Learning (AI)' 카테고리의 다른 글

torch seed 고정 코드  (3) 2023.07.15