モデル学習:SageMaker Trainingとハイパーパラメータ最適化

モデル学習はパイプラインの要

製薬・医療分野での機械学習活用は年々進展しています。

これらのプロジェクトで中核を担うのが モデル学習(training) です。

AWS SageMakerを使えば、クラウド上でスケーラブルかつ再現性のある学習が可能です。さらに、ハイパーパラメータ最適化(HPO) によって、AUCやRMSEといった評価指標を自動で改善できます。

モデル学習の選択肢

(1) XGBoost





from sagemaker.estimator import Estimator
import sagemaker

role = sagemaker.get_execution_role()

xgb_estimator = Estimator(
    image_uri=sagemaker.image_uris.retrieve("xgboost", "ap-northeast-1", version="1.7-1"),
    role=role,
    instance_count=1,
    instance_type="ml.m5.xlarge",
    output_path="s3://your-bucket/models/admet/"
)

xgb_estimator.fit({
    "train": "s3://your-bucket/processed/train/",
    "validation": "s3://your-bucket/processed/validation/"
})

(2) PyTorch





from sagemaker.pytorch import PyTorch

pytorch_estimator = PyTorch(
    entry_point="train.py",
    role=role,
    instance_count=1,
    instance_type="ml.g4dn.xlarge",  # GPUインスタンス
    framework_version="1.12",
    py_version="py38",
    output_path="s3://your-bucket/models/image/"
)

pytorch_estimator.fit({
    "train": "s3://your-bucket/processed/images/train/",
    "validation": "s3://your-bucket/processed/images/val/"
})

(3) 自作Dockerイメージ

FROM pytorch/pytorch:1.12.1-cuda11.3-cudnn8-runtime
RUN pip install rdkit-pypi deepchem
COPY train.py /opt/ml/code/train.py
ENV SAGEMAKER_PROGRAM train.py

このイメージをECRに登録し、Estimatorで image_uri として指定します。

ハイパーパラメータ最適化(HPO)

機械学習モデルの性能は、学習率・木の深さ・正則化パラメータなどの設定によって大きく変わります。
SageMakerの HyperparameterTuner を使えば、最適な組み合わせを自動探索できます。

例:XGBoostでAUC最大化を目標に探索

from sagemaker.tuner import HyperparameterTuner, ContinuousParameter, IntegerParameter

hyperparameter_ranges = {
    "eta": ContinuousParameter(0.01, 0.2),
    "max_depth": IntegerParameter(3, 10),
    "subsample": ContinuousParameter(0.5, 1.0)
}

objective_metric_name = "validation:auc"

tuner = HyperparameterTuner(
    estimator=xgb_estimator,
    objective_metric_name=objective_metric_name,
    hyperparameter_ranges=hyperparameter_ranges,
    metric_definitions=[{"Name": "validation:auc", "Regex": "auc:([0-9\\.]+)"}],
    max_jobs=20,
    max_parallel_jobs=3
)

tuner.fit({
    "train": "s3://your-bucket/processed/train/",
    "validation": "s3://your-bucket/processed/validation/"
})

製薬・医療分野での応用例

まとめ

弊社では、製薬・医療研究向けに モデル学習からデプロイまで一貫したMLOps基盤構築 を支援しています。

👉 お問い合わせはこちら