Kicarussays

[논문리뷰/설명] XGBoost: A Scalable Tree Boosting System (2) Weighted Quantile Sketch 본문

Machine Learning

[논문리뷰/설명] XGBoost: A Scalable Tree Boosting System (2) Weighted Quantile Sketch

Kicarus 2021. 6. 30. 17:19

이전 포스팅 링크입니다.

 

[논문리뷰] XGBoost: A Scalable Tree Boosting System (1)

아무래도 EMR 데이터를 다루다보면 테이블 데이터에 사용하기 적합한 방법론을 많이 찾아보게 됩니다. 딥러닝이 많이 적용되는 영상이나 신호처럼 데이터 특성에 알맞은 메소드가 꽤 명확한 데

kicarussays.tistory.com

 

이번 포스팅에서는 XGBoost 논문의 Appendix에 자세히 설명이 되어있는 Weighted Quantile Sketch에 대해서 알아보겠습니다.

 

Weighted Quantile Sketch가 XGBoost에서 등장한 배경을 설명하기 위해 몇 가지를 기억해봅시다.

우리가 지금 하고자 하는 작업은 데이터로부터 유용한 Split point 후보군을 찾는 것입니다.

유용한 Split point 후보군은 데이터를 균등하게 잘 나누는 point들의 집합이 될 것입니다.

 

그런데 우리가 가진 데이터는 모든 데이터가 동일한 취급을 받는 데이터가 아닌, 샘플 하나하나에 서로 다른 Weight가 부여된 데이터입니다. 그렇기 때문에 일반적인 경우처럼 정렬하고 간격대로 잘라서 Split point를 찾는 방식으로 접근하는 것보다 Weight가 반영된 설명가능한 구체적인 방법이 필요합니다.

 

 

 

Weighted Dataset 살펴보기

 

데이터셋 $\mathcal{D} = \left\{ ( \mathbb{x}_i, y_i ) \right\}$에 대하여, $|\mathcal{D}| = n, \mathbb{x}_i \in \mathbb{R}^m, y_i \in \mathbb{R}$ 입니다. $k$번 째 Feature에 대하여 집합 $\mathcal{D}_k = \left\{ (x_{1k}, h_1), (x_{2k}, h_2), \cdots, (x_{nk}, h_n) \right\}$을 정의하겠습니다. $h_i$는 손실함수의 2차 그래디언트 값인 것을 기억해야 합니다.

이렇게 데이터$x_{ik}$와 2차 그래디언트 값 $h_i$을 튜플로 묶은 이유는 $h_i$가 Weighted Quantile Sketch에서 Weight의 역할을 하기 때문입니다. 어떻게 그렇게 되는지 살펴봅시다.

 

 

 

2차 그래디언트 값이 Weight가 되는 이유

 

이전 포스팅의 (6)번식을 다시 살펴봅시다. $$\tilde{\mathcal{L}}^{(t)} = \sum_{i=1}^{n}[g_i  f_t(\mathbb{x}_i) + \frac{1}{2}  h_i  f_t^2 (\mathbb{x}_i) ] + \Omega(f_t), \tag{6}$$ 이 식은 XGBoost 학습 과정에서 $t$번 째 iteration마다 최적화할 손실함수 식이었습니다.

 

이 식을 $f_t(\mathbb{x}_i)$에 대한 2차식으로 다시 쓰면, $$\frac{1}{2} h_i(f_t(\mathbb{x}_i) + g_i / h_i )^2 + \Omega(f_t) + \text{Constant}$$ 입니다. (논문에는 $\frac{1}{2} h_i(f_t(\mathbb{x}_i) - g_i / h_i )^2 + \Omega(f_t) + \text{Constant}$ 로 부호가 반대로 되어있긴 한데 아마 오타이지 않을까 싶습니다) 

이 식은 $f_t(\mathbb{x}_i)$에 대한 2차식이고, 그 계수는 $\frac{1}{2} h_i$입니다. 모든 샘플의 손실함수 값에는 $\frac{1}{2} h_i$만큼의 Weight가 부여되어 있는 것이라고 해석할 수 있습니다. 따라서 2차 그래디언트 값이 Weight가 되는 것입니다.

(헷갈리실 수 있을 것 같아서 말씀드리자면, (6)번 식은 트리 학습 과정에서 사용하는 손실함수이고, 2차 그래디언트 값이 계산된 손실함수는 원래 결과값 $y_i$와 예측값 $\hat{y_i}$ 사이의 손실함수입니다.)

 

지금까지 2차 그래디언트 값을 함께 튜플로 설정한 배경을 살펴보았습니다. 이어서 Weighted Quantile Sketch를 하기 위한 rank 함수를 정의해봅시다.

 

 

 

Rank 함수 정의

 

Feature $k$에 대한 Rank 함수 $r_k : \mathbb{R} \to [0, +\infty)$를 다음과 같이 정의해봅시다.

$$r_k(z) = \frac{1}{\sum_{ ( x, h ) \in \mathcal{D}_k} h} \sum_{ ( x, h ) \in \mathcal{D}_k, x < z } h, \tag{6-1}$$

