Back to Blog
RayDistributed MLMachine LearningRay TrainRay TuneRay ServePythonKubernetesMLOpsGPU Training

Ray for Distributed ML — Train, Tune, Serve, and Scale Across Clusters

A practical guide to Ray for distributed machine learning: Ray Core remote tasks and actors with @ray.remote, the plasma object store for zero-copy shared memory between workers, Ray Data for scalable dataset preprocessing with lazy map_batches transformations and direct Parquet reads from S3, Ray Train for multi-GPU and multi-node distributed PyTorch training with DistributedDataParallel wrapping, fault-tolerant checkpointing on S3 with FailureConfig max_failures for spot instance resilience, ray.train.report() for per-epoch metric and checkpoint reporting, Ray Tune for distributed hyperparameter optimization with the ASHA scheduler for aggressive early stopping, Optuna Bayesian search for smarter candidate generation, Population-Based Training for mid-training hyperparameter mutation, MLflowLoggerCallback for automatic experiment tracking across all trials, Ray Serve for scalable model serving with @serve.batch request batching for GPU efficiency, autoscaling_config with min/max replicas and target_ongoing_requests, multi-model deployment graphs with Router actors binding Preprocessor and Classifier deployments, KubeRay operator with RayCluster for persistent clusters, RayJob for ephemeral per-run clusters that auto-cleanup after job completion, RayService for zero-downtime rolling upgrades of Ray Serve applications, GPU worker node pools with spot instance tolerations and Karpenter NodePool integration, and a 10-point production checklist covering version pinning, head node CPU isolation, checkpoint storage, fault tolerance testing, and Prometheus metrics scraping.

2026-06-17

Why Ray for Distributed ML?

Ray is an open-source distributed computing framework from Anyscale that unifies the ML lifecycle under a single Python-native runtime. Unlike Spark — which is optimized for batch ETL and SQL analytics — Ray is purpose-built for heterogeneous ML workloads: it handles CPU-bound preprocessing, GPU-bound training, hyperparameter search, and real-time inference within a single cluster, sharing resources dynamically instead of running separate infrastructure for each stage.

The key insight behind Ray is that ML teams spend most of their time gluing together incompatible distributed systems: a Spark cluster for data prep, a separate training cluster for GPU jobs, a third system for HPO, and yet another for serving. Ray replaces this zoo with a unified compute layer. If you are already running MLflow-tracked training pipelines and hitting single-machine limits, Ray is the natural next step.

Ray Core

Remote tasks and actors over any Python function. The distributed computing primitive that powers all higher-level libraries.

Ray AI Libraries

Ray Data, Ray Train, Ray Tune, and Ray Serve — high-level abstractions for each phase of the ML lifecycle built on Ray Core.

KubeRay

Kubernetes operator for Ray clusters. RayCluster, RayJob, and RayService CRDs for production deployment with autoscaling and spot instance support.

Ray Core — Remote Tasks and Actors

Ray Core provides two primitives: remote tasks (stateless functions that run on any worker) and actors (stateful objects that live on a specific worker, maintain local state, and handle method calls asynchronously). Both are decorated with @ray.remote and return futures (ObjectRef) instead of values.

import ray
import time
import numpy as np

# Initialize Ray — auto-detects local CPUs, or connects to a cluster
ray.init()  # local: ray.init()  |  cluster: ray.init("ray://head-node:10001")

# ── Remote tasks: stateless parallel functions ────────────────────────
@ray.remote
def preprocess_shard(shard_path: str) -> np.ndarray:
    """Runs on any available worker. No shared state."""
    data = np.load(shard_path)
    return (data - data.mean()) / data.std()

# Submit 100 shards in parallel — returns futures immediately
shards = [f"s3://data/shard_{i:04d}.npy" for i in range(100)]
futures = [preprocess_shard.remote(path) for path in shards]

# Block on results (all 100 run in parallel across workers)
results = ray.get(futures)   # list of np.ndarray
combined = np.concatenate(results, axis=0)

# ── Resource hints: reserve GPU or custom resources ───────────────────
@ray.remote(num_gpus=1, num_cpus=4)
def gpu_inference(batch: np.ndarray) -> np.ndarray:
    import torch
    model = torch.load("/model/checkpoint.pt")
    with torch.no_grad():
        return model(torch.from_numpy(batch).cuda()).cpu().numpy()

