NLP/AI/Statistics

[cs231n] Note 3: Optimization (Gradient Descent) 본문

Stanford Lectures : AI/CS231n

[cs231n] Note 3: Optimization (Gradient Descent)

Danbi Cho 2021. 1. 12. 20:47

이전에 딥러닝 모델에서 최적의 성능을 얻기 위한 최적화(Optimization) 방법에 대하여 설명하였다. 

 

이번 글에서는 최적화 방법 중에 하나인 gradient desent 방법에 대하여 소개하고자 한다. 

 

모델을 학습할 때 loss function의 gradient를 계산하여 파라미터를 업데이트하는 방법을 gradient descent라고 하며, 

 

아래의 코드는 가장 기본적인 vanila 버전의 코드이다. 

$ while True:
$	weights_grad = evaluate_gradient(loss_function, data, weights)
$	weights += step_size * weights_grad  # parameter update

 

Gradient Descent 방법은 최근 neural network에서 가장 흔히 사용되고 있으며 loss function을 최소화시켜 최적화시키는 방법으로 연구되었다. Gradient Descent 방법은 크게 두 가지 방법론으로 설명된다. 

 

Mini-batch Gradient Descent

대용량 데이터를 처리하는 경우 모든 학습 데이터에 대한 파라미터를 단일적으로 업데이트 하는 것은 비효율적이다. 그래서 소개된 방법은 데이터에 batch를 구성하는 것이다. 이를 mini-batch gradient descent라고 하며, 학습 데이터를 batch 사이즈에 따라 나눈 후 각각의 batch 별로 gradient를 계산하는 것이다. 

 

예를 들어, ConvNet에서는 120만 개의 학습 데이터를 256 batch로 나누어 파라미터를 업데이트하였다. 

$ while True:
$	data_batch = sample_training_data(data, 256) # 256 is batch size
$    weights_grad = evaluate_gradient(loss_fun, data_batch, weihts)
$    weights += - step_size*weights_grad  # parameter update

mini-batch gradient descent는 batch 별로 데이터를 나누어 gradient를 계산하고 파라미터를 업데이트하기 때문에 연산이 빠르고 최적의 loss 값으로 빠르게 수렴하는 장점이 있다. 

 

Stochastic Gradient Descent

Stochastic Gradient Descent (SGD) 방법은 mini-batch gradient descent 방식의 일종이라고 생각할 수 있다. 

 

mini-batch gradient descent에서 batch의 크기가 1인 경우를 의미하며, 모든 학습 데이터에 대하여 하나의 데이터씩 gradient를 계산하여 파라미터를 업데이트하는 방식이다. 

 

SGD는 gradient를 평가할 때 하나의 데이터만을 처리하기 때문에 최적의 파라미터를 구성하는 데에 시간이 조금 더 오래걸리는 단점이 있지만 그만큼 섬세하게 접근한다고 할 수 있다.

 

참고) https://cs231n.github.io/optimization-1/#gd

 

CS231n Convolutional Neural Networks for Visual Recognition

Table of Contents: Introduction In the previous section we introduced two key components in context of the image classification task: A (parameterized) score function mapping the raw image pixels to class scores (e.g. a linear function) A loss function tha

cs231n.github.io

 

Comments