직관적으로 $r_k(z)$ Feature $k$에서 $z$보다 작은 값을 가진 데이터들의 2차 그래디언트의 합을, 전체 2차 그래디언트의 합으로 나눈 값입니다.

 

우리의 목표는 이 Rank함수를 활용하여 Feature $k$에 대한 우수한 Split Point 후보군 $\mathcal{S_k} = \left\{ s_{k, 1}, s_{k, 2}, \cdots, s_{k, l} \right\}$을 찾는 것이고, 각 $s_{k, j}$들은 다음을 만족해야 합니다.

$$| r_k(s_{k, j}) - r_k(s_{k, j+1}) | < \epsilon, \text{     } s_{k, 1} = \min_{i} \mathbb{x}_{ik}, \text{     } s_{k, l} = \max_{i} \mathbb{x}_{ik}, \tag{6-2}$$

 

여기서 $\epsilon$은 0과 1 사이의 값을 갖는파라미터입니다. $\epsilon$이 클수록 Split point 간의 간격이 넓어지고, 후보군의 숫자는 줄어들 것입니다. 후보군의 숫자가 줄면 알고리즘 수행 속도는 빨라지겠지만 최적 Split point를 찾을 확률이 줄어들겠죠. 직관적으로 Split point 후보군의 수는 $1/\epsilon$이 됩니다. 

 

XGBoost는 이러한 방식으로 Split point 후보군을 찾는 데에 증명 가능한 이론적 근거(provable theoretical guarantee)를 바탕으로 한 알고리즘을 제시하고자 합니다. 앞으로 이어지는 내용은 XGBoost가 Rank함수를 바탕으로 Weighted Quantile Sketch를 수행하는 알고리즘을 어떻게 이론적으로 제시하는지 살펴볼 것입니다.

 

 

 

Appendix. Weighted Quantile Sketch

 

먼저 Quantile query가 무엇인지 알아봅시다. Quantile은 아주 간단한 개념입니다. 0과 1 사이의 값을 갖는 $\epsilon$에 대하여(approximation error로 정의합니다), ordered metric이 정의된 공간 상에서, 크기가 $N$인 집합 $S$의 $\epsilon$-quantile은 $S$를 오름차순으로 정렬하였을 때, $\epsilon N$번 째의 데이터가 됩니다. 예를 들어, 집합 $\left\{ 1, 2, 3, 4, 5 \right\}$의 0.5-quantile은 3, 0.25-quantile은 1.25입니다(3보다 작은 1, 2 사이의 25% 지점). Quantile query는 바로 어떤 데이터에 대하여 $\epsilon$-quantile이 무엇인지에 대한 질문입니다. Quantile Summary는 이 질문에 대한 summary로 이해하면 되겠습니다.

 

Quantile Summary에는 "merge(합치기)"와 "prune(쳐내기)" 두가지 operation이 정의되어 있습니다. 

 

  • Merge: Approximation error $\epsilon_1$, $\epsilon_2$ 수준에서의 두 Summary를 합치는 작업입니다. 합쳐진 Summary는 $\max(\epsilon_1$, $\epsilon_2)$의 Approximation error를 갖습니다.
  • Prune: Summary에 있는 element를 삭제하고, 그에 맞춰 Approximation error를 $\epsilon \to \epsilon + \frac{1}{b}$로 변경합니다. 

이렇게 보면 Quantile query, summary는 굉장히 간단해 보이지만, 데이터의 크기가 커질수록 해당 데이터를 정렬하는 데 소요되는 비용이 매우 커지고, 메모리의 제한이 있기 때문에, 분산 환경에서 Quantile query를 해결하기위한 알고리즘이 필요합니다. 논문에서는 Weight가 부여되지 않은 보통의 데이터들은 "Quantile Sketch"라는 이미 제안된 알고리즘으로 Quantile query를 해결할 수 있다고 합니다. 하지만 Weight가 부여된 데이터들에 쓸 수 있는 알고리즘은,  특히 "분산 환경에서" "수리적 기반 하에" "설명 가능한" 알고리즘은 아직 개발된 바가 없다고 언급되어 있습니다.

 

XGBoost에서는 분산 환경에서 Weight가 부여된 데이터들의 Split point 후보군을 찾는 문제를 해결하고자, 새로운 Weighted Quantile Sketch 알고리즘을 제안합니다.

 

다시 우리가 할 작업을 상기해보도록 합시다.

정해진 parameter인 Approximation error $\epsilon$에 대하여, (6-1), (6-2)번 식을 만족하는 Split point 후보군을 찾는 것입니다.

 

메모리의 한계가 있다면 데이터를 쪼개서 각각에 대한 Quantile Summary를 산출해야 하고, 이를 Merge 하는 작업이 필요합니다. 또한 최적의 Split point 후보군으로 Element를 줄이기 위해서 Prune 하는 작업도 해야 합니다.

 

Appendix에는 분산 환경에서 Merge와 Prune이 어떻게 이루어지는지 수식으로 자세하게 설명되어 있습니다.

 

 

 

Reference

https://arxiv.org/pdf/1603.02754.pdf

https://www.researchgate.net/publication/2854033_Space-Efficient_Online_Computation_of_Quantile_Summaries

 

 

 

Comments