Kicarussays

[논문리뷰/설명] wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations 본문

Deep Learning

[논문리뷰/설명] wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations

Kicarus 2022. 4. 14. 13:12

안녕하세요!

 

이번 논문은 발화 데이터 representation을 학습하는 self-supervised 기반의 방법을 제시한 논문입니다.

바로 시작해보겠습니다!

 

논문링크: https://arxiv.org/abs/2006.11477

 

wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations

We show for the first time that learning powerful representations from speech audio alone followed by fine-tuning on transcribed speech can outperform the best semi-supervised methods while being conceptually simpler. wav2vec 2.0 masks the speech input in

arxiv.org

 

 

 

 Introduction

 

발화 데이터를 학습하기 위해서는 번역된 labeled data가 필요합니다. 하지만 전세계의 수많은 언어들에 대한 양질의 transcribe 자료(음성 -> 텍스트)들을 얻기는 어렵고, labeled data만을 활용하여 학습하는 것은 유아기에 부모가 말하는 것을 듣는 것만으로 언어를 습득하는 인간의 방식과는 다릅니다.

 

self-supervised learning (SSL)은 데이터 자체의 representation을 잘 학습하기 위한 방법입니다. 마치 유아기에 언어를 습득하는 방법과 비슷합니다. SSL은 unlabeled data만을 활용하여 데이터 자체의 representation을 학습하게 하고, 이후에 labeled data를 활용하여 최종적으로 원하는 task를 수행할 수 있도록 fine-tuning하는 방법입니다.

 

이 논문은 원(原) 발화 데이터로부터 적절한 representation을 학습하는 SSL 프레임워크를 제안합니다. 먼저 발화데이터를 multi-layered CNN으로 인코딩하고, 해당 latent representation을 마스킹합니다. 마스킹된 latent representation을 transformer에 진입시켜 contextualized representation을 얻고, 적절한 레이블을 얻을 수 있도록 학습됩니다.

 

이어서 구체적인 방법이 어떻게 수행되는지 살펴봅시다.

 

 

 

 

 Model

 

 

위 Framework의 부분들의 하나씩 살펴봅시다.

 

1. Multi-layer convolutional feature encoder $f : \mathcal{X} \mapsto \mathcal{Z}$

 

  • 원 발화 데이터 $\mathcal{X}$를 $T$ step의 latent representation $\mathbf{z}_1, \cdots, \mathbf{z}_T$로 임베딩
  • layer normalization, GELU 활용
  • stride가 time step $T$를 결정하게 됨

 

2. Transformer $g : \mathcal{Z} \mapsto \mathcal{C}$

 

  • representation $\mathbf{c}_1, \cdots, \mathbf{c}_T$로 임베딩
  • positional embedding에 convolutional layer 활용 -> 해당 방법론 관련 논문 
  • Contextualized representation 산출

 

3. Quantization module $\mathcal{Z} \mapsto \mathcal{Q}$

 

  • Self-supervised training을 위해 feature encoder $f$를 통과한 $\mathbf{z}$에 product quantization 수행
  • Product quantization (PQ)의 서브벡터(codebook)의 개수를 $G$라고 하고, CNN을 통과한 임베딩 벡터 $\mathbf{z} \in \mathbb{R}^{V \times d}$ 라고 가정 (여기서 $V$는 time step, $d$는 CNN을 통과한 embedding 벡터의 크기)
  • $z$로부터 크기가 $\mathbb{R}^{V \times d/G}$인 $G$개의 서브벡터를 만들고, PQ를 수행하여 각 서브벡터의 데이터들을 클러스터링함
  • 각 서브벡터에 PQ를 수행할 때, 서브벡터의 entry들이 codebook 내에서 정의된 centroid중 가장 가까운 centroid의 data로 변환됨. 가장 가까운 centroid를 찾는 과정은 gumbel softmax를 통해서 수행함
  • 클러스터링 후 변환된 $G$개의 데이터들을 concat하여 크기가 $d$인 벡터로 변환하고, 단순 linear transformation $\mathbb{R}^d \mapsto \mathbb{R}^f$를 적용하여 $\mathbf{q}$를 얻음
  • Quantization을 거친 $\mathbf{q}$를 최종적으로 loss 계산에 활용해야하는데, codebook 내의 가장 가까운 centroid의 index로 변환하는 과정이 discrete하고, 미분불가능함. gumbel softmax는 이러한 미분불가능한 작업을 역전파가 가능하게끔 미분가능하게 만들어주는 activation function임. (Gumbel softmax 설명)
  • forward propagation에서는 가장 가까운 centroid를 찾는 argmax 함수를 활용하고, backward propagation에서는 gumbel softmax로 진행

 

 

 

 Training

Masking

BERT와 비슷한 방식으로 진행되는데, 일부 time step에 해당하는 부분의 Latent speech representation $\mathbf{z}$를 마스킹하고, 남은 부분으로 마스킹한 부분의 Quantized representation을 유추하는 방식으로 사전학습을 진행합니다.

 

Objective

손실함수는 다음과 같은 식을 사용합니다.

$$\mathcal{L} = \mathcal{L}_m + \alpha * \mathcal{L}_d$$

여기서 $\mathcal{L}_m$은 contrastive loss, $\mathcal{L}_d$는 diversity loss, $\alpha$는 하이퍼파라미터 입니다. 각 Loss가 의미하는 것이 무엇인지 살펴봅시다.

 

 

1. Contrastive Loss $\mathcal{L}_m$

 

$\mathcal{L}_m$은 마스킹된 time step의 Quantized representation을 유추할 때 활용하는 Loss입니다. Transformer를 통과한 벡터 $c$가 Quantized vector $q$와 잘 일치하면 Loss 함수값이 감소하는 형태입니다.

여기서 $sim$은 cosine similarity 함수를 의미합니다.

 

 

2. Diversity Loss $\mathcal{L}_d$

 

Codebook의 entry들이 균일하게 활용될수록 Loss 함수값이 감소하는 형태입니다. Quantized representation의 softmax를 씌운 것을 $\bar{p}_{g,v}$로 정의하고, 이를 활용하여 $\mathcal{L}_d$를 다음과 같이 정의합니다.

Fine-tuning

위의 과정으로 pre-train된 모델에 추가적인 Linear layer를 Context network위에 붙이는 등 원하는 작업을 수행하도록 모델을 구성합니다.

 

 


 

 

여기까지 wav2vec 2.0의 작동방식을 살펴보았습니다. 발화데이터 뿐만 아니라 생체신호 데이터에도 적용하여 의료현장에 활용할 수 있는 가능성이 많은 아주 멋진 방법론인 것 같습니다.

 

감사합니다!

 

 

 

 

 

 

 

Comments