티스토리 뷰

AI

[CS 224N] AdamW optimizer 구현

꼬꼬팜 2025. 6. 18. 00:41

GPT-2 layer를 구현한 이후에, 실제 학습을 위한 optimizer 구현

You will further implement the step() function of the Adam Optimizer based on Decoupled Weight Decay Regularization and Adam: A Method for Stochastic Optimization in order to train a sentiment classifier.

 

Decoupled Weight Decay RegularizationAdam: A Method for Stochastic Optimization에 기반해서 AdamW optimizer를 직접 구현

Adam Optimzer

 

Adam은 SGD 기반의 옵티마이저로, 각각의 파라미터에 대해 1차 모멘트(mean)와 2차 모멘트(variance) 를 유지하면서 learning rate을 adapctive하게 조절해주는 방식
RMSProp과 Momentum의 장점을 모두 가져온 방식으로 널리 사용

 

  • 1차 모멘트: m_t ← gradient의 지수이동평균
  • 2차 모멘트: v_t ← gradient 제곱의 지수이동평균
  • 이 두 가지를 통해 parameter update 시 안정성 향상

 

 

SGD(Stochastic Gradient Descent)

전체 데이터 대신, 미니배치 또는 한 샘플을 기반으로 매번 파라미터를 업데이트
-> 실제 학습할 corpus는 너무 많아서 언어모델에선 샘플링 방식으로 해결함

 

 

AdamW 

AdamW는 기존 Adam과 달리 weight decay를 gradient에 포함시키지 않고, 파라미터 업데이트 이후 직접 decay시키는 구조

  • Adam: grad ← grad + weight_decay * param 방식 → L2 Regularization처럼 작동
  • AdamW: param ← param - lr * weight_decay * param 방식 → Decoupled 방식

=> weight decay를 옵티마이저 내부적으로 “분리해서 처리”함으로써 성능 안정성을 향상

구현

def step(self, closure: Callable = None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")

                # State should be stored in this dictionary.
                state = self.state[p]

                # Access hyperparameters from the `group` dictionary.
                alpha = group["lr"]


                ### TODO: Complete the implementation of AdamW here, reading and saving
                ###       your state in the `state` dictionary above.
                ###       The hyperparameters can be read from the `group` dictionary
                ###       (they are lr, betas, eps, weight_decay, as saved in the constructor).
                ###
                ###       To complete this implementation:
                ###       1. Update the first and second moments of the gradients.
                ###       2. Apply bias correction
                ###          (using the "efficient version" given in https://arxiv.org/abs/1412.6980;
                ###          also given in the pseudo-code in the project description).
                ###       3. Update parameters (p.data).
                ###       4. Apply weight decay after the main gradient-based updates.
                ###
                ###       Refer to the default project handout for more details.
                ### YOUR CODE HERE

                beta1, beta2 = group["betas"]
                eps = group["eps"]
                weight_decay = group["weight_decay"]
                correct_bias = group["correct_bias"]

                # 시작 단계 상태 초기화
                if len(state) == 0:
                    state["step"] = 0
                    state["exp_avg"] = torch.zeros_like(p.data) 
                    state["exp_avg_sq"] = torch.zeros_like(p.data)

                exp_avg = state["exp_avg"] # m_t
                exp_avg_sq = state["exp_avg_sq"] # v_t

                state["step"] += 1
                step = state["step"]

                # 1차 모멘트
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)

                # 2차 모멘트
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)

                # Bias correction
                if correct_bias:
                    bias_correction1 = 1 - beta1 ** step
                    bias_correction2 = 1 - beta2 ** step
                    step_size = alpha * math.sqrt(bias_correction2) / bias_correction1
                else:
                    step_size = alpha

                denom = exp_avg_sq.sqrt().add_(eps)

                # Parameter update
                p.data.addcdiv_(exp_avg, denom, value=-step_size)

                # Weight decay 를 gradient와 분리해서 적용
                if weight_decay > 0.0:
                    p.data.add_(p.data, alpha=-alpha * weight_decay)

        return loss

Our reference uses the “efficient” method of computing the bias correction mentioned at the end of section 2 “Algorithm” of in Kigma and (and at the end of the algorithm above) in place of the intermediate m_hat and v_hat method. Similarly, the learning rate should be incorporated into the weight decay update

 

위에 언급된 기존 알고리즘과 다르게 "efficient method" 를 적용한 방법

 

기존 알고리즘

m̂_t = m_t / (1 - β1^t)  
v̂_t = v_t / (1 - β2^t)
θ_t = θ_t - α * m̂_t / (sqrt(v̂_t) + ε)

 

efficient method

step_size = α * sqrt(1 - β2^t) / (1 - β1^t)
θ_t = θ_t - step_size * m_t / (sqrt(v_t) + ε)

 

m_t, v_t는 그대로 사용하고, 보정 계수를 step_size 계산 시 learning rate에 함께 곱해서 처리

=> m_hat, v_hat을 별도로 계산해서 쓰는 방식 대신, 바이어스 보정과 학습률 적용을 한꺼번에 처리

 

추가로 weight decay를 적용할 때, Adam에서는 gradient에 더해졌지만

p.data += - alpha * weight_decay * p.data

 

AdamW는 따로 적용해서 최적화가 gradient 방향과 독립적으로 진행

공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
«   2026/02   »
1 2 3 4 5 6 7
8 9 10 11 12 13 14
15 16 17 18 19 20 21
22 23 24 25 26 27 28
글 보관함