# ── Actors: stateful objects with persistent local state ──────────────
@ray.remote
class ParameterServer:
    """Central parameter store — actors serialize method calls automatically."""
    def __init__(self, model_size: int):
        self.params = np.zeros(model_size, dtype=np.float32)
        self.update_count = 0

    def push_gradient(self, grad: np.ndarray, lr: float = 0.001) -> None:
        self.params -= lr * grad
        self.update_count += 1

    def pull_params(self) -> np.ndarray:
        return self.params.copy()

    def get_update_count(self) -> int:
        return self.update_count

# Create a named actor — survives driver restarts with lifetime="detached"
ps = ParameterServer.options(name="param_server").remote(model_size=10_000)

# Workers push gradients asynchronously
@ray.remote
def worker_train(worker_id: int, ps_handle) -> None:
    for step in range(100):
        params = ray.get(ps_handle.pull_params.remote())
        grad = np.random.randn(10_000).astype(np.float32)  # simulated gradient
        ps_handle.push_gradient.remote(grad, lr=0.001)
        if step % 10 == 0:
            count = ray.get(ps_handle.get_update_count.remote())
            print(f"Worker {worker_id} step {step}, total updates: {count}")

# Launch 4 workers in parallel
ray.get([worker_train.remote(i, ps) for i in range(4)])

# ── Object store: zero-copy shared memory between tasks ───────────────
# Put large array in Ray's plasma object store once — tasks read it zero-copy
large_dataset = np.random.rand(10_000_000, 128).astype(np.float32)
dataset_ref = ray.put(large_dataset)   # stored in shared memory

@ray.remote
def process_slice(data_ref, start: int, end: int) -> float:
    data = ray.get(data_ref)  # zero-copy read from plasma
    return float(data[start:end].mean())

# All workers share the same plasma object — no copies, no serialization cost
chunk = 1_000_000
means = ray.get([
    process_slice.remote(dataset_ref, i * chunk, (i + 1) * chunk)
    for i in range(10)
])

Note

Ray's plasma object store uses shared memory — large NumPy arrays passed between tasks on the same node are zero-copy reads. This is dramatically faster than serializing arrays through a message broker or network filesystem. For cross-node transfers, Ray uses a distributed object store with Apache Arrow IPC serialization.

Ray Data — Scalable Dataset Preprocessing

Ray Data is a distributed dataset library that handles the preprocessing gap between raw storage and model training. It reads Parquet, CSV, JSON, images, and TFRecords from S3/GCS/HDFS, applies transformations in parallel across workers, and streams batches directly into the training loop — eliminating the need for a separate Spark preprocessing job. Ray Data integrates natively with Ray Train so that data loading and training share the same cluster resources.

import ray
from ray.data import read_parquet, read_images
import numpy as np

ray.init()

# ── Read Parquet from S3 — automatic parallelism based on file count ──
ds = ray.data.read_parquet(
    "s3://ml-data/features/",
    columns=["user_id", "feature_vec", "label"],
    parallelism=200,    # number of parallel read tasks
)

print(ds.schema())     # inspects schema without full materialization
print(ds.count())      # triggers execution, returns row count

# ── Transformations — lazy, fused into a single execution plan ────────
def normalize_features(batch: dict) -> dict:
    """batch is a dict of numpy arrays — one key per column."""
    features = batch["feature_vec"].astype(np.float32)
    mean = features.mean(axis=0)
    std  = features.std(axis=0) + 1e-8
    batch["feature_vec"] = (features - mean) / std
    return batch

def encode_labels(batch: dict) -> dict:
    label_map = {"positive": 1, "negative": 0, "neutral": 2}
    batch["label_int"] = np.array(
        [label_map[l] for l in batch["label"]], dtype=np.int64
    )
    return batch

# .map_batches() applies transform to Arrow record batches in parallel
ds = (
    ds
    .map_batches(normalize_features, batch_size=1024, batch_format="numpy")
    .map_batches(encode_labels,      batch_size=1024, batch_format="numpy")
    .filter(lambda row: row["label"] != "neutral")   # row-level filter
    .drop_columns(["label"])                          # remove string label
)

# ── Train/test split ──────────────────────────────────────────────────
train_ds, test_ds = ds.train_test_split(test_size=0.1, shuffle=True, seed=42)

