일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | 5 | 6 | 7 |
8 | 9 | 10 | 11 | 12 | 13 | 14 |
15 | 16 | 17 | 18 | 19 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 27 | 28 |
29 | 30 | 31 |
- Machine Learning
- Back-propagation
- Explainable AI
- XGBoost
- Gradient Boosting Machine
- lime
- data science
- deep learning
- Gradient Tree Boosting
- Today
- Total
Kicarussays
[논문리뷰/설명] DeepHit: A Deep Learning Approach to Survival Analysis with Competing Risks : 생존분석을 위한 딥러닝 본문
[논문리뷰/설명] DeepHit: A Deep Learning Approach to Survival Analysis with Competing Risks : 생존분석을 위한 딥러닝
Kicarus 2021. 6. 30. 16:21해당 링크를 참조하여 포스팅하였습니다.
humboldt-wi.github.io/blog/research/information_systems_1920/group2_survivalanalysis/#motivation
생존분석에 딥러닝을 적용한 여러 케이스를 찾던 중 잘 정리된 포스트가 있어, 읽어보고 직접 정리해보려 합니다.
1. 생존 분석
먼저 생존분석에 대해 정리해보겠습니다.
1.1 용어 정리
- Birth Event: 첫 번째 관측이 시작되는 이벤트
- Death Event: 분석을 하고자 하는 이벤트(General Event, 예를 들면 환자의 사망, 고객의 이탈 등)
- Censorship: 중도절단 데이터(샘플에 아직 Death Event가 생기지 않았거나, 다른 Event가 관측된 경우)
1.2 생존 함수(Survival Function)
Definition. 생존 함수 $S(t)$는 샘플이 시간 $t$ 이후에 생존할 확률로, 샘플의 생존 시간에 대한 확률변수 $T$에 대하여 다음과 같이 정의한다. $$S(t) = Pr(T>t)$$
예를 들면 고객이 $t$ 시간 이후에도 파산하지 않을 확률을 나타낼 수 있겠죠.
1.3 위험함수 (Hazard Function)
Definition. 위험 함수 $h(t)$는 $t$시점에 Death Event가 발생할 확률로, 샘플의 생존 시간에 대한 확률변수 $T$에 대하여 다음과 같이 정의한다. $$h(t) = \lim_{\delta t \to 0} \frac{Pr(t \le T \le t + \delta t | T > t)}{\delta t}$$
Death Event가 환자의 사망이라고 한다면 $t$가 커질수록 $h(t)$는 증가하게 되겠지만, Death Event가 은행 고객의 파산이라면 아래 그림 1.과 같이 $t$가 커질수록 처음에는 $h(t)$가 감소하다가 증가하는 양상을 보이게 될 것입니다. 이처럼 Death Event가 무엇이냐에 따라 위험함수의 모양이 완전히 달라지게 됩니다.
위험함수를 추정할 수 있다면, 어느 시점에 어떤 결정을 내릴지에 대해 보다 많은 정보를 알 수 있게 될 것입니다. 예를 들어, Death Event가 고객의 이탈일 때, $h(t)$의 모양을 보고 고객에게 이벤트나 혜택을 제공할 타이밍을 결정할 수 있겠죠.
이처럼 생존 분석의 주목적은 데이터로부터 생존함수, 위험함수를 추정하고 해석하는 것입니다.
2. 생존분석 기법(Standard methods in Survival Analysis)
생존분석 방법은 세 그룹으로 나눌 수 있다고 합니다. Non-parametric, Semi-parametric, Parametric
어떤 방법을 사용할지는 데이터셋의 모양이나 연구 목적에 따라 달라질 것입니다. 한 가지 기법만 사용하는 것보다는 여러 기법을 사용해보고 결과를 비교하는 것이 좋겠죠.
- Parametric 기법은 데이터의 생존시간에 대한 분포가 특정 확률분포를 잘 따른다는 가정이 필요합니다. 데이터의 확률분포를 추정하고 MLE를 계산하여 해당 분포의 Parameter를 추정합니다.
- Non-parametric 기법은 데이터의 확률분포보다는 생존함수와 위험함수를 시간에 대한 함수로 설명하는 데 사용됩니다. 대표적인 일변량 기법으로 카플란 마이어 추정량(Kaplan-Meier estimator)이 있고, 생존 분석의 첫 번째 단계로 주로 사용합니다.
- Semi-parametric 기법은 Parametric, Non-parametric 모두 기반으로 하는 Cox regression model이 대표적입니다.
2.1 카플란 마이어 추정량(Kaplan-Meier estimator)
본 포스트 서두의 링크의 내용을 그대로 번역하면, 카플란 마이어 추정량은 생존함수 $S(t)$의 추정에 대해 사건 발생 시간에 따라 더 작은 스텝으로 나눈다... 라고 하네요. 말이 좀 어렵긴 한데 아래 그림 2.로 이해가 되실 것 같습니다.
각 구간에 대해서 생존 확률은 다음 식으로 설명됩니다. $$\hat{S(t)} = \prod_{i:t_i \le t}\frac{n_i - d_i}{n_i}$$
$n_i$는 $t_i$시점에 남아있는 대상자 수, $d_i$는 $t_i$시점에 Death Event가 생긴 대상자 수입니다. $\frac{n_i - d_i}{n_i}$ 가 일단 1보다 작기 때문에 그림 2.처럼 점점 감소하는 형태의 그래프가 나타나긴 하겠네요. 또한 수렴하는 부분도 $t$가 충분히 커졌을 때 $d_i$가 0에 가까워지므로 $\frac{n_i - d_i}{n_i}$가 1에 가까운 값을 가지면서 수렴하게 될 것입니다.
카플란 마이어 추정량을 사용할 때 다음 네 가지 가정이 필요합니다.
- 모든 관측치는(중도절단 or not) 추정에 사용된다.
- 모든 대상자는 각자의 환경이나 연구에 참여한 시점은 생존확률에 영향을 미치지 않는다.
- 중도절단된 데이터에 대해서 그렇지 않은 데이터들과 같은 생존확률을 적용한다.
- 모든 대상자에 대해 생존확률은 같다.
가정들이 모두 매우 Strong한 가정들로 보입니다... 대상자들의 생존에 영향을 미치는 다른 많은 변수들이 있을텐데 그런 변수들의 영향을 무시한 추정량이라는 단점이 있습니다. 하지만 전체적인 데이터의 모양을 살펴보는 데는 좋을 것 같네요. 이어서 카플란 마이어 추정량의 단점을 일부 상쇄하는 콕스 비례위험모형에 대해서 살펴봅시다.
2.2 콕스 비례위험모형(Cox proportional hazard model)
콕스 비례위험모형은 기존에 카플란 마이어 추정량을 그릴 때 썼던 사건 발생 시간, 중도절단 여부 데이터에 나이, 성별 등 대상자의 특징 데이터가 포함됩니다. 이 모형은 1972년 콕스 선생님이 개발하신 모형인데 지금까지도 많이 쓰이고 있는 모형입니다. 생존 분포를 각 대상자들의 정보를 통해서 나타낸다는 특징이 있습니다. 위험함수에 대상자의 특징 데이터를 적용한 다음 식을 한번 볼까요? $$\lambda(t|\mathbf{x}) = \lambda_0(t)exp(\beta_1x_1 + \cdots + \beta_nx_n)$$
여기서 $x$는 대상자의 특징 데이터 벡터, $x_1, \cdots, x_n$은 각 특징 데이터의 값, $\beta_1, \cdots, \beta_n$은 각 특징 데이터값에 곱해질 계수(coefficient)입니다. $\lambda_0(t)$는 baseline hazard function이라고 합니다. baseline hazard function은 대상자의 모든 특징 데이터 값이 0일 때의 위험 정도를 나타내는 함수라고 하네요.
그렇다면 $\beta_1, \cdots, \beta_n$에 해당하는 계수는 어떻게 구해야 하는 것일까요? 이 방법에 대해서는 링크를 참조하시기 바랍니다. 해당 링크의 요지는 각 특징에 해당하는 계수 $\beta_1, \cdots, \beta_n$이 0인지 아닌지 가설검정을 통해 밝히는 것입니다.
cf. 특징 데이터 벡터 $\mathbf{x}$의 모든 값이 1을 가질 때의 식을 다음과 같이 나타낼 수 있다고 합니다. $$\lambda_1(t) = \lambda(t|\mathbf{x}=1) = \lambda_0(t)\text{exp}(\beta \mathbf{x}) = \lambda_0(t)\text{exp}(\beta)$$
이 식에서 $\lambda_1(t), \lambda_0(t)$는 모두 상수이므로, $e^{\beta}$값이 상수인건 자명합니다. 이 $e^{\beta}$값을 위험비(hazard ratio; HR)라고 합니다. 의학논문에 실린 표에서 자주 볼 수 있는 값이죠. 이번 포스팅을 하며 저도 처음 알게 되었습니다.
이 콕스 비례위험모형에는 비례위험가정이 필요합니다. 비례위험가정은 시간이 지남에 따라 위험비가 일정해야한다는 것입니다. 예를 들어 A 시점에서 X 대상자의 사망 확률이 Y 대상자보다 두 배 크다면, 모든 다른 시점에서도 사망 확률이 일정하게 두 배 커야 한다는 것입니다. 결국 콕스 비례위험모형은 다음 사항들을 따라야 합니다.
- 각 대상자의 Death Event는 독립적으로 발생한다.
- 각 대상자들의 생존곡선이 서로 교차해서는 안된다.
- 변수들의 계수 추정치는 위험함수와 선형곱 관계에 있다.
하지만 데이터셋의 종류에 따라 이러한 사항들을 만족하지 못하는 경우도 있겠죠. 하지만 이러한 제약을 극복할 수 있는 여러 방법들이 있습니다.
- 비례위험가정을 위반하지 않도록 해당 변수(연속변수)를 계층화하여 분류한다(정보의 손실이 생길 수 있다).
- 시간에 따라 변하는 데이터에 적용할 수 있는 콕스 회귀분석(Cox Regression)을 적용한다.
- Random Survival Forest
- Extension with neural networks!!
이어서 Deephit 논문을 직접 리뷰해보겠습니다.
이전 포스팅에서 생존분석에 대한 내용과, 인공신경망(Neural Network; Deep Learning)을 왜 생존분석에 적용하게 되었는지에 대한 배경을 설명했으니, 바로 방법론에 대한 설명을 시작해보도록 하겠습니다.
Survival Data
DeepHit에서 사용하는 Survival Data는 세 가지 요소로 구성되는데, 기간(Time), 데이터가 관측되어온 이벤트의 종류 및 발생 여부(Event), 대상 데이터의 특성(Feature)이 그 요소들입니다.
위 예시에서 $s$는 Time, $k$는 Event, $\mathbb{x}$는 Feature가 될 것입니다.
Time
- DeepHit는 Time을 불연속적(Discrete)이고 상한(Upper Limit)이 있는 데이터로 간주
(예컨대 사람이 300살 이상 사는 경우는 없으니 Time의 상한을 300년으로 설정할 수 있음) - Time Set을 $\mathcal{T} = \left\{ 0, \cdots, T_{\text{max}} \right\}$로 정의합니다.
Event
- $K$는 이벤트 종류의 개수를 의미
- DeepHit는 기존의 생존분석 방법들과는 달리 여러 이벤트에 대해서 분석할 수 있다는 점을 강점으로 제시함
- 위에서 보는 데이터 예시에서는 $K = 2$
- 한 샘플 당 하나의 이벤트만 발생한다는 것을 가정하고 있음
Feature
- 나이, 성별 등 샘플이 가지고 있는 특성들
최종적으로 $\mathcal{D} = \left\{ ( \mathbb{x}^{(i)}, s^{(i)}, k^{(i)} ) \right\}_{i=1}^{N}$ 형태의 데이터셋을 사용하게 됩니다. 그리고 타겟팅하고 있는 확률은 $P(s = s^*, k=k^* | \mathbb{x} = \mathbb{x}^*)$ 입니다.
Model Description
DeepHit에서 사용하는 Survival Data는 세 가지 요소로 구성되는데, 기간(Time), 데이터가 관측되어온 이벤트의 종류 및 발생 여부(Event), 대상 데이터의 특성(Feature)이 그 요소들입니다. 타겟팅하고 있는 확률 $\hat{P}$를 학습하기 위해 DeepHit는 아래와 같은 구조를 갖습니다.
이 구조는 크게 세 부분으로 나눌 수 있습니다.
1. Shared Sub-network
- $\mathbb{x}$로 표현되는 한 샘플의 Feature가 Fully-Connected Layer들로 구성된 MLP(Multi Layer Perceptron)를 통과
- MLP를 통과하고 나온 output에 샘플의 Feature 원본 ${\mathbb{x}}$를 더함
- DeepHit 구조 이미지에서 보이는 $f_s(\mathbb{x})$가 MLP를 통과한 후 나온 output 벡터가 됨
- $\mathbb{z} = (f_s(\mathbb{x}), \mathbb{x})$가 output 벡터와 기존 Feature 벡터를 더한 벡터가 됨
(벡터의 길이는 Feature의 개수)
2. Cause-Specific Sub-networks
- Shared Sub-network를 통과하고 나온 벡터 $\mathbb{z}$가 input으로 들어감
- 대상이 되는 Event의 개수만큼 Cause-Specific Sub-network를 구성
(기존의 생존분석 방법들과는 달리 여러 이벤트에 대해서 분석할 수 있다는 점이 DeepHit의 강점) - 각 Cause-Specific Sub-network를 통과하면 Event 별 Output Layer 벡터를 가져오게 됨
3. Output(Softmax) Layer
- Event 별 Output Layer 벡터들을 모두 이어붙이고, Softmax 함수를 통과
($y_{i, j}$는 Event i가 j 시점에 발생활 확률의 추정값 $\hat{P}(s, k|\mathbb{x})$) - $\sum y_{i, j} = 1$
이어서, 이러한 구조를 가진 DeepHit가 어떤 Loss 함수를 기준으로 학습하는지 살펴봅시다.
Loss Function
DeepHit에서 기준으로 하는 Loss 함수를 살펴보기에 앞서 저자는 CIF(Cumulative Incidence Function)를 정의합니다. 이벤트 $k^*$에 대한 CIF는 $$\begin{align*} F_{k^*}(t^*|\mathbb{x}^*) &= P(s \le t^*, k = k^* | \mathbb{x} = \mathbb{x}^*) \\ &= \sum_{s^* = 0}^{t^*} P(s = s^*, k = k^* | \mathbb{x} = \mathbb{x}^*) \end{align*}$$
으로 정의할 수 있는데, 직관적으로 보면 $\mathbb{x}$ 특성을 가진 샘플이, $t^*$ 시점 내에 이벤트 $k^*$가 발생할 확률입니다. 그런데 이 확률에 대한 정확한 값은 알 수 없기 때문에 추정치를 계산하게 되는데, CIF 추정치(Estimated CIF)를 다음과 같이 나타낼 수 있습니다. $$\hat{F}_{k^*}(s^*|\mathbb{x}^*) = \sum_{m=0}^{s^*} y_{k, m}^{*}$$
여기서 $y_{k, m}$는 DeepHit 모델을 통과하고 나온 Output Layer에서 가져온 값이 됩니다.
지금 살펴본 추정 CIF를 활용하여 DeepHit가 타겟팅하고 있는 Loss 함수를 살펴봅시다. DeepHit의 Loss 함수는 두 개의 Loss 함수의 합으로 나타납니다. $\mathcal{L}_{\text{Total}} = \mathcal{L}_1 + \mathcal{L}_2$로 쓸 수 있고, $\mathcal{L}_1$은 이벤트 발생 시간에 대한 Loss 함수, $\mathcal{L}_2$는 샘플들의 이벤트 발생 순서에 대한 Loss 함수입니다.
이벤트 발생 시간에 대한 Loss 함수 $\mathcal{L}_1$ 은 다음과 같이 정의됩니다.
$$\mathcal{L}_1 = - \sum_{i=1}^{N} \left[ \mathbf{1}(k^{(i)} \ne \varnothing) \cdot \log \left( y_{ k^{ (i) }, s^{ (i) } }^{ (i) } \right) + \mathbf{1} ( k^{(i)} = \varnothing ) \cdot \log \left( 1 - \sum_{k=1}^{K} \hat{F}_{k^{(i)}}(s^{(i)}|\mathbb{x}^{(i)}) \right) \right]$$
여기서 $\mathbf{1} (x)$는 $x$가 True일 때 1, False일 때 0을 반환하는 함수입니다. 이 식을 직관적으로 살펴보면,
앞의 부분은 $k^{(i)} \ne \varnothing$일 때 ($i$번 째 샘플이 이벤트 $k$가 발생한 경우), $\log ( y_{ k^{ (i) }, s^{ (i) } }^{ (i) } )$를 반환합니다. 이 로그 값은 1에 가까울수록 0에 가까운 음수를 반환하고, $\mathcal{L}_1$ 식의 시그마 앞부분에는 -가 붙어있으므로, 결국 Output(Softmax) Layer를 통과하고 나온 확률값이 클 수록(1에 가까울수록) Loss 함수가 작아지게 됩니다. 즉, 이벤트가 발생하는 시간을 잘 맞출수록 Loss 함수가 감소하는 것입니다.
뒤의 부분은 $k^{(i)} = \varnothing$일 때 ($i$번 째 샘플이 이벤트 $k$가 발생하지 않은 경우), $\log \left( 1 - \sum_{k=1}^{K} \hat{F}_{k^{(i)}}(s^{(i)}|\mathbb{x}^*) \right)$를 반환하는데, $s^{(i)}$시점까지 어떤 이벤트도 발생하지 않을 확률 값의 log 값을 의미합니다. 이것은 Censoring 데이터(이벤트 발생 여부를 모르는 데이터)에 대하여, 관측 시점 이전까지 아무 이벤트도 발생하면 안되게 하는 Loss 함수입니다.
이벤트 발생 순서에 대한 Loss 함수 $\mathcal{L}_2$ 는 다음과 같이 정의됩니다.
$$\mathcal{L}_2 = \sum_{k=1}^{K} \alpha_k \cdot \sum_{i \ne j} A_{k,i,j} \cdot \eta \left( \hat{F}_{k^{(i)}}(s^{(i)}|\mathbb{x}^{(i)}), \hat{F}_{k^{(i)}}(s^{(i)}|\mathbb{x}^{(j)}) \right)$$ $$\text{where } A_{k,i,j} \triangleq \mathbf{1} (k^{(i)} = k, s^{(i)} < s^{(j)}), \eta (x, y) = \exp (\frac{-(x-y)}{\sigma})$$
여기서 $\triangleq$는 "정의한다"는 뜻으로 이해해주시면 되겠습니다.
결국 $\mathcal{L}_2$는 이벤트가 발생한 시점이 다른 두 샘플로부터 이벤트 발생 순서를 맞히는 Loss 함수입니다.
$\alpha_k$는 이벤트 별로 가중치를 둔 부분입니다. $A_{k,i,j}$는 직관적으로 보면, 이벤트 $k$가 발생한 두 $i, j$번 째 샘플에 대하여, $i$번 째 샘플이 $j$번 째 샘플보다 이벤트 $k$가 먼저 발생했을 때 1을 반환하는 함수입니다. 이 함수가 있는 이유는 $\mathcal{L}_2$ 계산과정에서 값이 중복되어 더해지는 것을 방지하기 위해서입니다. ($i, j$번 째 샘플의 이벤트 발생 순서는 한 번만 구하면 됩니다)
$\eta (x, y)$를 살펴보면, 두 input $x, y$에 대하여 $x-y$ 값이 크면 클수록 작아지는 함수입니다. (분모에 있는 $\sigma$는 논문에서 별다른 설명을 찾지는 못했습니다. 파라미터라고 생각하면 될 것 같습니다) 결국 $i, j$번 째의 두 샘플에 대하여, $A_{k,i,j} \cdot \eta \left( \hat{F}_{k^{(i)}}(s^{(i)}|\mathbb{x}^{(i)}), \hat{F}_{k^{(i)}}(s^{(i)}|\mathbb{x}^{(j)}) \right)$는,
$i$번 째 샘플이 $j$번 째 샘플보다 이벤트 $k$가 먼저 발생했을 때에 $A_{k,i,j}$는 1을 반환하고, 두 샘플의 CIF 추정치의 차이가 클수록 $\mathcal{L}_2$가 작아지는 양상을 보입니다. 즉, 모든 샘플 쌍들의 순서를 맞히는데, CIF 추정치의 차이가 최대한 커지도록 Loss 함수가 설계된 것입니다.
여기까지 DeepHit의 Loss 함수들을 살펴보았습니다. 이제 DeepHit의 결과를 평가하는 방식에 대해 살펴보겠습니다.
Discriminative Performance
순서를 맞추는 문제에서 사용하는 전통적인 $C$-index라는 지표가 있습니다. 이 지표로부터 아이디어를 가져와, 저자는 DeepHit를 평가하기 위해 다음과 같은 $C^{td}$-index를 제시합니다.
$$\begin{align*} C^{td} &= P\left( \hat{F}_{k}(s^{(i)}|\mathbb{x}^{(i)}) > \hat{F}_{k}(s^{(i)}|\mathbb{x}^{(j)}) | s^{(i)} < s^{(j)}\right) \\ &\approx \frac{\sum_{i \ne j} A_{k,i,j} \cdot \mathbf{1} \left( \hat{F}_{k}(s^{(i)}|\mathbb{x}^{(i)}) > \hat{F}_{k}(s^{(i)}|\mathbb{x}^{(j)}) | s^{(i)} < s^{(j)}\right)}{\sum_{i \ne j} A_{k,i,j}} \end{align*}$$
이 식은 직관적으로, 이벤트 $k$가 발생한 모든 샘플들의 모든 쌍이 순서가 잘 맞춰져 있는지 확인하는 것입니다. 만약 실제로 $i$가 $j$보다 먼저 발생했다면, $A_{k,i,j}$는 1을 반환하게 됩니다. 여기서 두 CIF 추정치의 차가 0보다 크면, 이것은 DeepHit로부터 계산한 추정치가 올바른 값이라는 결론이 됩니다. ($i$번 째 CIF는 $s^{(i)}$시점 이전에 이벤트가 발생할 확률을 의미하고, 실제로는 $j$번 째 샘플이 더 늦게 이벤트가 발생했으므로 $j$번 째 CIF 추정치가 $i$번 째의 것보다 더 작아야 잘 추정되었다고 볼 수 있습니다)
DeepHit는 이 $C^{td}$를 활용하여, 기존의 다른 생존분석들, DeepSurv라는 먼저 나왔던 딥러닝 생존분석 기법과 성능을 비교합니다. $C^{td}$ 메트릭 하에서는 DeepHit가 가장 좋은 결과를 내고 있다고 말합니다.
My Opinion
생존분석을 딥러닝에 적용하여, 기존 생존분석의 한계(다양한 가정이 필요하고, 하나의 이벤트에 대해서만 분석 가능)를 극복한 것이 참 인상 깊었던 것 같습니다. 또한 간단한 Fully-connected Layer를 적절히 배치하여 생존분석에 필요한 형태로 Output을 내도록 모델을 설계했다는 점, Loss 함수를 기존에 없던 방식으로 커스텀하여 샘플들의 이벤트 발생 순서를 맞춘다는 목적을 달성하게끔 했다는 점이 정말 참신했던 것 같습니다. 더해서 $C^{td}$라는 새로운 메트릭을 제시하여 본인들의 연구에 타당성을 부여한 점도 좋았습니다.
다만, 컴퓨터 비전이나 신호처리 문제처럼 결과의 우수성이 변수들의 중요도보다 우선시되는 문제와는 달리, 생존분석 문제는 변수들의 중요도를 파악하는 것이 꼭 필요합니다. 그런 점에서는 딥러닝의 한계(Blackbox)를 벗어나지 못한 것으로 보입니다. 그럼에도 통계 분석가들의 영역으로 보였던 생존분석을 딥러닝에 적용했다는 것이 큰 의미를 갖는 것 같습니다.