TPU Pods Gemma3 파인튜닝 정리

최근 구글이 출시한 Gemma3를 TPU Pods에서 파인튜닝해보기 위해 삽질중이다.

이 포스트는 삽질을 정리하기 위해 작성하고 있다.

요약
2025-03-20 gemma-llm으로 시도해보았지만 실패했다.
2025-03-21 Google Deepmind gemma-3 report 팀에 이메일을 보냈다.
2025-03-22 🤗 Optimum TPU에서 파인튜닝이 되는 것을 확인했다. EasyDeL에 대해 알게 되었다.
2025-03-26 keras-hub에 gemma3가 merged됨!


우선, Gemma v1 모델은 https://github.com/deveworld/Gemma-EasyLM/ 을 통해 파인튜닝 했으며, Beomi님의 Gemma-EasyLM을 기반으로 하였다.

상세하게 보면 HuggingFace Transformers의 FlaxGemma Modeling을 이용한 것이다.

Gemma2 모델은 Keras3의 Keras-nlp, 현재 Keras-hub를 통해 파인튜닝 하였다.

파인튜닝이 완료되면 코드와 함께 공개할 예정이었지만, Gemma3가 출시되며 모두 엎어졌다.

이제 Gemma3가 출시되고 파인튜닝을 시도해보려 했지만, HF에 Flax Modeling도 없고, Keras-hub에서도 아직 추가중이다. (keras-hub#2152)

그래서 이번엔 Google DeepMind의 공식 구현인 (그러나, Google 공식 제품은 아닌) gemma-llm을 활용해 파인튜닝을 해보기로 했다.

그러나, 은탄환은 없는법이다. 단일 TPU VM에서는 (많은 종속성 에러와 함께)동작하지만, TPU Pods에서는 드라이버 단에서부터 오류가 난다.

해결 방법을 찾기 위해 Gemma3 팀에 [email protected] 메일로 직접 컨택을 해보았다.

대충 `Gemma3 자체가 TPU Pods에서 학습된거 같은데, 그때 쓴게 뭐고 코드가 있으신가요?` 정도의 내용이었다.

물론 아직까지 회신은 없다.

그래서 회신이 오기전까지 이것저것 알아보던중 torch/xla 위에서 돌아가는 🤗 Optimum TPU를 발견했다.

좀 유기된 프로젝트 같지만 이걸로 시도해보기로 했다.

물론 Gemma1까지만 지원하지만, Gemma3를 처음부터 Flax Modeling 하는 것 보단 나을 것이다. 필자가 Model 아키텍처를 Flax로 짜본 경험도 없고 말이다. 무엇보다 귀찮다.

우선 대충 코드를 보고 gemma3도 포함하게 바꿔준다. (알고보니 모델 타입이 gemma3_text 였다.)

이후 간단하게 exmaples 코드를 참조해서 작동하는지 확인해보았다.

import torch
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
from transformers import AutoTokenizer, AutoModelForCausalLM
from optimum.tpu import fsdp_v2
from datasets import load_dataset

fsdp_v2.use_fsdp_v2()

model_id = "google/gemma-3-4b-pt"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)

def preprocess_function(sample):
    instruction = f"### Instruction\n{sample['instruction']}"
    context = f"### Context\n{sample['context']}" if len(sample["context"]) > 0 else None
    response = f"### Answer\n{sample['response']}"
    # join all the parts together
    prompt = "\n\n".join([i for i in [instruction, context, response] if i is not None])
    prompt += tokenizer.eos_token
    sample["text"] = prompt
    return sample

dataset = load_dataset("databricks/databricks-dolly-15k", split="train")
data = dataset.map(preprocess_function, remove_columns=list(dataset.features))

# Set up PEFT LoRA for fine-tuning.
lora_config = LoraConfig(
    r=8,
    target_modules=["k_proj", "v_proj"],
    task_type="CAUSAL_LM",
)

# Set up the FSDP arguments
fsdp_training_args = fsdp_v2.get_fsdp_training_args(model)

# Set up the trainer
training_arguments = SFTConfig(
    run_name="gemma3-optimum-tpu-test",
    dataset_text_field="text",
    per_device_train_batch_size=64,
    num_train_epochs=32,
    max_steps=-1,
    output_dir="./output",
    optim="adafactor",
    logging_steps=1,
    dataloader_drop_last = True,  # Required for FSDPv2.
    max_seq_length=1024,
    packing=True,
    **fsdp_training_args,
)

trainer = SFTTrainer(
    model=model,
    args=training_arguments,

    train_dataset=data,
    peft_config=lora_config,
)
trainer.train()

결과는...

된다! 좀 느리지만 어쨌든 된다!

학습은 되지만 Optimum-tpu에서 `Optimum-TPU support and is optimized for v5e and v6e TPUs.` 라고 한 만큼 TRC에서 제공받은 v4 TPU에서는 아직 종속성 문제도 많다.

조금 더 살펴본 다음 Warmup Cosine Decay도 달고 Dataset도 불러오게 해서 코드를 깔끔하게 짜야겠다.

Gemma3는 아직 출시된지 얼마 안되서 하이퍼 파라미터도 건드리면서 나은걸 찾아봐야 한다.

메일 답장을 기다리는 동안 Optimum-tpu로 열심히 작업해봐야 겠다.

그런 와중에 EasyDel이라는 JAX 기반 프레임워크를 발견했다. Gemma3도 지원하고 있다!!

다른 분들이 이미 Flax로 Gemma3를 모델링 해놓은 것이​다. (감사합니다 ㅠㅜ...)

심지어는 Mistral, Phi3부터 Exaone까지 전부 있다..!!


알아보던 와중 keras-hub에 Gemma3 feat가 merge되었다.

깔끔하게 keras3를 써서 finetune 하는 것이 제일 나을 것 같다.

Read more

RTX5090 체험 후기 | Gcube 지큐브

RTX5090 체험 후기 | Gcube 지큐브

최근 gcube RTX 5090 체험 테스트에 선정되어 무상으로 체험해보게 되었다. 그 논란의 물량도 얼마 없어 돈이 있어도 구하기 어려운 RTX 5090을, 심지어 무료로 말이다! 게다가 5090뿐만 아니라 4090, 5080도 함께 제공받았다. 모두 현시점에서 가장 성능이 좋은 소비자용 그래픽카드 3종류이다. 메모리가 작고 대역폭 병목을 제외한 성능만 본다면 현존 최고 성능이다. 이들

By Dev. World

KorT: LLM이 평가하는 한국어 번역 벤치마크

한국어 번역 품질 벤치마크 KorT를 출시했습니다! 최근 한강 작가님의 작품이 노벨 문학상을 수상하며 전 세계의 주목을 받았던 일을 기억하시나요? 사실, 상을 수상한 배경 뒤에는 좋은 번역이 있었습니다. 좋은 번역은 단순한 언어 변환을 넘어, 우리의 문화와 이야기를 세계 무대에 성공적으로 선보이는 데 결정적인 역할을 합니다. 이처럼 번역의 중요성은 점점 더 커지고

By Dev. World

한영 번역기 Gemago 개발기

최근 AI가 화제입니다. 그러나 온디바이스 AI는 아직까지 많지 않은 편입니다. 외부로 유출되면 안되는 민감한 내용을 포함하면 API를 이용하기 힘들기 때문에 수요는 매우 많음에도 불구하고 말이죠. 특히 제 경우에는 개인정보가 포함된 내용을 번역하는데에서 어려움이 많았습니다. 그래서 저는 Google에서 새로 공개한 Gemma 모델을 활용해 소형 언어 모델의 장점을 최대한 살린 한-영 번역기

By Dev. World