print(f"Train rows: {train_ds.count()}, Test rows: {test_ds.count()}")

# ── Materialize to disk for repeated access ───────────────────────────
# Save as Parquet — preserves column types, readable by downstream tools
train_ds.write_parquet("s3://ml-data/preprocessed/train/")
test_ds.write_parquet( "s3://ml-data/preprocessed/test/")

# ── Image preprocessing pipeline ─────────────────────────────────────
from torchvision import transforms
from PIL import Image
import io

image_ds = ray.data.read_images(
    "s3://ml-data/images/",
    include_paths=True,
    size=(224, 224),    # resize all images on read
)

def augment_image(batch: dict) -> dict:
    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])
    augmented = []
    for img_bytes in batch["image"]:
        img = Image.fromarray(img_bytes)
        augmented.append(transform(img).numpy())
    batch["tensor"] = np.stack(augmented)
    return batch

image_ds = image_ds.map_batches(
    augment_image,
    batch_size=64,
    num_gpus=0,         # augmentation runs on CPU workers
    concurrency=8,      # 8 parallel transform workers
)

Ray Train — Distributed Training with PyTorch and Lightning

Ray Train orchestrates distributed training across multiple GPUs and nodes. It handles process group initialization, gradient synchronization, checkpoint saving and restoring, and fault tolerance. For PyTorch, it wraps DistributedDataParallel automatically — you write a single-worker training function and Ray Train replicates it across the requested worker count. For PyTorch Lightning users, RayLightningEnvironment and RayTrainReportCallback plug in with minimal code changes.

import ray
from ray.train import ScalingConfig, RunConfig, CheckpointConfig
from ray.train.torch import TorchTrainer
import ray.train.torch as ray_torch
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

ray.init()

