TRC와 함께한 MaxText 후기
TPU v4-64에서 Gemma3 파인튜닝을 위해서 Keras3 대신 MaxText를 새롭게 사용해본 경험을 공유해보고자 한다.
GCP에서 TPU Queued Resources를 생성한다.
이후 worker 0 (GCP에서 나오는 TPU ip address)에 ssh를 연결해준다.
(앞으로 특별한 말이 없다면 tpu node worker 0에서 실행한다)
maxtext를 git에서 clone 해준다.
git clone https://github.com/AI-Hypercomputer/maxtext && cd maxtext
이후에는 공식문서의 내용에 따라준다.
ZONE=<zone>
gcloud config set compute/zone $ZONE
ssh-keygen -f ~/.ssh/google_compute_engine
TPU_PREFIX=$YOUR_TPU_NAME
python3 multihost_runner.py --TPU_PREFIX=$TPU_PREFIX --COMMAND="bash setup.sh" --INTERNAL_IP=true
MaxText는 기본적으로 Gemma 1-3 Series를 지원한다.
Gemma3를 사용하기 위해서는 이 파일을 참고하면 된다.
Kaggle에 올라와 있는 Flax모델을 MaxText를 위해서 Checkpoint Conversion을 진행해줘야 한다. Kagglehub를 이용해서 Gemma3 Flax 체크포인트 파일을 받아준다.
import kagglehub
path = kagglehub.model_download("google/gemma-3/flax/gemma3-4b-it")
print("Path to model files:", path)
모델은 ~/.cache/kagglehub/models/google/gemma-3/flax/gemma3-4b-it/1
폴더에 다운로드가 된다.
이후 convert_gemma3_chkpt.py를 이용해 Conversion해준다.
다만, TPU multihost 환경이기에 cpu 플랫폼으로 제한해야 한다. (아니면, 그냥 다른 호스트에서 해도 된다)
JAX_PLATFORMS=cpu python3 -m MaxText.convert_gemma3_chkpt --base_model_path ~/.cache/kagglehub/models/google/gemma-3/flax/gemma3-4b-it/1/gemma3-4b-it/ --maxtext_model_path gs://<bucket-name-1> --model_size 4b
마지막으로 MaxText/configs/base.yml를 수정해야 한다.
model_name: "gemma3-4b"
base_output_directory: "gs://<bucket-name-2>"
load_parameters_path: "gs://<bucket-name-1>/0/items"
tokenizer_path: "google/gemma-3-4b-it"
dataset_type: hf
hf_path: '<hf-dataset-path>'
이제 시작해보자
RUN_NAME="<unique-run-name>"
python3 multihost_runner.py --TPU_PREFIX=$TPU_PREFIX --COMMAND="bash preflight.sh PLATFORM=GCE && numactl --membind 0 --cpunodebind=0 python3 -m MaxText.train MaxText/configs/base.yml run_name=$RUN_NAME" --INTERNAL_IP=true
... (작성중)