일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | 5 | 6 | 7 |
8 | 9 | 10 | 11 | 12 | 13 | 14 |
15 | 16 | 17 | 18 | 19 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 27 | 28 |
29 | 30 | 31 |
- data science
- Explainable AI
- Gradient Boosting Machine
- Gradient Tree Boosting
- XGBoost
- deep learning
- Machine Learning
- lime
- Back-propagation
- Today
- Total
Kicarussays
[논문리뷰/설명] wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations 본문
[논문리뷰/설명] wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations
Kicarus 2022. 4. 14. 13:12안녕하세요!
이번 논문은 발화 데이터 representation을 학습하는 self-supervised 기반의 방법을 제시한 논문입니다.
바로 시작해보겠습니다!
논문링크: https://arxiv.org/abs/2006.11477
Introduction
발화 데이터를 학습하기 위해서는 번역된 labeled data가 필요합니다. 하지만 전세계의 수많은 언어들에 대한 양질의 transcribe 자료(음성 -> 텍스트)들을 얻기는 어렵고, labeled data만을 활용하여 학습하는 것은 유아기에 부모가 말하는 것을 듣는 것만으로 언어를 습득하는 인간의 방식과는 다릅니다.
self-supervised learning (SSL)은 데이터 자체의 representation을 잘 학습하기 위한 방법입니다. 마치 유아기에 언어를 습득하는 방법과 비슷합니다. SSL은 unlabeled data만을 활용하여 데이터 자체의 representation을 학습하게 하고, 이후에 labeled data를 활용하여 최종적으로 원하는 task를 수행할 수 있도록 fine-tuning하는 방법입니다.
이 논문은 원(原) 발화 데이터로부터 적절한 representation을 학습하는 SSL 프레임워크를 제안합니다. 먼저 발화데이터를 multi-layered CNN으로 인코딩하고, 해당 latent representation을 마스킹합니다. 마스킹된 latent representation을 transformer에 진입시켜 contextualized representation을 얻고, 적절한 레이블을 얻을 수 있도록 학습됩니다.
이어서 구체적인 방법이 어떻게 수행되는지 살펴봅시다.
Model
위 Framework의 부분들의 하나씩 살펴봅시다.
1. Multi-layer convolutional feature encoder $f : \mathcal{X} \mapsto \mathcal{Z}$
- 원 발화 데이터 $\mathcal{X}$를 $T$ step의 latent representation $\mathbf{z}_1, \cdots, \mathbf{z}_T$로 임베딩
- layer normalization, GELU 활용
- stride가 time step $T$를 결정하게 됨
2. Transformer $g : \mathcal{Z} \mapsto \mathcal{C}$
- representation $\mathbf{c}_1, \cdots, \mathbf{c}_T$로 임베딩
- positional embedding에 convolutional layer 활용 -> 해당 방법론 관련 논문
- Contextualized representation 산출
3. Quantization module $\mathcal{Z} \mapsto \mathcal{Q}$
- Self-supervised training을 위해 feature encoder $f$를 통과한 $\mathbf{z}$에 product quantization 수행
- Product quantization (PQ)의 서브벡터(codebook)의 개수를 $G$라고 하고, CNN을 통과한 임베딩 벡터 $\mathbf{z} \in \mathbb{R}^{V \times d}$ 라고 가정 (여기서 $V$는 time step, $d$는 CNN을 통과한 embedding 벡터의 크기)
- $z$로부터 크기가 $\mathbb{R}^{V \times d/G}$인 $G$개의 서브벡터를 만들고, PQ를 수행하여 각 서브벡터의 데이터들을 클러스터링함
- 각 서브벡터에 PQ를 수행할 때, 서브벡터의 entry들이 codebook 내에서 정의된 centroid중 가장 가까운 centroid의 data로 변환됨. 가장 가까운 centroid를 찾는 과정은 gumbel softmax를 통해서 수행함
- 클러스터링 후 변환된 $G$개의 데이터들을 concat하여 크기가 $d$인 벡터로 변환하고, 단순 linear transformation $\mathbb{R}^d \mapsto \mathbb{R}^f$를 적용하여 $\mathbf{q}$를 얻음
- Quantization을 거친 $\mathbf{q}$를 최종적으로 loss 계산에 활용해야하는데, codebook 내의 가장 가까운 centroid의 index로 변환하는 과정이 discrete하고, 미분불가능함. gumbel softmax는 이러한 미분불가능한 작업을 역전파가 가능하게끔 미분가능하게 만들어주는 activation function임. (Gumbel softmax 설명)
- forward propagation에서는 가장 가까운 centroid를 찾는 argmax 함수를 활용하고, backward propagation에서는 gumbel softmax로 진행
Training
Masking
BERT와 비슷한 방식으로 진행되는데, 일부 time step에 해당하는 부분의 Latent speech representation $\mathbf{z}$를 마스킹하고, 남은 부분으로 마스킹한 부분의 Quantized representation을 유추하는 방식으로 사전학습을 진행합니다.
Objective
손실함수는 다음과 같은 식을 사용합니다.
$$\mathcal{L} = \mathcal{L}_m + \alpha * \mathcal{L}_d$$
여기서 $\mathcal{L}_m$은 contrastive loss, $\mathcal{L}_d$는 diversity loss, $\alpha$는 하이퍼파라미터 입니다. 각 Loss가 의미하는 것이 무엇인지 살펴봅시다.
1. Contrastive Loss $\mathcal{L}_m$
$\mathcal{L}_m$은 마스킹된 time step의 Quantized representation을 유추할 때 활용하는 Loss입니다. Transformer를 통과한 벡터 $c$가 Quantized vector $q$와 잘 일치하면 Loss 함수값이 감소하는 형태입니다.
여기서 $sim$은 cosine similarity 함수를 의미합니다.
2. Diversity Loss $\mathcal{L}_d$
Codebook의 entry들이 균일하게 활용될수록 Loss 함수값이 감소하는 형태입니다. Quantized representation의 softmax를 씌운 것을 $\bar{p}_{g,v}$로 정의하고, 이를 활용하여 $\mathcal{L}_d$를 다음과 같이 정의합니다.
Fine-tuning
위의 과정으로 pre-train된 모델에 추가적인 Linear layer를 Context network위에 붙이는 등 원하는 작업을 수행하도록 모델을 구성합니다.
여기까지 wav2vec 2.0의 작동방식을 살펴보았습니다. 발화데이터 뿐만 아니라 생체신호 데이터에도 적용하여 의료현장에 활용할 수 있는 가능성이 많은 아주 멋진 방법론인 것 같습니다.
감사합니다!