일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | 5 | 6 | 7 |
8 | 9 | 10 | 11 | 12 | 13 | 14 |
15 | 16 | 17 | 18 | 19 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 27 | 28 |
29 | 30 | 31 |
- Gradient Tree Boosting
- Machine Learning
- lime
- Explainable AI
- XGBoost
- Back-propagation
- data science
- deep learning
- Gradient Boosting Machine
- Today
- Total
Kicarussays
[논문리뷰/설명] XGBoost: A Scalable Tree Boosting System (1) 본문
아무래도 EMR 데이터를 다루다보면 테이블 데이터에 사용하기 적합한 방법론을 많이 찾아보게 됩니다. 딥러닝이 많이 적용되는 영상이나 신호처럼 데이터 특성에 알맞은 메소드가 꽤 명확한 데이터들과는 달리(영상은 CNN, 신호는 RNN), 테이블 데이터는 데이터 특성에 알맞은 방법을 찾기가 어려운 편입니다. 특히 결측치를 처리하기가 어렵고, Feature Engineering이 까다로운 이유인 것 같습니다.
XGBoost는 이러한 테이블 데이터를 분석하는 사람들에겐 가뭄에 단비같은 방법론입니다. Random Forest, Logistic Regression 등 기성 머신러닝 방법론들과 비교했을 때, 확실히 유의미하게 우수한 성능을 보입니다. 부끄럽게도 그 동안 패키지만 다운받아서 쓰다가, 좀 더 깊게 이해해보고 싶은 욕심에 논문 리뷰를 하게 되었네요. 이제 시작하겠습니다!
논문 링크: https://arxiv.org/pdf/1603.02754.pdf
Abstract
- XGBoost는 확장성이 우수하다.
1. Introduction
- https://github.com/dmlc/xgboost (XGBoost 오픈 소스 패키지 링크)
- XGBoost의 우수성
- 탈중앙화(Out-of-core) 계산방식으로 빠름 (직접 사용해보았을 때는 잘 체감되지 않았고, LightGBM과 비교하였을 때 LightGBM이 엄청나게 빠르다는 것은 체감이 되었던 것 같습니다, LightGBM은 이후에 다뤄볼 생각입니다)
- 하드웨어적 성격의 장점도 쓰여있지만, 본 포스팅에서는 알고리즘적 내용을 위주로 다뤄보도록 하겠습니다.
2. Tree Boosting In a Nutshell
Tree Boosting 과정이 어떻게 이루어지는지 살펴볼 것입니다.
2.1 Regularized Learning Objective
위의 이미지는 논문에서 Tree Boosting을 설명하기 위해 쓴 것입니다. 각각의 트리는 Input(위 이미지에서는 사람)을 받아 Output으로 어떤 값을 내고, 각각의 트리로부터 받은 값들을 합친 것을 최종 예측값(Prediction)으로 결정하게 됩니다.
이제 위의 Tree들을 어떻게 수식으로 나타내는지 살펴보겠습니다.
주어진 $n \times m$ ($n$ samples, $m$ features) 데이터셋 $\mathcal{D} = \left\{ (\mathbb{x}_i, y_i)\right\}$ ($|\mathcal{D}| = n, \mathbb{x}_i \in \mathbb{R}^m, y_i \in \mathbb{R}$)에 대하여, Tree Ensemble 모델은 K개의 Tree를 사용하고, $$\hat{y}_i = \phi(\mathbb{x}_i) = \sum_{k=1}^{K} f_k(\mathbb{x}_i), f_k \in \mathcal{F}, \tag{1}$$
여기서 $\mathcal{F} = \left\{ f(\mathbb{x}) = w_{q(\mathbb{x})} \right\}(q: \mathbb{R}^m \to T, w \in \mathbb{R}^{T})$는 Regression Tree들을 나타냅니다. 여기서 $T$는 각 Tree의 Leaf의 개수입니다.
XGBoost는 이 트리들을 학습하는데 정규화된 손실 함수(Regularized Loss Function)를 사용합니다. 해당 손실함수 $\mathcal{L}$에 대하여, $$\mathcal{L}(\phi) = \sum_{i} l(\hat{y_i}, y_i) + \sum_{k} \Omega(f_k), \tag{2}$$ $$\text{where } \Omega(f) = \gamma T + \frac{1}{2} \lambda ||w||^2$$ 여기서 $l$은 실제 output $y_i$와 예측 output $\hat{y_i}$ 사이의 차이를 계산하는 미분가능한 Convex 손실 함수입니다. 여기에 각 Tree에 대하여 $\Omega$라는 Regularization 텀을 두었는데, 직관적으로 보면 T가 작고(leaf의 개수가 적고 -> 모델이 simple하고), $||w||^2$가 작은(각 leaf의 L2 norm이 작은 -> 모델의 오버피팅을 방지하는) 방향으로 학습이 되게 합니다. 전통적 Gradient Tree Boosting에서는 이러한 Regularization 텀이 없는데, XGBoost에서는 보다 간단하고, 오버피팅을 잘 방지하는 모델을 구성하기 위해 Regularization 텀을 추가하였습니다. 본 논문에서는 당연한 이야기지만 Regularization Parameter $\gamma$, $\lambda$ 값이 0이 되면 기본 Gradient Tree Boosting 모델로 돌아간다고 언급하고 있습니다.
2.2 Gradient Tree Boosting
이제 트리를 어떻게 수식으로 나타내고, 해당 트리로부터 어떻게 손실 함수를 계산하는지 알았으니, 트리를 어떻게 학습하는지 알아보겠습니다.
어떤 데이터로부터 분류 기능을 가장 잘 수행하는 트리를 찾는 문제는 NP-완전 문제로 알려져 있고, 가능한 모든 트리로부터 가장 좋은 트리를 찾는 것은 매우 어려운 문제입니다. (NP-완전에 대한 개념은 잘 정리되어 있는 문서들이 많으니, 한 번 읽어보시길 권장드립니다)
XGBoost는 이러한 트리를 찾는 문제를 해결하기 위해, 매 iteration마다 트리에 가지를 하나씩 늘려가는 방식(additive manner)을 제안합니다. 아래 식을 한 번 보시겠습니다. $$\mathcal{L}^{(t)}=\sum_{i=1}^{n} l(y_i, \hat{y_i}^{(t-1)} + f_t(\mathbb{x}_i)) + \Omega(f_t) + \text{Constant}, \tag{3}$$
이 식에서 $\mathcal{L}^{(t)}$는 $t$번 째 iteration에 계산할 손실 함수, 시그마에 $i=1$ 부터 $n$까지는 2.1에서 총 데이터 샘플 수가 n이었던 것을 떠올리시면 됩니다. (단순히 모든 데이터의 손실 함수를 더한다는 뜻입니다) $l$은 임의의 손실 함수이고, $y_i$는 $i$번 째 샘플의 실제 output, $\hat{y}^{(t-1)}$는 $t-1$번 째의 iteration에서의 트리로부터 나온 예측값입니다. $f_t (\mathbb{x}_i)$는 $t$번 째 iteration에서의 트리에 $i$번 째 샘플을 넣은 예측값이고, $\Omega(f_t)$는 2.1에서 설명한 regulation term입니다. 마지막 Constant는 $t-1$번째 iteration까지 regulation term의 합입니다. 현재 iteration에서 변할 수 없는 부분이니 상수로 둔 것입니다. (논문에는 생략되어 있습니다.)
(3)번 식에서 왜 이 방법론의 이름이 XGBoost인지 알 수 있습니다. Gradient Boosting은 앙상블 학습을 진행하는 과정에서 이전 iteration에서 잘 학습하지 못한 부분에 가중치를 두고 학습하는 방식입니다. (3)번 식을 직관적으로 해석해보면, 직전 iteration의 예측값($\hat{y}^{(t-1)}$)에 현재 iteration의 예측값($f_t (\mathbb{x}_i)$)을 더해서 손실함수를 최소화하는 형태입니다. 이것으로부터 이 방식이 Gradient Boosting 방식이라는 것을 알 수 있습니다.
이어서, (3)번 식을 일반적인 손실 함수 하에서 최적화하기 좋은 형태로 만들기 위해 테일러 확장(Taylor Expansion)을 통한 2차 근사(Second-order approximation)를 합니다. (2차 근사식이 최적화하기 좋은 형태인 이유는 아래에서 설명하겠습니다) $$\mathcal{L}^{(t)} \simeq \sum_{i=1}^{n}[ l(y_i, \hat{y}^{(t-1)}) + \frac{\partial}{\partial \hat{y_i}^{(t-1)}} l(y_i, \hat{y}^{(t-1)}) \cdot f_t(\mathbb{x}_i) + \frac{1}{2} \cdot \frac{\partial}{\partial^2 \hat{y_i}^{(t-1)}} l(y_i, \hat{y}^{(t-1)}) \cdot f_t^2 (\mathbb{x}_i) ] + \Omega(f_t), \tag{4}$$
이 2차 근사식은 $l'(y) = l(y_i, y + f_t(\mathbb{x}_i))$, $y = \hat{y_i}^{t-1} - f_t(\mathbb{x}_i)$로 두고 직접 계산해보시기 바랍니다.
이 식에서 본 논문은 계산의 편리함을 위해 $g_i = \frac{\partial}{\partial \hat{y_i}^{(t-1)}} l(y_i, \hat{y}^{(t-1)})$, $h_i = \frac{\partial}{\partial^2 \hat{y_i}^{(t-1)}} l(y_i, \hat{y}^{(t-1)})$로 치환하고, $$\mathcal{L}^{(t)} \simeq \sum_{i=1}^{n}[ l(y_i, \hat{y}^{(t-1)}) + g_i f_t(\mathbb{x}_i) + \frac{1}{2} h_i f_t^2 (\mathbb{x}_i) ] + \Omega(f_t), \tag{5}$$
와 같이 나타냅니다. 여기서 상수인 항들을 모두 제거하면 최종적으로 t번 째 iteration마다 최적화할 손실함수 식을 다음과 같이 얻을 수 있습니다. $$ \tilde{\mathcal{L}}^{(t)} = \sum_{i=1}^{n}[g_i f_t(\mathbb{x}_i) + \frac{1}{2} h_i f_t^2 (\mathbb{x}_i) ] + \Omega(f_t), \tag{6}$$
$l(y_i, \hat{y}^{(t-1)})$은 직전 iteration에서 이미 계산된 값이기 때문에 상수취급되어 제거됩니다.
2차 근사식이 좋은 이유는, t번 째 iteration에서 최종 손실함수가 $g_i, h_i$에 의존하여 최적화할 수 있기 때문에, 손실함수의 종류와 관계없이 간단하게 계산할 수 있다는 점입니다.
이제 모든 leaf $j$에 대해서, 새로운 집합 $I_j = \left\{ i | q(\mathbb{x}_i) = j\right\}$ 를 정의합니다. 이것으로 (6)번 식을 다음과 같이 쓸 수 있습니다. $$\begin{align*} \tilde{\mathcal{L}}^{(t)} &= \sum_{i=1}^{n}[g_i f_t(\mathbb{x}_i) + \frac{1}{2} h_i f_t^2 (\mathbb{x}_i) ] + \gamma T + \frac{1}{2}\lambda \sum_{j=1}^{T} w_{j}^{2} \\ &= \sum_{j=1}^{T} [ (\sum_{i \in I_j} g_i)w_j + \frac{1}{2} (\sum_{i \in I_j} h_i + \lambda) w_{j}^2 ] + \gamma T \end{align*} \tag{7}$$
이 식을 $w_j$에 대해 미분하여, $q(\mathbb{x})$에 대해 가장 손실함수 값을 최소화할 수 있는 최적의 $j$번 째 leaf의 weight인 $w_{j}^*$를 다음과 같이 구할 수 있습니다. $$w_{j}^* = -\frac{\sum_{i \in I_j} g_i}{\sum_{i \in I_j} h_i + \lambda}, \tag{8}$$ 그리고 이 $w_{j}^*$를 그대로 (7)번 식에 대입하면, $$\tilde{\mathcal{L}}^{(t)} (q) = -\frac{1}{2} \sum_{j=1}^{T} \frac{\sum_{i \in I_j} g_i}{\sum_{i \in I_j} h_i + \lambda} + \gamma T, \tag{9}$$로 손실함수의 최적값을 구할 수 있습니다.
논문의 Figure 2는 손실함수가 2차 근사식을 바탕으로 어떻게 계산되는지 보여줍니다.
모든 가능한 트리를 나열하여 최적 트리를 찾는 것은 거의 불가능하기 때문에, 2차 근사식을 바탕으로 한 손실함수를 토대로 매 iteration마다 하나의 leaf로부터 가지를 늘려나가는 것이 효율적입니다. 가지를 늘렸을 때 (9)번 식으로 나타나는 손실 함수가 최대한 감소하도록 하는 split point를 찾는 것이 XGBoost의 목표입니다. 가지를 늘렸을 때, 왼쪽과 오른쪽으로 분류되는 sample들의 집합을 각각 $I_L, I_R$이라 하고, $I = I_L \cup I_R$으로 정의합시다. 그러면 가지를 늘렸을 때 손실함수 값의 감소량 $\mathcal{L}_{split}$은 $$\mathcal{L}_{split} = \frac{1}{2} \left[ \frac{(\sum_{i \in I_L} g_i)^2}{\sum_{i \in I_L} h_i + \lambda } + \frac{(\sum_{i \in I_R} g_i)^2}{\sum_{i \in I_R} h_i + \lambda } - \frac{(\sum_{i \in I} g_i)^2}{\sum_{i \in I} h_i + \lambda } \right] - \gamma \tag{10}$$ 입니다. 이는 가지를 나누기 전의 손실함수에서 가지를 나눈 후의 손실함수를 뺀 식입니다.
이 식은 트리를 학습하며 가지를 늘릴 때마다 사용되는 식이 됩니다.
2.3 Shrinkage and Column Subsampling
트리 학습 알고리즘을 설명하기에 앞서, 2.1의 (2)번 식에서 설명한 $\Omega(f) = \gamma T + \frac{1}{2} \lambda ||w||^2$을 통해 트리 수준에서의 오버피팅 방지 테크닉을 상기해봅시다. 추가적으로, 2.3에서는 학습 알고리즘 수행 과정에서 오버피팅을 방지하는 두 가지 테크닉을 소개합니다.
(1) Shrinkage
- 매 iteration마다 새로 추가되는 Tree를 제외한, 직전 iteration까지 학습된 Tree들의 모든 leaf의 weight에 파라미터 $\eta \in (0, 1)$를 곱해주는 기법
- iteration이 지날수록 생성되는 트리들의 영향을 줄여서, 이후 생성될 트리들이 모델을 개선할 여지를 남기는 방식으로 오버피팅을 방지한다.
- Learning Rate(학습률)와 비슷한 개념
(2) Column Subsampling
- Random Forest에서 사용하는 기법
- Subsampling한 일부의 Column들만을 활용한 병렬학습을 통해 오버피팅을 방지한다.
3. Split Finding Algorithms
XGBoost에서 가지를 치기 위한 최적의 Split Point를 찾는 알고리즘에 대해 알아보겠습니다.
3.1 Basic Exact Greedy Algorithm
트리를 학습하는 과정에서 가장 좋은 Split Point를 찾기 위해 가능한 모든 Split Point를 나열해볼 수 있을 것입니다. 이렇게 가능한 모든 경우를 나열하여 Split Point를 찾는 과정이 Exact Greedy Algorithm입니다.
위 Algorithm 1에서 $I$는 Split 하기 전의 대상이 되는 모든 Sample들의 집합입니다. $d$도 feature dimension이라고 쓰여있긴 한데, $m$으로, gain도 score로 생각하시고 이어지는 순서대로 알고리즘을 파악하시면 좋을 것 같습니다.
- score에 0 입력
- (10)번 식으로 나타나는, Split 이후 손실함수 값의 감소량을 계산하기 위해, Split 이전의 그래디언트 값들을 $G, H$에 입력
- m개의 feature들에 대해서, 가능한 모든 Split point로 $G_L, H_L, G_R, H_R$을 계산하고, score에 손실함수 값의 감소량 입력
- score가 가장 높았던 Split point 선택
3.2 Approximate Algorithm
가능한 모든 Split point를 찾는 것이 아니라, Split point 후보군을 선정하여, 그 후보군 내에서 가장 좋은 Split point를 찾는 과정이 Approximate Algorithm입니다.
Exact Greedy Algorithm은 확실하게 가장 좋은 Split point를 찾을 수는 있지만, 데이터 전체가 메모리에 올라가지 않는다면 수행할 수 없고, 분산 환경에서도 수행할 수 없습니다. 이러한 이유로 Approximate Algorithm을 사용하게 되고, XGBoost는 두 가지 방식의 Approximate Algorithm을 제안합니다.
(1) Global Variant
한 번에 모든 Split point 후보군을 제시
(2) Local Variant
매 iteration마다 Split point 후보군을 제시
당연히 Local Variant 방식이 더 많은 계산량을 요구하게 될 것입니다. 본 논문은 Split point를 찾는 방식들의 성능 차이를 비교한 그래프를 제시합니다.
eps는 Approximate Algorithm의 파라미터 값이고, eps가 작을수록 더 많은 후보군을 제시한다고 생각하시면 됩니다. eps가 수식으로 어떻게 적용되는지는 본 포스팅 이후 (2)편에서 다루도록 하겠습니다.
당연히 Exact Greedy Algorithm이 가장 우수하지만, 낮은 eps(0.05)의 Global Variant 방식과 eps 0.3 수준의 Local Variant 방식이 비슷한 성능을 보이는 것을 확인할 수 있습니다. 따라서 Approximate Algorithm을 잘 설계하면 시간도 절약하고, Exact Greedy Algorithm의 성능에 크게 뒤처지지 않게 할 수 있습니다. XGBoost 오픈 소스 패키지에는 어떤 알고리즘을 사용할지 사용자가 자유롭게 선택할 수 있습니다.
3.3 Weighted Quantile Sketch
Quantile Sketch가 무엇인지 생소하신 분들도 많을 것 같습니다. 저도 이게 무슨 뜻인지 잘 몰라서 한참 찾아본 것 같습니다.
Quantile은 분위수라는 뜻입니다. 1분위수, 3분위수 등 데이터를 나누는 기준이 되는 것으로 이해하시면 됩니다. 그런 의미에서 Quantile Sketch는 데이터를 나누는(Split하는) 설계 과정으로 이해하시면 도움이 될 것 같습니다.
XGBoost에서는 각 Sample $\mathbb{x}_i$마다 Gradient값인 $g_i, h_i$가 있습니다. 따라서 샘플마다 서로 다른 Gradient값을 갖고 있고, 이것을 일종의 weight로 간주하고 이를 고려하여 Split point 후보군을 찾는 것이 Weighted Quantile Sketch입니다.
본 논문에서는 Appendix에 Weighted Quantile Sketch에 대해서 자세한 수식으로 다루고 있고, 이는 본 포스팅 이후 (2)편에서 다뤄볼 예정입니다.
3.4 Sparsity-aware Split Finding
XGBoost는 놀랍게도 결측치가 있는 데이터셋도 학습이 가능합니다. 이것을 가능케 하는 것이 Sparsity-aware Split Finding입니다. 먼저 알고리즘을 살펴보겠습니다.
이 알고리즘은 결측치가 있는 데이터들을 분류할 Default 방향을 결정하는 알고리즘입니다. 모든 결측치를 한 번은 전부 오른쪽에, 한 번은 전부 왼쪽에 배치하고, Split Point를 찾는 것입니다. 아래의 그림을 예로 들면, 모든 결측치를 왼쪽에 배치하였을 때 더 좋은 Split point를 찾을 수 있으므로, 해당 가지(branch)에서 결측치 데이터를 분류할 Default 방향을 왼쪽 leaf로 설정하게 됩니다.
사실 저는 이 알고리즘의 이름이 Missing-aware Split Finding이 아니라 Sparsity-aware Split Finding인지는 잘 모르겠습니다. 어쨌든 이 알고리즘의 목적은 트리로 들어온 결측치를 처리하는 Default 방향을 결정하는 데에 있습니다. 그리고 알고리즘 내부에서 결측치를 어떻게 처리할지 비교적 단순하게 결정하기 때문인지, 기존 알고리즘보다 더 빠른 속도를 보인다고 말합니다. (비교한 기존 알고리즘이 어떤 알고리즘인지 따로 레퍼런스가 달려있지는 않습니다,,, 개인적으로 본 논문에서 애매하다고 생각하는 부분 중 하나입니다)
여기까지 XGBoost가 학습하는 과정에서 어떤 손실함수를 토대로 학습하고, 손실함수를 줄이기 위해 가지(branch)를 어떻게 확장해나가는지 살펴보았습니다. 논문에서 뒤에 이어지는 내용은 하드웨어적 관점에서 XGBoost가 어떻게 우수한지 설명하고 있습니다. 트리를 학습하는 데에 시간이 가장 많이 소요되는 부분은 데이터를 정렬(sorting)하는 부분입니다. 데이터를 정렬할 때의 시간복잡도는 $O(n \log{n})$인데, XGBoost가 어떻게 이 시간복잡도를 다뤘는지 논문의 4장 "System Design"에서 살펴보실 수 있습니다. (사실 저는 이런 퓨어한 컴퓨터과학적 이론은 잘 알지 못해서 이 부분은 가볍게 읽고 넘어갔습니다)
4. System Design
XGBoost는 분산 처리 환경에서 학습 가능하고 CPU 캐시를 고려한 알고리즘이라 데이터의 크기가 커져도 빠른 속도로 학습할 수 있다는 점을 설명하고 있습니다.
5. Related Works
- Gradient Tree Boosting은 분류(Classification), 순위 결정(Learning to rank) 등에서 우수한 성능을 보임
- XGBoost는 오버피팅을 방지하기 위해 Regularized 모델 채택 -> (2)번 식 참고
- Column Subsampling, Shrinkage를 이용하여 추가적으로 오버피팅을 방지하는 테크닉 차용
- Sparsity-aware Split Finding을 원론적인(설명가능한) 측면에서 제시
- Weighted Quantile Sketch 방식을 제안 -> Appendix에 자세히 기술되어 있음
- 분산 처리 환경에서 학습 가능 등 컴퓨터 하드웨어적 측면에서 큰 데이터를 효율적으로 학습 가능
- XGBoost는 이런 요소들이 모두 결합됨
6. End To End Evaluations
6.1 System Implementation
XGBoost 패키지를 Python, R, Julia 등 많은 컴퓨터 언어 플랫폼에서 사용 가능하다고 소개합니다.
6.2 Dataset and Setup
메소드 검증에 사용한 데이터셋과 결과를 설명합니다.
My Opinion
XGBoost는 딥러닝이 득세하던 시기인 2016년에 등장한 설명가능한 머신러닝 방법론입니다. CNN, RNN을 비롯한 많은 딥러닝 기법들은 반도체의 발전과 더불어 수면 위로 떠오르게 되었고, 그 우수한 성능 또한 입증되었습니다. 그렇게 Support Vector Machine, Random Forest 등의 전통적인 머신러닝 기법들은 점차 영향력이 줄어드는가 했었지만, 성능이 아무리 좋아도 딥러닝 기법들에는 결과에 대해서 인간의 언어로 설명이 어렵다는 Blackbox의 한계가 있습니다.
이렇게 딥러닝과 머신러닝 사이에는 성능과 설명가능성 사이의 Trade-off가 있습니다. 그러던 중 XGBoost의 등장은 모델의 학습 결과를 명확하게 설명 가능하고, 그 성능 또한 정형 데이터를 다루는 데에 있어서 딥러닝 못지않게 좋다는 점에서 매우 의미가 있는 방법론인 것 같습니다.
XGBoost의 등장에 이어 LightGBM, CatBoost가 등장했고 이 방법론들 또한 다뤄볼 예정입니다.
다음 포스팅에서는 XGBoost에서 제안한 Weighted Quantile Sketch에 대해서 좀 더 자세히 알아보겠습니다.
Reference
https://xgboost.readthedocs.io/en/latest/tutorials/model.html
https://www.youtube.com/watch?v=VHky3d_qZ_E&t=1651s