본문 바로가기

Machine Learning/Worklog

Whisper 한국어 Fine-tuning 2주차

반응형

지난주에 H100 돈이 감당 안되가지고 TRC에 지원해서 tpu를 써야겠다고 마음 먹은 뒤로 한 짓은 대충 세가지다.

1. whisper jax finetunig 코드 작업

이건 일단 kaggle notebook으로 진행했다. 솔직히 jax 공부까지 해볼까 했지만 그냥 https://github.com/huggingface/transformers/blob/main/examples/flax/speech-recognition/run_flax_speech_recognition_seq2seq.py 이 코드랑 chatgpt로 얼추 맞춰서 돌아가는건 확인했다.

그 다음엔 TRC에서 얻은 TPU로 포팅, v4 32장을 지원해줬다. 파이토치 데이터로더를 그대로 사용하긴 했는데 중간중간에 hf에서 터지는 경우도 있고 해서 시간이 꽤 걸렸으나 결국 turbo 모델에 batch size 256으로 2만스텝 학습을 진행하긴 했다.

2. large-v3-turbo lora tuning

생각해보니까 굳이 풀 파인튜닝을 할 필요는 없겠다 싶어서 H100 한장으로 lora 먹여서 튜닝을 진행했다.

3. 데이터 추가

지난 글에서 추가하고 싶다 했던 두가지 데이터셋을 추가 처리 완료했다.

그리고 한국어 테스트셋을 완성했다. https://github.com/rtzr/Awesome-Korean-Speech-Recognition 여기서 주요 영역별 회의 음성을 제외한 6개의 테스트셋과 common voice, fleurs를 추가하여 총 8개 split이다.

결과

whisper base를 튜닝한 것은 확실히 효과가 좋았다. 특히 komix 데이터셋 + batch size 키운 효과가 확실해보였다.

그런데 turbo 모델은 결과가 상당히 안좋다. ksponspeech 같이 filler word가 많은 테스트셋의 경우는 lora 세팅에서 학습한 효과를 많이 보긴 했지만 그 외 세팅에서는 전부 소폭 안좋아졌다. 더군다나 full tuning에서는 4k step vs 20k step에서 보듯이 학습이 진행될수록 안 좋아졌다.

파라미터 튜닝이 많이 필요해보인다. 다행인 점은 cer 말고 val loss 만으로도 어느정도는 성능 유추가 가능해보인다.

 

다음 할 것들

1. hyper parameter search

jax finetuning 코드가 틀리지 않았다는 가정하에서 parameter tuning이 필요해보인다. 

2. torch_xla로 코드 전환

생각보다 스루풋이 굉장히 안나오고 있다. MXU가 17% 정도 나오고 있는데 이게 정상적인건지 잘 모르는 상황에서 그나마 의심스러운 것은 pytorch dataloader + jax 다 보니까 cpu -> tpu device 전환이 문제가 아닐까 하는 생각이다. 그래서 좀 더 pytorch 친화적일것이라 생각되는 torch_xla로 코드 전환을 해볼 예정이다.

3. Splah Attention 적용

jax든 torch_xla든 위스퍼 코드에는 naive attention을 쓰고 있다. head만 128로 패딩해주면 splash attention을 쓸 수 있을걸로 보인다. 

반응형

'Machine Learning > Worklog' 카테고리의 다른 글

Whisper 한국어 Fine-tuning 1주차  (1) 2025.02.16