# ── Define the per-worker training function ───────────────────────────
def train_func(config: dict):
    """This function runs on EVERY worker. Ray handles DDP setup."""
    import torch.distributed as dist

    # Hyperparameters from config dict (populated by Ray Tune or caller)
    lr          = config.get("lr", 1e-3)
    batch_size  = config.get("batch_size", 256)
    epochs      = config.get("epochs", 10)
    hidden_dim  = config.get("hidden_dim", 256)

    # ── Model ─────────────────────────────────────────────────────────
    class MLP(nn.Module):
        def __init__(self, input_dim, hidden_dim, output_dim):
            super().__init__()
            self.net = nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Linear(hidden_dim, hidden_dim // 2),
                nn.ReLU(),
                nn.Linear(hidden_dim // 2, output_dim),
            )
        def forward(self, x):
            return self.net(x)

    model = MLP(input_dim=128, hidden_dim=hidden_dim, output_dim=3)

    # ray_torch.prepare_model() wraps with DDP and moves to correct GPU
    model = ray_torch.prepare_model(model)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    criterion = nn.CrossEntropyLoss()

    # ── Data loading — each worker gets its own shard via Ray Data ────
    # In practice: load from Ray Data dataset passed via ray.train.get_dataset_shard()
    train_dataset = ray.train.get_dataset_shard("train")
    for epoch in range(epochs):
        model.train()
        total_loss = 0.0
        num_batches = 0

        for batch in train_dataset.iter_torch_batches(
            batch_size=batch_size,
            dtypes={"feature_vec": torch.float32, "label_int": torch.long},
        ):
            features = batch["feature_vec"]
            labels   = batch["label_int"]

            optimizer.zero_grad()
            logits = model(features)
            loss   = criterion(logits, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            total_loss += loss.item()
            num_batches += 1

        scheduler.step()
        avg_loss = total_loss / max(num_batches, 1)

        # Report metrics and save checkpoint — only on rank 0 by default
        ray.train.report(
            metrics={"loss": avg_loss, "epoch": epoch, "lr": scheduler.get_last_lr()[0]},
            checkpoint=ray.train.Checkpoint.from_dict(
                {"model_state": model.module.state_dict(),
                 "optimizer_state": optimizer.state_dict(),
                 "epoch": epoch}
            ),
        )

# ── Configure distributed training ───────────────────────────────────
scaling_config = ScalingConfig(
    num_workers=4,          # 4 worker processes (each gets 1 GPU)
    use_gpu=True,
    resources_per_worker={"CPU": 4, "GPU": 1},
    placement_strategy="SPREAD",  # spread workers across nodes
)

run_config = RunConfig(
    name="mlp_training_run",
    storage_path="s3://ml-runs/ray-train/",
    checkpoint_config=CheckpointConfig(
        num_to_keep=3,             # keep only 3 best checkpoints
        checkpoint_score_attribute="loss",
        checkpoint_score_order="min",
    ),
    failure_config=ray.train.FailureConfig(max_failures=2),  # retry on worker failure
)

# ── Load preprocessed Ray Data datasets ──────────────────────────────
train_ds = ray.data.read_parquet("s3://ml-data/preprocessed/train/")
test_ds  = ray.data.read_parquet("s3://ml-data/preprocessed/test/")

trainer = TorchTrainer(
    train_loop_per_worker=train_func,
    train_loop_config={
        "lr": 5e-4,
        "batch_size": 512,
        "epochs": 20,
        "hidden_dim": 512,
    },
    scaling_config=scaling_config,
    run_config=run_config,
    datasets={"train": train_ds},   # shard automatically across workers
)

result = trainer.fit()
print(f"Best checkpoint: {result.best_checkpoints[0]}")
print(f"Metrics: {result.metrics}")

Note

Ray Train's fault tolerance is backed by an actor-based checkpoint protocol. If a worker process crashes mid-epoch, Ray Train restarts the failed worker, loads the latest checkpoint, and resumes training from the last reported step. For spot instance training on Kubernetes, set max_failures=3 in FailureConfig — this is enough to survive typical spot preemption patterns without manual intervention.

Ray Tune — Distributed Hyperparameter Optimization

Ray Tune runs hyperparameter search across hundreds of trials in parallel, each trial running on a separate set of cluster resources. It supports population-based algorithms (ASHA, PBT), Bayesian optimization via Optuna and Ax, and grid/random search as baselines. The ASHA (Asynchronous Successive Halving) scheduler is the default for deep learning: it aggressively terminates underperforming trials early, freeing GPU hours for promising configurations.

import ray
from ray import tune
from ray.tune.schedulers import ASHAScheduler
from ray.tune.search.optuna import OptunaSearch
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig, RunConfig

ray.init()

# ── Search space definition ───────────────────────────────────────────
search_space = {
    "lr":           tune.loguniform(1e-5, 1e-2),   # log-uniform between 1e-5 and 1e-2
    "batch_size":   tune.choice([128, 256, 512, 1024]),
    "hidden_dim":   tune.choice([128, 256, 512]),
    "dropout":      tune.uniform(0.1, 0.5),
    "weight_decay": tune.loguniform(1e-6, 1e-3),
}

# ── ASHA: aggressive early stopping for deep learning HPO ────────────
asha_scheduler = ASHAScheduler(
    metric="loss",
    mode="min",
    max_t=30,            # maximum epochs per trial
    grace_period=5,      # minimum epochs before pruning
    reduction_factor=3,  # keep top 1/3 of trials at each rung
)

# ── Optuna Bayesian search for smarter candidate generation ──────────
optuna_search = OptunaSearch(
    metric="loss",
    mode="min",
    points_to_evaluate=[
        # Seed with known-good configurations
        {"lr": 1e-3, "batch_size": 256, "hidden_dim": 256,
         "dropout": 0.3, "weight_decay": 1e-4},
    ],
)

# ── Trainable: wraps the TorchTrainer for Tune ───────────────────────
def trainable(config):
    """Called once per trial. config is sampled from search_space."""
    trainer = TorchTrainer(
        train_loop_per_worker=train_func,   # same train_func as before
        train_loop_config=config,
        scaling_config=ScalingConfig(
            num_workers=2, use_gpu=True,
            resources_per_worker={"CPU": 2, "GPU": 1},
        ),
        run_config=RunConfig(storage_path="s3://ml-runs/ray-tune/"),
    )
    result = trainer.fit()
    # Tune picks up reported metrics from train_func via ray.train.report()
    return result.metrics

tuner = tune.Tuner(
    trainable,
    param_space=search_space,
    tune_config=tune.TuneConfig(
        num_samples=50,           # total trials to run
        scheduler=asha_scheduler,
        search_alg=optuna_search,
        max_concurrent_trials=8,  # max parallel trials (limited by cluster GPUs)
    ),
    run_config=RunConfig(
        name="hpo_run_v1",
        storage_path="s3://ml-runs/ray-tune/",
    ),
)

results = tuner.fit()

# ── Analyze results ───────────────────────────────────────────────────
best_result = results.get_best_result(metric="loss", mode="min")
print(f"Best config:    {best_result.config}")
print(f"Best loss:      {best_result.metrics['loss']:.4f}")
print(f"Best checkpoint: {best_result.checkpoint}")

# ── Population-Based Training (PBT): mutate hyperparameters mid-training
from ray.tune.schedulers import PopulationBasedTraining

pbt_scheduler = PopulationBasedTraining(
    time_attr="training_iteration",
    metric="loss",
    mode="min",
    perturbation_interval=5,   # perturb every 5 epochs
    hyperparam_mutations={
        "lr":           tune.loguniform(1e-5, 1e-2),
        "dropout":      tune.uniform(0.1, 0.5),
    },
    quantile_fraction=0.25,    # top 25% survive, bottom 25% get perturbed
)

Once the best hyperparameters are found with Ray Tune, tracking the final trained model in MLflow and promoting it through staging to production follows the same MLOps CI/CD champion-challenger promotion patterns — Ray Tune's best checkpoint can be loaded directly as the challenger model for AUC comparison.

Ray Serve — Production Model Serving

Ray Serve is an ML-native serving framework built on Ray. Unlike Flask or FastAPI wrappers around a model, Serve handles request batching, autoscaling, multi-model deployment graphs, and GPU resource allocation natively. Each deployment is a Python class decorated with @serve.deployment that runs as a Ray actor, giving it persistent GPU memory and request queue semantics.

import ray
from ray import serve
from ray.serve.handle import DeploymentHandle
import numpy as np
import torch

ray.init()
serve.start(detached=True, http_options={"host": "0.0.0.0", "port": 8000})

# ── Basic deployment: GPU model with request batching ─────────────────
@serve.deployment(
    num_replicas=2,
    ray_actor_options={"num_gpus": 1, "num_cpus": 2},
    max_ongoing_requests=50,    # queue depth per replica
)
class ClassifierDeployment:
    def __init__(self, model_path: str):
        self.model = torch.load(model_path, map_location="cuda")
        self.model.eval()
        print(f"Model loaded on GPU: {torch.cuda.get_device_name(0)}")

    @serve.batch(max_batch_size=64, batch_wait_timeout_s=0.05)
    async def handle_batch(self, requests: list[dict]) -> list[dict]:
        """Batching decorator aggregates concurrent requests automatically."""
        features = torch.tensor(
            [req["features"] for req in requests],
            dtype=torch.float32, device="cuda"
        )
        with torch.no_grad():
            logits = self.model(features)
            probs  = torch.softmax(logits, dim=-1).cpu().numpy()

        return [
            {"label": int(np.argmax(probs[i])),
             "confidence": float(probs[i].max())}
            for i in range(len(requests))
        ]

    async def __call__(self, request) -> dict:
        body = await request.json()
        return await self.handle_batch(body)

# ── Autoscaling configuration ─────────────────────────────────────────
@serve.deployment(
    autoscaling_config={
        "min_replicas": 1,
        "max_replicas": 8,
        "target_ongoing_requests": 20,   # scale up when queue > 20 per replica
        "upscale_delay_s": 10,
        "downscale_delay_s": 60,
    },
    ray_actor_options={"num_gpus": 1},
)
class AutoscaledClassifier:
    def __init__(self):
        self.model = torch.load("/models/classifier.pt", map_location="cuda")
        self.model.eval()

    async def __call__(self, request) -> dict:
        body = await request.json()
        features = torch.tensor(body["features"], dtype=torch.float32).unsqueeze(0).cuda()
        with torch.no_grad():
            probs = torch.softmax(self.model(features), dim=-1).cpu().numpy()[0]
        return {"label": int(np.argmax(probs)), "confidence": float(probs.max())}

# ── Deployment graph: multi-model pipeline ────────────────────────────
@serve.deployment(num_replicas=2, ray_actor_options={"num_cpus": 2})
class Preprocessor:
    def __init__(self):
        self.mean = np.load("/models/feature_mean.npy")
        self.std  = np.load("/models/feature_std.npy")

    def preprocess(self, raw_features: list[float]) -> list[float]:
        arr = (np.array(raw_features, dtype=np.float32) - self.mean) / (self.std + 1e-8)
        return arr.tolist()

@serve.deployment(num_replicas=1, ray_actor_options={"num_cpus": 1})
class Router:
    """Orchestrates multi-stage inference: preprocess → classify."""
    def __init__(self, preprocessor: DeploymentHandle, classifier: DeploymentHandle):
        self.preprocessor = preprocessor
        self.classifier   = classifier

    async def __call__(self, request) -> dict:
        body    = await request.json()
        cleaned = await self.preprocessor.preprocess.remote(body["raw_features"])
        result  = await self.classifier.handle_batch.remote({"features": cleaned})
        return {"prediction": result, "request_id": body.get("id")}

# ── Bind deployment graph and deploy ─────────────────────────────────
preprocessor = Preprocessor.bind()
classifier   = ClassifierDeployment.bind(model_path="/models/classifier.pt")
router       = Router.bind(preprocessor, classifier)

serve.run(router, route_prefix="/predict")

Note

Ray Serve's @serve.batch decorator is one of its most impactful features for GPU efficiency. Without batching, each request triggers a separate forward pass — most of the GPU's compute capacity is idle. With batching, Serve holds incoming requests for up to batch_wait_timeout_s seconds, then processes them in a single forward pass. On a T4 GPU, batching typically reduces per-request latency by 3–5x and increases throughput 10–20x compared to single-request inference.

KubeRay — Running Ray on Kubernetes

KubeRay is the Kubernetes operator for Ray. It introduces three CRDs: RayCluster for persistent clusters, RayJob for submitting a single job to an ephemeral cluster, and RayService for Ray Serve deployments with zero-downtime upgrades. This approach parallels how the Spark Operator manages SparkApplication CRDs — but Ray clusters are persistent (head node stays up) while Spark clusters are ephemeral per job.

# ── Install KubeRay operator ──────────────────────────────────────────
helm repo add kuberay https://ray-project.github.io/kuberay-helm/
helm repo update

helm install kuberay-operator kuberay/kuberay-operator   --namespace kuberay-system   --create-namespace   --version 1.1.0

# ── RayCluster: persistent cluster for interactive workloads ──────────
# raycluster.yaml
apiVersion: ray.io/v1
kind: RayCluster
metadata:
  name: ray-ml-cluster
  namespace: ml
spec:
  rayVersion: "2.10.0"

  headGroupSpec:
    rayStartParams:
      dashboard-host: "0.0.0.0"
      num-cpus: "0"          # head node: no tasks, only scheduling
    template:
      spec:
        containers:
          - name: ray-head
            image: rayproject/ray-ml:2.10.0-gpu
            resources:
              limits:
                cpu: "4"
                memory: "16Gi"
            ports:
              - containerPort: 6379   # GCS port
              - containerPort: 8265   # Dashboard
              - containerPort: 10001  # Client port

  workerGroupSpecs:
    - groupName: gpu-workers
      replicas: 4
      minReplicas: 1
      maxReplicas: 8         # autoscale between 1–8 GPU workers
      rayStartParams:
        num-gpus: "1"
      template:
        spec:
          containers:
            - name: ray-worker
              image: rayproject/ray-ml:2.10.0-gpu
              resources:
                limits:
                  cpu: "8"
                  memory: "32Gi"
                  nvidia.com/gpu: "1"
              env:
                - name: RAY_worker_register_timeout_seconds
                  value: "120"
          nodeSelector:
            node.kubernetes.io/instance-type: "g4dn.2xlarge"
          tolerations:
            - key: "nvidia.com/gpu"
              operator: "Exists"
              effect: "NoSchedule"

# ── RayJob: ephemeral cluster per job run ─────────────────────────────
# rayjob.yaml
apiVersion: ray.io/v1
kind: RayJob
metadata:
  name: training-run-001
  namespace: ml
spec:
  submissionMode: K8sJobMode
  entrypoint: "python /app/train.py --config /app/config.yaml"
  runtimeEnvYaml: |
    pip:
      - torch==2.2.0
      - ray[train]==2.10.0
      - mlflow==2.13.0
    env_vars:
      MLFLOW_TRACKING_URI: "http://mlflow.ml.svc:5000"
      S3_BUCKET: "ml-data"
  shutdownAfterJobFinishes: true   # auto-cleanup cluster when job completes
  ttlSecondsAfterFinished: 3600    # delete Job resource after 1 hour
  rayClusterSpec:
    rayVersion: "2.10.0"
    headGroupSpec:
      rayStartParams: {num-cpus: "0"}
      template:
        spec:
          containers:
            - name: ray-head
              image: my-registry/ml-training:v1.2.0
              resources:
                limits: {cpu: "4", memory: "16Gi"}
    workerGroupSpecs:
      - groupName: gpu-workers
        replicas: 2
        rayStartParams: {num-gpus: "1"}
        template:
          spec:
            containers:
              - name: ray-worker
                image: my-registry/ml-training:v1.2.0
                resources:
                  limits: {cpu: "8", memory: "32Gi", "nvidia.com/gpu": "1"}

# ── RayService: serving with zero-downtime upgrade ────────────────────
# rayservice.yaml
apiVersion: ray.io/v1
kind: RayService
metadata:
  name: classifier-service
  namespace: ml
spec:
  serviceUnhealthySecondThreshold: 120
  deploymentUnhealthySecondThreshold: 60
  serveConfigV2: |
    applications:
      - name: classifier
        route_prefix: /predict
        import_path: serve_app:router
        runtime_env:
          working_dir: "s3://ml-apps/serve/v1.3.0.zip"
          pip: ["torch==2.2.0", "numpy==1.26.0"]
        deployments:
          - name: Router
            num_replicas: 1
          - name: ClassifierDeployment
            num_replicas: 2
            ray_actor_options:
              num_gpus: 1
  rayClusterConfig:
    headGroupSpec:
      rayStartParams: {num-cpus: "0"}
      template:
        spec:
          containers:
            - name: ray-head
              image: rayproject/ray-ml:2.10.0-gpu
              resources:
                limits: {cpu: "4", memory: "16Gi"}

Integrating Ray with MLflow and Feature Stores

Ray Train and Ray Tune both integrate with MLflow for experiment tracking. Every trial in a Tune run can log metrics, parameters, and checkpoints to an MLflow tracking server, making HPO results searchable alongside baseline experiments. The MLflow autolog feature works inside Ray workers as long as the tracking URI is set as an environment variable on all worker pods.

import ray
import mlflow
from ray import tune
from ray.air.integrations.mlflow import MLflowLoggerCallback, setup_mlflow

ray.init()

# ── MLflow autologging inside a Ray Train worker ──────────────────────
def train_func_with_mlflow(config: dict):
    import mlflow.pytorch

    # setup_mlflow() initializes MLflow on each worker with the experiment name
    mlflow_config = setup_mlflow(
        config,
        experiment_name="ray-distributed-training",
        tracking_uri=config.get("mlflow_uri", "http://mlflow:5000"),
    )

    mlflow.pytorch.autolog(log_every_n_epoch=1, checkpoint=False)

    with mlflow.start_run(nested=True):
        mlflow.log_params(config)
        # ... training loop as before ...
        for epoch in range(config["epochs"]):
            loss = run_epoch(config)  # your training step
            mlflow.log_metric("train_loss", loss, step=epoch)
            ray.train.report({"loss": loss})

# ── MLflow callback for Ray Tune: log all trial metrics automatically ─
tuner = tune.Tuner(
    train_func_with_mlflow,
    param_space={
        "lr":          tune.loguniform(1e-4, 1e-2),
        "hidden_dim":  tune.choice([128, 256]),
        "epochs":      tune.choice([10, 20]),
        "mlflow_uri":  "http://mlflow.ml.svc:5000",
    },
    tune_config=tune.TuneConfig(num_samples=20),
    run_config=ray.train.RunConfig(
        callbacks=[
            MLflowLoggerCallback(
                tracking_uri="http://mlflow.ml.svc:5000",
                experiment_name="ray-tune-hpo",
                save_artifact=True,     # save trial checkpoints as MLflow artifacts
            )
        ]
    ),
)
results = tuner.fit()

# ── Feature loading from Feast inside Ray Data ─────────────────────────
# Ray Data + Feast: materialize online features into a Ray dataset for training
def load_features_from_feast(entity_df_path: str):
    """Fetch historical features from Feast for model training."""
    from feast import FeatureStore
    import pandas as pd

    store = FeatureStore(repo_path="/feast/feature_repo")
    entity_df = pd.read_parquet(entity_df_path)

    training_df = store.get_historical_features(
        entity_df=entity_df,
        features=[
            "user_features:click_rate_7d",
            "user_features:session_count_30d",
            "item_features:popularity_score",
            "item_features:avg_rating",
        ],
    ).to_df()
    return training_df

# Load features via Ray remote task — runs distributed across workers
@ray.remote
def fetch_partition(partition_path: str) -> "pd.DataFrame":
    return load_features_from_feast(partition_path)

partitions = [f"s3://data/entities/partition_{i}.parquet" for i in range(50)]
feature_dfs = ray.get([fetch_partition.remote(p) for p in partitions])

import pandas as pd
full_feature_df = pd.concat(feature_dfs, ignore_index=True)
ray_dataset = ray.data.from_pandas(full_feature_df)

Note

When integrating Ray with a Feast feature store, use get_historical_features() in Ray Data preprocessing tasks for training data generation, and Ray Serve's __init__ to load online features from the Redis store at serving time. This mirrors the pattern described in the Feast feature retrieval documentation and eliminates training-serving skew by using the same feature definitions for both.

Ray Production Checklist

1

Pin Ray versions across all components: the Ray cluster image, Python SDK, and all AI library packages (ray[train], ray[tune], ray[serve]) must match exactly — version skew between head and workers causes silent serialization failures

2

Set num-cpus: '0' on the head node in KubeRay to prevent the scheduler from placing computation tasks there — the head node runs GCS and the dashboard and CPU contention degrades scheduling latency cluster-wide

3

Configure max_ongoing_requests per Ray Serve replica to bound memory usage — without this limit, replicas accept unlimited requests and OOM when upstream traffic spikes overwhelm model inference throughput

4

Use RayJob (not RayCluster + manual submission) for batch training runs — RayJob creates an ephemeral cluster, submits the job, waits for completion, then cleans up all pods, eliminating idle GPU costs between training runs

5

Enable Ray Serve autoscaling with conservative downscale_delay_s (120s or more) to prevent replica thrashing on bursty traffic — each replica reload reloads the model into GPU memory, taking 30–90 seconds per GPU

6

Store Ray Train checkpoints on S3 or GCS (not local disk) by setting storage_path in RunConfig — local disk checkpoints are lost when the head node pod restarts and prevent fault-tolerant resumption

7

Set RAY_worker_register_timeout_seconds to 120 in worker pod env — the default 30s is too short for nodes pulling large container images (GPU containers are typically 5–15 GB) and causes spurious worker registration failures

8

Configure Prometheus scraping for Ray metrics: ray-head exposes /metrics on port 8080, and each worker exposes its own metrics endpoint — use a ServiceMonitor CRD to scrape all worker pods dynamically

9

Test fault tolerance explicitly: kill a worker pod mid-training and verify that Ray Train resumes from the last checkpoint on the replacement worker — do this in staging before relying on it in production spot instance pools

10

Use RayService for Ray Serve production deployments instead of serve.run() in a RayCluster — RayService handles zero-downtime rolling upgrades by deploying the new serve config before draining traffic from the old deployment

For related patterns: fine-tuning smaller open-source models on a single GPU is covered in Fine-Tuning Open Models with LoRA and QLoRA — Ray Train scales those same techniques to multi-node clusters when your dataset or model exceeds single-GPU capacity.

Hitting single-machine limits on ML training or hyperparameter search, managing separate Spark, training, and serving clusters, or looking to consolidate your ML infrastructure on Kubernetes?

We design and implement distributed ML platforms with Ray — from Ray Core actor design and Ray Data preprocessing pipeline configuration to Ray Train multi-GPU training setup with fault-tolerant S3 checkpointing, Ray Tune HPO with ASHA and Optuna search, MLflow integration for experiment tracking across all trials, Ray Serve deployment graph design with request batching and autoscaling, KubeRay operator installation with RayCluster and RayJob CRD configuration for ephemeral training jobs, RayService rolling upgrade configuration, GPU node pool setup with spot instance toleration and Karpenter autoscaler integration, and Prometheus metrics scraping for cluster observability. Let’s talk.

Let's Talk

Related Articles

DataSOps Consulting

Need help implementing this in production?

We build and operate data pipelines, AI systems, and observability stacks for engineering teams. Reach out for a free 30-minute architecture review.