From 2e83a7424a188e846ac4f8ed580828105bdb9675 Mon Sep 17 00:00:00 2001 From: Bowen Date: Wed, 7 Aug 2024 22:27:56 -0700 Subject: [PATCH] Deployed c6668d9 with MkDocs version: 1.6.0 --- index.html | 2 +- mnist/index.html | 35 ++--------------------------------- search/search_index.json | 2 +- 3 files changed, 4 insertions(+), 35 deletions(-) diff --git a/index.html b/index.html index 793014c..7e337b1 100644 --- a/index.html +++ b/index.html @@ -509,7 +509,7 @@

RedCoast

  • Loss function: execute the model and compute loss (e.g., cross-entropy);
  • Predict function to run the model and deliver outcomes (e.g., beam search).
  • -

    Redco automates all the remaining of pipeline execution such as data parallelism, multi-host related processing, distributed checkpointing, randomness controlling, logging, etc.

    +

    Redco automates all the remaining of pipeline execution such as data and model parallelism, multi-host related processing, distributed checkpointing, randomness controlling, logging, etc.

    diff --git a/mnist/index.html b/mnist/index.html index f4038ee..35a45f5 100644 --- a/mnist/index.html +++ b/mnist/index.html @@ -76,7 +76,7 @@
    - + Skip to content @@ -378,27 +378,12 @@ @@ -544,27 +529,12 @@ @@ -587,8 +557,7 @@

    MNIST Example

    -

    MNIST

    -

    A trivial MNIST example with RedCoast. Runnable by +

    This is a trivial MNIST example with RedCoast. Runnable by

    python main.py
     

    To simulate multiple devices in cpu-only envs, diff --git a/search/search_index.json b/search/search_index.json index 0f30c97..43a3381 100644 --- a/search/search_index.json +++ b/search/search_index.json @@ -1 +1 @@ -{"config":{"lang":["en"],"separator":"[\\s\\-]+","pipeline":["stopWordFilter"]},"docs":[{"location":"","title":"RedCoast","text":"

    Red Coast (redco) is a lightweight and user-friendly tool designed to automate distributed training and inference for large models while simplifying the ML pipeline development process without necessitating MLSys expertise from users.

    RedCoast supports Large Models + Complex Algorithms, in a lightweight and user-friendly manner:

    With RedCoast, to define a ML pipeline, only three functions are needed:

    Redco automates all the remaining of pipeline execution such as data parallelism, multi-host related processing, distributed checkpointing, randomness controlling, logging, etc.

    "},{"location":"deployer/","title":"Deployer","text":""},{"location":"deployer/#redco.deployers.deployer.Deployer","title":"Deployer","text":"

    Handles low-level operations to support Trainer and Predictor, e.g., automatic data/model parallelism, distributed checkpointing, data processing, logging, randomness controlling, etc.

    Attributes:

    Name Type Description workdir str

    Working directory for saving checkpoints and logs.

    mesh `jax.sharding.Mesh`

    Mesh used for model sharding.

    Source code in redco/deployers/deployer.py
    class Deployer:\n    \"\"\" Handles low-level operations to support Trainer and Predictor,\n        e.g., automatic data/model parallelism, distributed checkpointing,\n        data processing, logging, randomness controlling, etc.\n\n    Attributes:\n        workdir (str): Working directory for saving checkpoints and logs.\n        mesh (`jax.sharding.Mesh`): Mesh used for model sharding.\n    \"\"\"\n    def __init__(self,\n                 jax_seed,\n                 n_model_shards=1,\n                 verbose=True,\n                 workdir=None,\n                 n_processes=None,\n                 host0_address=None,\n                 host0_port=None,\n                 process_id=None,\n                 n_local_devices=None,\n                 run_tensorboard=False,\n                 wandb_init_kwargs=None):\n        \"\"\" Initializes a Deployer.\n\n        Args:\n            jax_seed (`jax.numpy.Array`): Seed for random number generation.\n            n_model_shards (int): Number of shards for running large model.\n            verbose (bool): Whether to enable verbose logging.\n            workdir (str):  Directory for saving logs and checkpoints.\n            n_processes (int):  For multi-host, number of processes/nodes.\n            host0_address (str):  For multi-host, address of the host0.\n            host0_port (int): For multi-host, port of the host0.\n            process_id (int): For multi-host, index of the current process.\n            n_local_devices (int): For multi-host, number of local devices.\n            run_tensorboard (bool):  Whether to enable TensorBoard logging.\n            wandb_init_kwargs (dict): wandb.init arguments if using wandb.\n        \"\"\"\n        if n_processes is None:\n            if 'SLURM_JOB_NUM_NODES' in os.environ:\n                n_processes = int(os.environ['SLURM_JOB_NUM_NODES'])\n                process_id = int(os.environ['SLURM_NODEID'])\n            else:\n                n_processes = 1\n\n        if n_processes > 1:\n            local_device_ids = None if n_local_devices is None \\\n                else list(range(n_local_devices))\n\n            if host0_port is None:\n                host0_port = DEFAULT_HOST0_PORT\n\n            jax.distributed.initialize(\n                coordinator_address=f'{host0_address}:{host0_port}',\n                num_processes=n_processes,\n                process_id=process_id,\n                local_device_ids=local_device_ids)\n\n        if workdir is not None:\n            os.makedirs(workdir, exist_ok=True)\n\n        self._verbose = verbose\n        self._workdir = workdir\n        self._logger = get_logger(verbose=verbose, workdir=workdir)\n\n        if wandb_init_kwargs is not None and jax.process_index() == 0:\n            import wandb\n            wandb.init(**wandb_init_kwargs)\n            self._wandb_log_fn = wandb.log\n        else:\n            self._wandb_log_fn = None\n\n        if run_tensorboard and jax.process_index() == 0:\n            from flax.metrics import tensorboard\n            self._summary_writer = tensorboard.SummaryWriter(workdir)\n        else:\n            self._summary_writer = None\n\n        self.log_info(\n            f'Local Devices: {jax.local_device_count()} / {jax.device_count()}')\n\n        self._rng = jax.random.PRNGKey(seed=jax_seed)\n        self._mesh = get_mesh(n_model_shards=n_model_shards)\n        self._checkpointer = ocp.PyTreeCheckpointer()\n\n    def get_local_global_micro_batch_size(self, per_device_batch_size):\n        \"\"\"Get local/global micro batch sizes based on per-device batch size.\"\"\"\n        if self._mesh is None:\n            local_micro_batch_size = \\\n                per_device_batch_size * jax.local_device_count()\n            global_micro_batch_size = \\\n                local_micro_batch_size * jax.process_count()\n        else:\n            global_micro_batch_size = local_micro_batch_size = \\\n                per_device_batch_size * self._mesh.shape['dp']\n\n        return local_micro_batch_size, global_micro_batch_size\n\n    def get_accumulate_grad_batches(\n            self, global_batch_size, per_device_batch_size):\n        \"\"\"Calculates the number of gradient accumulation batches.\"\"\"\n        _, global_micro_batch_size = self.get_local_global_micro_batch_size(\n            per_device_batch_size=per_device_batch_size)\n        assert global_batch_size % global_micro_batch_size == 0\n        accumulate_grad_batches = global_batch_size // global_micro_batch_size\n\n        return accumulate_grad_batches\n\n    def get_model_input_batches(self,\n                                examples,\n                                per_device_batch_size,\n                                collate_fn,\n                                shuffle,\n                                shuffle_rng,\n                                desc,\n                                is_train=False,\n                                accumulate_grad_batches=None):\n        \"\"\"Prepares model input batches from examples.\n\n        Args:\n            examples (list): List of input examples.\n            per_device_batch_size (int): Batch size per device.\n            collate_fn (Callable): Function to collate the examples.\n            shuffle (bool): Whether to shuffle the examples.\n            shuffle_rng (`jax.numpy.Array`): RNG for randomness of shuffling.\n            desc (str): Description in the progress bar.\n            is_train (bool): Whether the data is for training.\n            accumulate_grad_batches (int): gradient accumulation batches.\n\n        Returns:\n            (generator): A python generator of batched model inputs.\n        \"\"\"\n        local_micro_batch_size, global_micro_batch_size = \\\n            self.get_local_global_micro_batch_size(\n                per_device_batch_size=per_device_batch_size)\n\n        examples = get_host_examples(\n            examples=examples,\n            global_micro_batch_size=global_micro_batch_size,\n            shuffle=shuffle,\n            shuffle_rng=shuffle_rng,\n            mesh=self._mesh)\n\n        if not is_train:\n            desc = f'{desc} (global_batch_size = {global_micro_batch_size})'\n        elif accumulate_grad_batches is None:\n            desc = \\\n                f'{desc} (global_micro_batch_size = {global_micro_batch_size})'\n        else:\n            desc = (f'{desc} ('\n                    f'global_micro_batch_size = {global_micro_batch_size}, '\n                    f'accumulate_grad_batches = {accumulate_grad_batches})')\n\n        return get_data_batches(\n            examples=examples,\n            batch_size=local_micro_batch_size,\n            collate_fn=collate_fn,\n            mesh=self._mesh,\n            desc=desc,\n            verbose=self._verbose)\n\n    def get_lr_schedule_fn(self,\n                           train_size,\n                           per_device_batch_size,\n                           n_epochs,\n                           learning_rate,\n                           schedule_type='linear',\n                           warmup_ratio=0.,\n                           warmup_steps=None,\n                           init_learning_rate=0.,\n                           end_learning_rate=0.):\n        \"\"\"Creates a learning rate schedule function.\n\n        Args:\n            train_size (int): Number of training examples per epoch.\n            per_device_batch_size (int): Batch size per device.\n            n_epochs (int): Number of epochs.\n            learning_rate (float): Peak learning rate.\n            schedule_type (str): Type of lr schedule, \"linear\" or \"cosine\".\n            warmup_ratio (float): Ratio of lr warmup.\n            warmup_steps (int): Number of warmup steps.\n            init_learning_rate (float): Initial learning rate before warmup.\n            end_learning_rate (float): End learning rate for the schedule.\n\n        Returns:\n            (Callable): A lr schedule function, step -> learning rate.\n        \"\"\"\n        _, global_micro_batch_size = self.get_local_global_micro_batch_size(\n            per_device_batch_size=per_device_batch_size)\n        total_train_steps = n_epochs * (train_size // global_micro_batch_size)\n\n        if warmup_steps is None:\n            warmup_steps = int(total_train_steps * warmup_ratio)\n\n        return get_lr_schedule_fn(\n            schedule_type=schedule_type,\n            total_train_steps=total_train_steps,\n            warmup_steps=warmup_steps,\n            init_learning_rate=init_learning_rate,\n            learning_rate=learning_rate,\n            end_learning_rate=end_learning_rate)\n\n    def get_sharding_rules(self, params_shape_or_params):\n        \"\"\"Get sharding rules based on the parameter shapes.\"\"\"\n        if self._mesh is None:\n            return None\n        else:\n            sharding_rules = get_sharding_rules(\n                params_shape_or_params=params_shape_or_params,\n                n_model_shards=self._mesh.shape['mp'])\n            return sharding_rules\n\n    def get_params_spec(self, params_shape_or_params, params_sharding_rules):\n        \"\"\"Generates parameter specs based on sharding rules.\"\"\"\n        return get_params_spec(\n            params_shape_or_params=params_shape_or_params,\n            params_sharding_rules=params_sharding_rules)\n\n    def get_opt_state_spec(\n            self, params_shape_or_params, params_spec, optimizer):\n        \"\"\"Get optimizer state specs\"\"\"\n        return get_opt_state_spec(\n            params_shape_or_params=params_shape_or_params,\n            params_spec=params_spec,\n            optimizer=optimizer)\n\n    def shard_params(self, params, params_spec, desc='params'):\n        \"\"\"Distributes parameters to all devices based on the provided specs.\"\"\"\n        self.log_info(info=f'Sharding {desc} ...')\n        return shard_params(\n            mesh=self._mesh, params=params, params_spec=params_spec)\n\n    def run_model_step(self, step_fn, input_args):\n        \"\"\"Executes a model step function with the provided inputs.\"\"\"\n        if self._mesh is None:\n            return step_fn(*input_args)\n        else:\n            with self._mesh:\n                return step_fn(*input_args)\n\n    def gen_rng(self):\n        \"\"\"Get a new random number generator key and update the random state.\"\"\"\n        self._rng, new_rng = jax.random.split(self._rng)\n        return new_rng\n\n    def log_info(self, info, title=None, step=None):\n        \"\"\"Logs a messages\"\"\"\n        log_info(\n            info=info,\n            title=title,\n            logger=self._logger,\n            summary_writer=self._summary_writer,\n            step=step)\n\n    def log_metrics(self, metrics, step):\n        \"\"\"Logs metrics to TensorBoard and Weights and Biases (wandb).\"\"\"\n        if self._summary_writer is not None:\n            for metric_name, value in metrics.items():\n                self._summary_writer.scalar(metric_name, value, step=step)\n\n        if self._wandb_log_fn is not None:\n            self._wandb_log_fn(metrics, step)\n\n    def save_outputs(self, outputs, desc, step):\n        \"\"\"Saves model outputs to workdir.\"\"\"\n        if self._workdir is not None and jax.process_index() == 0:\n            save_outputs(\n                workdir=self._workdir,\n                outputs=outputs,\n                desc=desc,\n                step=step,\n                logger=self._logger,\n                summary_writer=self._summary_writer)\n\n    def save_ckpt(\n            self, ckpt_dir, params, opt_state=None, float_dtype=None, **kwargs):\n        \"\"\"Saves a checkpoint to the specified directory.\n\n        Args:\n            ckpt_dir (str): Directory to save the checkpoint.\n            params (dict): Model parameters.\n            opt_state (dict): Optimizer state.\n            float_dtype (`jax.numpy.dtype`): Dtype for floating point numbers.\n            **kwargs (dict): Additional information to be saved into\n                info.json, e.g., current training step, epoch index, etc.\n        \"\"\"\n        ckpt_dir = os.path.abspath(ckpt_dir)\n        self.log_info(f'Saving ckpt to {ckpt_dir} ...')\n        save_ckpt(\n            ckpt_dir=ckpt_dir,\n            checkpointer=self._checkpointer,\n            params=params,\n            opt_state=opt_state,\n            float_dtype=float_dtype,\n            rng=self._rng,\n            **kwargs)\n        self.log_info(f'Ckpt saved into {ckpt_dir}')\n\n    def load_params_shape(self, ckpt_dir):\n        \"\"\"Loads the shape of the parameters from a checkpoint.\"\"\"\n        return load_params_shape(ckpt_dir=ckpt_dir)\n\n    def load_ckpt(self,\n                  ckpt_dir,\n                  params_sharding_rules=None,\n                  optimizer=None,\n                  float_dtype=None,\n                  load_params=True,\n                  load_opt_state=True,\n                  update_rng=False):\n        \"\"\"Loads a checkpoint from the specified directory.\n\n        Args:\n            ckpt_dir (str): Directory of the checkpoint.\n            params_sharding_rules (list[tuple]): Sharding rules for parameters.\n            optimizer (`optax.optimizer`): Optimizer for loading opt_state.\n            float_dtype (`jax.numpy.dtype`): Dtype for floating point numbers.\n            load_params (bool): Whether to load the parameters.\n            load_opt_state (bool): Whether to load the optimizer state.\n            update_rng (bool): if updating the random state of the deployer.\n\n        Returns:\n            (tuple): A tuple with the loaded checkpoint (in a dict with\n                `\"params\"` and `\"opt_state\"`) and additional information (in a\n                dict, usually including `\"steps\"`, `\"epoch_idx\"`, and `\"rng\"`).\n        \"\"\"\n        ckpt_dir = os.path.abspath(ckpt_dir)\n        self.log_info(f'Loading ckpt from {ckpt_dir} ...')\n\n        params_shape = self.load_params_shape(ckpt_dir=ckpt_dir)\n\n        specs = {}\n        if self._mesh is not None:\n            if params_sharding_rules is None:\n                params_sharding_rules = self.get_sharding_rules(\n                    params_shape_or_params=params_shape)\n\n            specs['params'] = self.get_params_spec(\n                params_shape_or_params=params_shape,\n                params_sharding_rules=params_sharding_rules)\n            if optimizer is not None:\n                specs['opt_state'] = self.get_opt_state_spec(\n                    params_shape_or_params=params_shape,\n                    params_spec=specs['params'],\n                    optimizer=optimizer)\n\n        ckpt, info = load_ckpt(\n            ckpt_dir=ckpt_dir,\n            checkpointer=self._checkpointer,\n            params_shape_or_params=params_shape,\n            optimizer=optimizer,\n            float_dtype=float_dtype,\n            mesh=self._mesh,\n            specs=specs,\n            load_params=load_params,\n            load_opt_state=load_opt_state)\n\n        for key, value in info.items():\n            if not update_rng and key == 'rng':\n                continue\n            self.log_info(f'{ckpt_dir}::{key} = {value}')\n\n        if update_rng:\n            self._rng = info['rng']\n            self.log_info(f'rng updated to {self._rng} (by {ckpt_dir})')\n\n        return ckpt, info\n\n    def load_last_ckpt(self,\n                       optimizer=None,\n                       params_sharding_rules=None,\n                       float_dtype=None,\n                       load_params=True,\n                       load_opt_state=True,\n                       update_rng=True):\n        \"\"\"Loads the last checkpoint from the work directory (self.workdir).\n        See load_ckpt() for the explanation of arguments.\n        \"\"\"\n        try:\n            last_ckpt_name = open(\n                f'{self._workdir}/ckpts/last_ckpt.txt').read().strip()\n        except:\n            self.log_info(\n                f'{self._workdir}/ckpts/last_ckpt.txt not found. '\n                f'no ckpt loaded.')\n            return None, None\n\n        return self.load_ckpt(\n            ckpt_dir=f'{self._workdir}/ckpts/{last_ckpt_name}',\n            optimizer=optimizer,\n            float_dtype=float_dtype,\n            params_sharding_rules=params_sharding_rules,\n            load_params=load_params,\n            load_opt_state=load_opt_state,\n            update_rng=update_rng)\n\n    @property\n    def mesh(self):\n        \"\"\"Returns the mesh for model sharding\"\"\"\n        return self._mesh\n\n    @property\n    def workdir(self):\n        \"\"\"Returns the work directory.\"\"\"\n        return self._workdir\n
    "},{"location":"deployer/#redco.deployers.deployer.Deployer.mesh","title":"mesh property","text":"

    Returns the mesh for model sharding

    "},{"location":"deployer/#redco.deployers.deployer.Deployer.workdir","title":"workdir property","text":"

    Returns the work directory.

    "},{"location":"deployer/#redco.deployers.deployer.Deployer.__init__","title":"__init__(jax_seed, n_model_shards=1, verbose=True, workdir=None, n_processes=None, host0_address=None, host0_port=None, process_id=None, n_local_devices=None, run_tensorboard=False, wandb_init_kwargs=None)","text":"

    Initializes a Deployer.

    Parameters:

    Name Type Description Default jax_seed `jax.numpy.Array`

    Seed for random number generation.

    required n_model_shards int

    Number of shards for running large model.

    1 verbose bool

    Whether to enable verbose logging.

    True workdir str

    Directory for saving logs and checkpoints.

    None n_processes int

    For multi-host, number of processes/nodes.

    None host0_address str

    For multi-host, address of the host0.

    None host0_port int

    For multi-host, port of the host0.

    None process_id int

    For multi-host, index of the current process.

    None n_local_devices int

    For multi-host, number of local devices.

    None run_tensorboard bool

    Whether to enable TensorBoard logging.

    False wandb_init_kwargs dict

    wandb.init arguments if using wandb.

    None Source code in redco/deployers/deployer.py
    def __init__(self,\n             jax_seed,\n             n_model_shards=1,\n             verbose=True,\n             workdir=None,\n             n_processes=None,\n             host0_address=None,\n             host0_port=None,\n             process_id=None,\n             n_local_devices=None,\n             run_tensorboard=False,\n             wandb_init_kwargs=None):\n    \"\"\" Initializes a Deployer.\n\n    Args:\n        jax_seed (`jax.numpy.Array`): Seed for random number generation.\n        n_model_shards (int): Number of shards for running large model.\n        verbose (bool): Whether to enable verbose logging.\n        workdir (str):  Directory for saving logs and checkpoints.\n        n_processes (int):  For multi-host, number of processes/nodes.\n        host0_address (str):  For multi-host, address of the host0.\n        host0_port (int): For multi-host, port of the host0.\n        process_id (int): For multi-host, index of the current process.\n        n_local_devices (int): For multi-host, number of local devices.\n        run_tensorboard (bool):  Whether to enable TensorBoard logging.\n        wandb_init_kwargs (dict): wandb.init arguments if using wandb.\n    \"\"\"\n    if n_processes is None:\n        if 'SLURM_JOB_NUM_NODES' in os.environ:\n            n_processes = int(os.environ['SLURM_JOB_NUM_NODES'])\n            process_id = int(os.environ['SLURM_NODEID'])\n        else:\n            n_processes = 1\n\n    if n_processes > 1:\n        local_device_ids = None if n_local_devices is None \\\n            else list(range(n_local_devices))\n\n        if host0_port is None:\n            host0_port = DEFAULT_HOST0_PORT\n\n        jax.distributed.initialize(\n            coordinator_address=f'{host0_address}:{host0_port}',\n            num_processes=n_processes,\n            process_id=process_id,\n            local_device_ids=local_device_ids)\n\n    if workdir is not None:\n        os.makedirs(workdir, exist_ok=True)\n\n    self._verbose = verbose\n    self._workdir = workdir\n    self._logger = get_logger(verbose=verbose, workdir=workdir)\n\n    if wandb_init_kwargs is not None and jax.process_index() == 0:\n        import wandb\n        wandb.init(**wandb_init_kwargs)\n        self._wandb_log_fn = wandb.log\n    else:\n        self._wandb_log_fn = None\n\n    if run_tensorboard and jax.process_index() == 0:\n        from flax.metrics import tensorboard\n        self._summary_writer = tensorboard.SummaryWriter(workdir)\n    else:\n        self._summary_writer = None\n\n    self.log_info(\n        f'Local Devices: {jax.local_device_count()} / {jax.device_count()}')\n\n    self._rng = jax.random.PRNGKey(seed=jax_seed)\n    self._mesh = get_mesh(n_model_shards=n_model_shards)\n    self._checkpointer = ocp.PyTreeCheckpointer()\n
    "},{"location":"deployer/#redco.deployers.deployer.Deployer.gen_rng","title":"gen_rng()","text":"

    Get a new random number generator key and update the random state.

    Source code in redco/deployers/deployer.py
    def gen_rng(self):\n    \"\"\"Get a new random number generator key and update the random state.\"\"\"\n    self._rng, new_rng = jax.random.split(self._rng)\n    return new_rng\n
    "},{"location":"deployer/#redco.deployers.deployer.Deployer.get_accumulate_grad_batches","title":"get_accumulate_grad_batches(global_batch_size, per_device_batch_size)","text":"

    Calculates the number of gradient accumulation batches.

    Source code in redco/deployers/deployer.py
    def get_accumulate_grad_batches(\n        self, global_batch_size, per_device_batch_size):\n    \"\"\"Calculates the number of gradient accumulation batches.\"\"\"\n    _, global_micro_batch_size = self.get_local_global_micro_batch_size(\n        per_device_batch_size=per_device_batch_size)\n    assert global_batch_size % global_micro_batch_size == 0\n    accumulate_grad_batches = global_batch_size // global_micro_batch_size\n\n    return accumulate_grad_batches\n
    "},{"location":"deployer/#redco.deployers.deployer.Deployer.get_local_global_micro_batch_size","title":"get_local_global_micro_batch_size(per_device_batch_size)","text":"

    Get local/global micro batch sizes based on per-device batch size.

    Source code in redco/deployers/deployer.py
    def get_local_global_micro_batch_size(self, per_device_batch_size):\n    \"\"\"Get local/global micro batch sizes based on per-device batch size.\"\"\"\n    if self._mesh is None:\n        local_micro_batch_size = \\\n            per_device_batch_size * jax.local_device_count()\n        global_micro_batch_size = \\\n            local_micro_batch_size * jax.process_count()\n    else:\n        global_micro_batch_size = local_micro_batch_size = \\\n            per_device_batch_size * self._mesh.shape['dp']\n\n    return local_micro_batch_size, global_micro_batch_size\n
    "},{"location":"deployer/#redco.deployers.deployer.Deployer.get_lr_schedule_fn","title":"get_lr_schedule_fn(train_size, per_device_batch_size, n_epochs, learning_rate, schedule_type='linear', warmup_ratio=0.0, warmup_steps=None, init_learning_rate=0.0, end_learning_rate=0.0)","text":"

    Creates a learning rate schedule function.

    Parameters:

    Name Type Description Default train_size int

    Number of training examples per epoch.

    required per_device_batch_size int

    Batch size per device.

    required n_epochs int

    Number of epochs.

    required learning_rate float

    Peak learning rate.

    required schedule_type str

    Type of lr schedule, \"linear\" or \"cosine\".

    'linear' warmup_ratio float

    Ratio of lr warmup.

    0.0 warmup_steps int

    Number of warmup steps.

    None init_learning_rate float

    Initial learning rate before warmup.

    0.0 end_learning_rate float

    End learning rate for the schedule.

    0.0

    Returns:

    Type Description Callable

    A lr schedule function, step -> learning rate.

    Source code in redco/deployers/deployer.py
    def get_lr_schedule_fn(self,\n                       train_size,\n                       per_device_batch_size,\n                       n_epochs,\n                       learning_rate,\n                       schedule_type='linear',\n                       warmup_ratio=0.,\n                       warmup_steps=None,\n                       init_learning_rate=0.,\n                       end_learning_rate=0.):\n    \"\"\"Creates a learning rate schedule function.\n\n    Args:\n        train_size (int): Number of training examples per epoch.\n        per_device_batch_size (int): Batch size per device.\n        n_epochs (int): Number of epochs.\n        learning_rate (float): Peak learning rate.\n        schedule_type (str): Type of lr schedule, \"linear\" or \"cosine\".\n        warmup_ratio (float): Ratio of lr warmup.\n        warmup_steps (int): Number of warmup steps.\n        init_learning_rate (float): Initial learning rate before warmup.\n        end_learning_rate (float): End learning rate for the schedule.\n\n    Returns:\n        (Callable): A lr schedule function, step -> learning rate.\n    \"\"\"\n    _, global_micro_batch_size = self.get_local_global_micro_batch_size(\n        per_device_batch_size=per_device_batch_size)\n    total_train_steps = n_epochs * (train_size // global_micro_batch_size)\n\n    if warmup_steps is None:\n        warmup_steps = int(total_train_steps * warmup_ratio)\n\n    return get_lr_schedule_fn(\n        schedule_type=schedule_type,\n        total_train_steps=total_train_steps,\n        warmup_steps=warmup_steps,\n        init_learning_rate=init_learning_rate,\n        learning_rate=learning_rate,\n        end_learning_rate=end_learning_rate)\n
    "},{"location":"deployer/#redco.deployers.deployer.Deployer.get_model_input_batches","title":"get_model_input_batches(examples, per_device_batch_size, collate_fn, shuffle, shuffle_rng, desc, is_train=False, accumulate_grad_batches=None)","text":"

    Prepares model input batches from examples.

    Parameters:

    Name Type Description Default examples list

    List of input examples.

    required per_device_batch_size int

    Batch size per device.

    required collate_fn Callable

    Function to collate the examples.

    required shuffle bool

    Whether to shuffle the examples.

    required shuffle_rng `jax.numpy.Array`

    RNG for randomness of shuffling.

    required desc str

    Description in the progress bar.

    required is_train bool

    Whether the data is for training.

    False accumulate_grad_batches int

    gradient accumulation batches.

    None

    Returns:

    Type Description generator

    A python generator of batched model inputs.

    Source code in redco/deployers/deployer.py
    def get_model_input_batches(self,\n                            examples,\n                            per_device_batch_size,\n                            collate_fn,\n                            shuffle,\n                            shuffle_rng,\n                            desc,\n                            is_train=False,\n                            accumulate_grad_batches=None):\n    \"\"\"Prepares model input batches from examples.\n\n    Args:\n        examples (list): List of input examples.\n        per_device_batch_size (int): Batch size per device.\n        collate_fn (Callable): Function to collate the examples.\n        shuffle (bool): Whether to shuffle the examples.\n        shuffle_rng (`jax.numpy.Array`): RNG for randomness of shuffling.\n        desc (str): Description in the progress bar.\n        is_train (bool): Whether the data is for training.\n        accumulate_grad_batches (int): gradient accumulation batches.\n\n    Returns:\n        (generator): A python generator of batched model inputs.\n    \"\"\"\n    local_micro_batch_size, global_micro_batch_size = \\\n        self.get_local_global_micro_batch_size(\n            per_device_batch_size=per_device_batch_size)\n\n    examples = get_host_examples(\n        examples=examples,\n        global_micro_batch_size=global_micro_batch_size,\n        shuffle=shuffle,\n        shuffle_rng=shuffle_rng,\n        mesh=self._mesh)\n\n    if not is_train:\n        desc = f'{desc} (global_batch_size = {global_micro_batch_size})'\n    elif accumulate_grad_batches is None:\n        desc = \\\n            f'{desc} (global_micro_batch_size = {global_micro_batch_size})'\n    else:\n        desc = (f'{desc} ('\n                f'global_micro_batch_size = {global_micro_batch_size}, '\n                f'accumulate_grad_batches = {accumulate_grad_batches})')\n\n    return get_data_batches(\n        examples=examples,\n        batch_size=local_micro_batch_size,\n        collate_fn=collate_fn,\n        mesh=self._mesh,\n        desc=desc,\n        verbose=self._verbose)\n
    "},{"location":"deployer/#redco.deployers.deployer.Deployer.get_opt_state_spec","title":"get_opt_state_spec(params_shape_or_params, params_spec, optimizer)","text":"

    Get optimizer state specs

    Source code in redco/deployers/deployer.py
    def get_opt_state_spec(\n        self, params_shape_or_params, params_spec, optimizer):\n    \"\"\"Get optimizer state specs\"\"\"\n    return get_opt_state_spec(\n        params_shape_or_params=params_shape_or_params,\n        params_spec=params_spec,\n        optimizer=optimizer)\n
    "},{"location":"deployer/#redco.deployers.deployer.Deployer.get_params_spec","title":"get_params_spec(params_shape_or_params, params_sharding_rules)","text":"

    Generates parameter specs based on sharding rules.

    Source code in redco/deployers/deployer.py
    def get_params_spec(self, params_shape_or_params, params_sharding_rules):\n    \"\"\"Generates parameter specs based on sharding rules.\"\"\"\n    return get_params_spec(\n        params_shape_or_params=params_shape_or_params,\n        params_sharding_rules=params_sharding_rules)\n
    "},{"location":"deployer/#redco.deployers.deployer.Deployer.get_sharding_rules","title":"get_sharding_rules(params_shape_or_params)","text":"

    Get sharding rules based on the parameter shapes.

    Source code in redco/deployers/deployer.py
    def get_sharding_rules(self, params_shape_or_params):\n    \"\"\"Get sharding rules based on the parameter shapes.\"\"\"\n    if self._mesh is None:\n        return None\n    else:\n        sharding_rules = get_sharding_rules(\n            params_shape_or_params=params_shape_or_params,\n            n_model_shards=self._mesh.shape['mp'])\n        return sharding_rules\n
    "},{"location":"deployer/#redco.deployers.deployer.Deployer.load_ckpt","title":"load_ckpt(ckpt_dir, params_sharding_rules=None, optimizer=None, float_dtype=None, load_params=True, load_opt_state=True, update_rng=False)","text":"

    Loads a checkpoint from the specified directory.

    Parameters:

    Name Type Description Default ckpt_dir str

    Directory of the checkpoint.

    required params_sharding_rules list[tuple]

    Sharding rules for parameters.

    None optimizer `optax.optimizer`

    Optimizer for loading opt_state.

    None float_dtype `jax.numpy.dtype`

    Dtype for floating point numbers.

    None load_params bool

    Whether to load the parameters.

    True load_opt_state bool

    Whether to load the optimizer state.

    True update_rng bool

    if updating the random state of the deployer.

    False

    Returns:

    Type Description tuple

    A tuple with the loaded checkpoint (in a dict with \"params\" and \"opt_state\") and additional information (in a dict, usually including \"steps\", \"epoch_idx\", and \"rng\").

    Source code in redco/deployers/deployer.py
    def load_ckpt(self,\n              ckpt_dir,\n              params_sharding_rules=None,\n              optimizer=None,\n              float_dtype=None,\n              load_params=True,\n              load_opt_state=True,\n              update_rng=False):\n    \"\"\"Loads a checkpoint from the specified directory.\n\n    Args:\n        ckpt_dir (str): Directory of the checkpoint.\n        params_sharding_rules (list[tuple]): Sharding rules for parameters.\n        optimizer (`optax.optimizer`): Optimizer for loading opt_state.\n        float_dtype (`jax.numpy.dtype`): Dtype for floating point numbers.\n        load_params (bool): Whether to load the parameters.\n        load_opt_state (bool): Whether to load the optimizer state.\n        update_rng (bool): if updating the random state of the deployer.\n\n    Returns:\n        (tuple): A tuple with the loaded checkpoint (in a dict with\n            `\"params\"` and `\"opt_state\"`) and additional information (in a\n            dict, usually including `\"steps\"`, `\"epoch_idx\"`, and `\"rng\"`).\n    \"\"\"\n    ckpt_dir = os.path.abspath(ckpt_dir)\n    self.log_info(f'Loading ckpt from {ckpt_dir} ...')\n\n    params_shape = self.load_params_shape(ckpt_dir=ckpt_dir)\n\n    specs = {}\n    if self._mesh is not None:\n        if params_sharding_rules is None:\n            params_sharding_rules = self.get_sharding_rules(\n                params_shape_or_params=params_shape)\n\n        specs['params'] = self.get_params_spec(\n            params_shape_or_params=params_shape,\n            params_sharding_rules=params_sharding_rules)\n        if optimizer is not None:\n            specs['opt_state'] = self.get_opt_state_spec(\n                params_shape_or_params=params_shape,\n                params_spec=specs['params'],\n                optimizer=optimizer)\n\n    ckpt, info = load_ckpt(\n        ckpt_dir=ckpt_dir,\n        checkpointer=self._checkpointer,\n        params_shape_or_params=params_shape,\n        optimizer=optimizer,\n        float_dtype=float_dtype,\n        mesh=self._mesh,\n        specs=specs,\n        load_params=load_params,\n        load_opt_state=load_opt_state)\n\n    for key, value in info.items():\n        if not update_rng and key == 'rng':\n            continue\n        self.log_info(f'{ckpt_dir}::{key} = {value}')\n\n    if update_rng:\n        self._rng = info['rng']\n        self.log_info(f'rng updated to {self._rng} (by {ckpt_dir})')\n\n    return ckpt, info\n
    "},{"location":"deployer/#redco.deployers.deployer.Deployer.load_last_ckpt","title":"load_last_ckpt(optimizer=None, params_sharding_rules=None, float_dtype=None, load_params=True, load_opt_state=True, update_rng=True)","text":"

    Loads the last checkpoint from the work directory (self.workdir). See load_ckpt() for the explanation of arguments.

    Source code in redco/deployers/deployer.py
    def load_last_ckpt(self,\n                   optimizer=None,\n                   params_sharding_rules=None,\n                   float_dtype=None,\n                   load_params=True,\n                   load_opt_state=True,\n                   update_rng=True):\n    \"\"\"Loads the last checkpoint from the work directory (self.workdir).\n    See load_ckpt() for the explanation of arguments.\n    \"\"\"\n    try:\n        last_ckpt_name = open(\n            f'{self._workdir}/ckpts/last_ckpt.txt').read().strip()\n    except:\n        self.log_info(\n            f'{self._workdir}/ckpts/last_ckpt.txt not found. '\n            f'no ckpt loaded.')\n        return None, None\n\n    return self.load_ckpt(\n        ckpt_dir=f'{self._workdir}/ckpts/{last_ckpt_name}',\n        optimizer=optimizer,\n        float_dtype=float_dtype,\n        params_sharding_rules=params_sharding_rules,\n        load_params=load_params,\n        load_opt_state=load_opt_state,\n        update_rng=update_rng)\n
    "},{"location":"deployer/#redco.deployers.deployer.Deployer.load_params_shape","title":"load_params_shape(ckpt_dir)","text":"

    Loads the shape of the parameters from a checkpoint.

    Source code in redco/deployers/deployer.py
    def load_params_shape(self, ckpt_dir):\n    \"\"\"Loads the shape of the parameters from a checkpoint.\"\"\"\n    return load_params_shape(ckpt_dir=ckpt_dir)\n
    "},{"location":"deployer/#redco.deployers.deployer.Deployer.log_info","title":"log_info(info, title=None, step=None)","text":"

    Logs a messages

    Source code in redco/deployers/deployer.py
    def log_info(self, info, title=None, step=None):\n    \"\"\"Logs a messages\"\"\"\n    log_info(\n        info=info,\n        title=title,\n        logger=self._logger,\n        summary_writer=self._summary_writer,\n        step=step)\n
    "},{"location":"deployer/#redco.deployers.deployer.Deployer.log_metrics","title":"log_metrics(metrics, step)","text":"

    Logs metrics to TensorBoard and Weights and Biases (wandb).

    Source code in redco/deployers/deployer.py
    def log_metrics(self, metrics, step):\n    \"\"\"Logs metrics to TensorBoard and Weights and Biases (wandb).\"\"\"\n    if self._summary_writer is not None:\n        for metric_name, value in metrics.items():\n            self._summary_writer.scalar(metric_name, value, step=step)\n\n    if self._wandb_log_fn is not None:\n        self._wandb_log_fn(metrics, step)\n
    "},{"location":"deployer/#redco.deployers.deployer.Deployer.run_model_step","title":"run_model_step(step_fn, input_args)","text":"

    Executes a model step function with the provided inputs.

    Source code in redco/deployers/deployer.py
    def run_model_step(self, step_fn, input_args):\n    \"\"\"Executes a model step function with the provided inputs.\"\"\"\n    if self._mesh is None:\n        return step_fn(*input_args)\n    else:\n        with self._mesh:\n            return step_fn(*input_args)\n
    "},{"location":"deployer/#redco.deployers.deployer.Deployer.save_ckpt","title":"save_ckpt(ckpt_dir, params, opt_state=None, float_dtype=None, **kwargs)","text":"

    Saves a checkpoint to the specified directory.

    Parameters:

    Name Type Description Default ckpt_dir str

    Directory to save the checkpoint.

    required params dict

    Model parameters.

    required opt_state dict

    Optimizer state.

    None float_dtype `jax.numpy.dtype`

    Dtype for floating point numbers.

    None **kwargs dict

    Additional information to be saved into info.json, e.g., current training step, epoch index, etc.

    {} Source code in redco/deployers/deployer.py
    def save_ckpt(\n        self, ckpt_dir, params, opt_state=None, float_dtype=None, **kwargs):\n    \"\"\"Saves a checkpoint to the specified directory.\n\n    Args:\n        ckpt_dir (str): Directory to save the checkpoint.\n        params (dict): Model parameters.\n        opt_state (dict): Optimizer state.\n        float_dtype (`jax.numpy.dtype`): Dtype for floating point numbers.\n        **kwargs (dict): Additional information to be saved into\n            info.json, e.g., current training step, epoch index, etc.\n    \"\"\"\n    ckpt_dir = os.path.abspath(ckpt_dir)\n    self.log_info(f'Saving ckpt to {ckpt_dir} ...')\n    save_ckpt(\n        ckpt_dir=ckpt_dir,\n        checkpointer=self._checkpointer,\n        params=params,\n        opt_state=opt_state,\n        float_dtype=float_dtype,\n        rng=self._rng,\n        **kwargs)\n    self.log_info(f'Ckpt saved into {ckpt_dir}')\n
    "},{"location":"deployer/#redco.deployers.deployer.Deployer.save_outputs","title":"save_outputs(outputs, desc, step)","text":"

    Saves model outputs to workdir.

    Source code in redco/deployers/deployer.py
    def save_outputs(self, outputs, desc, step):\n    \"\"\"Saves model outputs to workdir.\"\"\"\n    if self._workdir is not None and jax.process_index() == 0:\n        save_outputs(\n            workdir=self._workdir,\n            outputs=outputs,\n            desc=desc,\n            step=step,\n            logger=self._logger,\n            summary_writer=self._summary_writer)\n
    "},{"location":"deployer/#redco.deployers.deployer.Deployer.shard_params","title":"shard_params(params, params_spec, desc='params')","text":"

    Distributes parameters to all devices based on the provided specs.

    Source code in redco/deployers/deployer.py
    def shard_params(self, params, params_spec, desc='params'):\n    \"\"\"Distributes parameters to all devices based on the provided specs.\"\"\"\n    self.log_info(info=f'Sharding {desc} ...')\n    return shard_params(\n        mesh=self._mesh, params=params, params_spec=params_spec)\n
    "},{"location":"mnist/","title":"MNIST Example","text":""},{"location":"mnist/#mnist","title":"MNIST","text":"

    A trivial MNIST example with RedCoast. Runnable by

    python main.py\n

    To simulate multiple devices in cpu-only envs,

    XLA_FLAGS=\"--xla_force_host_platform_device_count=8\" python main.py\n

    "},{"location":"mnist/#source-code","title":"Source Code","text":"
    from functools import partial\nimport fire\nimport numpy as np\nfrom flax import linen as nn\nimport optax\nfrom torchvision.datasets import MNIST\nfrom redco import Deployer, Trainer, Predictor\n\n\n# A simple CNN model \n# Copied from https://github.com/google/flax/blob/main/examples/mnist/train.py\nclass CNN(nn.Module):\n    @nn.compact\n    def __call__(self, x):\n        x = nn.Conv(features=32, kernel_size=(3, 3))(x)\n        x = nn.relu(x)\n        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n        x = nn.Conv(features=64, kernel_size=(3, 3))(x)\n        x = nn.relu(x)\n        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n        x = x.reshape((x.shape[0], -1))  # flatten\n        x = nn.Dense(features=256)(x)\n        x = nn.relu(x)\n        x = nn.Dense(features=10)(x)\n        return x\n\n\n# Collate function converting a batch of raw examples to model inputs (in numpy) \ndef collate_fn(examples):\n    images = np.stack(\n        [np.array(example['image'])[:, :, None] for example in examples])\n    labels = np.array([example['label'] for example in examples])\n\n    return {'images': images, 'labels': labels}\n\n\n# Loss function converting model inputs to a scalar loss\ndef loss_fn(train_rng, state, params, batch, is_training):\n    logits = state.apply_fn({'params': params}, batch['images'])\n    return optax.softmax_cross_entropy_with_integer_labels(\n        logits=logits, labels=batch['labels']).mean()\n\n\n# Predict function converting model inputs to the model outputs\ndef pred_fn(pred_rng, params, batch, model):\n    accs = model.apply({'params': params}, batch['images']).argmax(axis=-1)\n    return {'acc': accs}\n\n\n# (Optional) Evaluation function in trainer.fit. Here it computes accuracy.\ndef eval_metric_fn(examples, preds):\n    preds = np.array([pred['acc'] for pred in preds])\n    labels = np.array([example['label'] for example in examples])\n    return {'acc': np.mean(preds == labels).item()}\n\n\ndef main(per_device_batch_size=64, learning_rate=1e-3, jax_seed=42):\n    deployer = Deployer(jax_seed=jax_seed, workdir='./workdir')\n\n    dataset = {\n        'train': [{'image': t[0], 'label': t[1]} for t in list(\n            MNIST('./data', train=True, download=True))],\n        'test': [{'image': t[0], 'label': t[1]} for t in list(\n            MNIST('./data', train=False, download=True))],\n    }\n\n    model = CNN()\n    dummy_batch = collate_fn(examples=[dataset['train'][0]])\n    params = model.init(deployer.gen_rng(), dummy_batch['images'])['params']\n\n    trainer = Trainer(\n        deployer=deployer,\n        collate_fn=collate_fn,\n        apply_fn=model.apply,\n        loss_fn=loss_fn,\n        params=params,\n        optimizer=optax.adamw(learning_rate=learning_rate))\n\n    predictor = Predictor(\n        deployer=deployer,\n        collate_fn=collate_fn,\n        pred_fn=partial(pred_fn, model=model))\n\n    trainer.fit(\n        train_examples=dataset['train'],\n        per_device_batch_size=per_device_batch_size,\n        n_epochs=2,\n        eval_examples=dataset['test'],\n        eval_predictor=predictor,\n        eval_metric_fn=eval_metric_fn)\n\n\nif __name__ == '__main__':\n    fire.Fire(main)\n
    "},{"location":"predictor/","title":"Predictor","text":""},{"location":"predictor/#redco.predictors.predictor.Predictor","title":"Predictor","text":"

    Predictor class managing distributed inference process.

    Attributes:

    Name Type Description mesh `jax.sharding.Mesh`

    Mesh used for distributed inference.

    Source code in redco/predictors/predictor.py
    class Predictor:\n    \"\"\"Predictor class managing distributed inference process.\n\n    Attributes:\n        mesh (`jax.sharding.Mesh`): Mesh used for distributed inference.\n    \"\"\"\n    def __init__(self,\n                 deployer,\n                 collate_fn,\n                 pred_fn,\n                 output_fn=None,\n                 params_sharding_rules=None):\n        \"\"\"Initializes a Predictor instance.\n\n        Args:\n            deployer (`redco.Deployer`): A deployer for low-level operations.\n            collate_fn (Callable): A function converting a data batch to model inputs,\n                e.g., tokenizing sentences into input_ids.\n            pred_fn (Callable): A function to produce model outputs with model inputs,\n                e.g., running beam search with a language model.\n            output_fn (Callable): A function finalizing model outputs (on CPU),\n                e.g., decoding generated ids to text.\n            params_sharding_rules (list[tuple]): Rules for sharding parameters.\n        \"\"\"\n        self._deployer = deployer\n        self._collate_fn = partial(collate_fn_wrapper, collate_fn=collate_fn)\n        self._params_sharding_rules = params_sharding_rules\n        self._pred_fn = partial(pred_fn_wrapper, pred_fn=pred_fn)\n        self._p_pred_step = None\n\n        if output_fn is None:\n            self._output_fn = default_output_fn\n        else:\n            self._output_fn = output_fn\n\n    def setup_running_step(self, dummy_batch, params_shape_or_params):\n        \"\"\"Sets up the prediction step function for distributed inference.\n\n        Args:\n            dummy_batch (PyTree): A dummy batch used to determine data shapes.\n            params_shape_or_params (dict): The shape of params or actual params.\n        \"\"\"\n        pred_step_fn = partial(pred_step, pred_fn=self._pred_fn, mesh=self.mesh)\n\n        if self.mesh is None:\n            self._p_pred_step = jax.pmap(pred_step_fn, axis_name='dp')\n        else:\n            data_spec = jax.tree.map(lambda x: P('dp'), dummy_batch)\n            params_spec = self._deployer.get_params_spec(\n                params_shape_or_params=params_shape_or_params,\n                params_sharding_rules=self._params_sharding_rules)\n            self._p_pred_step = pjit(\n                pred_step_fn,\n                in_shardings=(None, params_spec, data_spec),\n                out_shardings=None)\n\n    def predict(self,\n                examples,\n                per_device_batch_size,\n                params,\n                params_replicated=False,\n                params_sharded=False,\n                desc=None):\n        \"\"\"Runs distributed prediction on a list of examples.\n\n        Args:\n            examples (list): Input examples for prediction.\n            per_device_batch_size (int): Batch size per device.\n            params (dict): Model parameters in a dict/FrozenDict.\n            params_replicated (bool): if the params are already replicated.\n            params_sharded (bool): if the parameters are already sharded.\n            desc (str): Description to show in the progress bar.\n\n        Returns:\n            (list): A list of predictions corresponding to the input examples.\n        \"\"\"\n        raw_n_inputs = len(examples)\n        _, global_micro_batch_size = \\\n            self._deployer.get_local_global_micro_batch_size(\n                per_device_batch_size=per_device_batch_size)\n        examples = examples + [examples[0]] * (global_micro_batch_size - 1)\n        examples = add_idxes(examples=examples)\n\n        data_batches = self._deployer.get_model_input_batches(\n            examples=examples,\n            per_device_batch_size=per_device_batch_size,\n            collate_fn=self._collate_fn,\n            shuffle=False,\n            shuffle_rng=None,\n            desc=f'Predicting ({desc})' if desc is not None else 'Predicting')\n\n        params = freeze(params)\n        if (self.mesh is None) and (not params_replicated):\n            params = replicate(params)\n        if (self.mesh is not None) and (not params_sharded):\n            params_spec = self._deployer.get_params_spec(\n                params_shape_or_params=params,\n                params_sharding_rules=self._params_sharding_rules)\n            params = self._deployer.shard_params(\n                params=params, params_spec=params_spec)\n\n        preds = []\n        for batch in data_batches:\n            if self._p_pred_step is None:\n                self.setup_running_step(\n                    dummy_batch=batch, params_shape_or_params=params)\n\n            pred_rng = self._deployer.gen_rng()\n            if self.mesh is None:\n                pred_rng = jax.random.split(\n                    pred_rng, num=jax.process_count())[jax.process_index()]\n                pred_rng = shard_prng_key(pred_rng)\n\n            batch_preds_with_idxes = self._deployer.run_model_step(\n                step_fn=self._p_pred_step,\n                input_args=(pred_rng, params, batch))\n            batch_preds = process_batch_preds(\n                batch_preds_with_idxes=batch_preds_with_idxes, mesh=self.mesh)\n            batch_preds = self._output_fn(batch_preds)\n\n            assert isinstance(batch_preds, list) and \\\n                   len(batch_preds) == global_micro_batch_size\n            preds.extend(batch_preds)\n\n        return preds[:raw_n_inputs]\n\n    @property\n    def mesh(self):\n        \"\"\"Returns the mesh used for distributed inference.\"\"\"\n        return self._deployer.mesh\n
    "},{"location":"predictor/#redco.predictors.predictor.Predictor.mesh","title":"mesh property","text":"

    Returns the mesh used for distributed inference.

    "},{"location":"predictor/#redco.predictors.predictor.Predictor.__init__","title":"__init__(deployer, collate_fn, pred_fn, output_fn=None, params_sharding_rules=None)","text":"

    Initializes a Predictor instance.

    Parameters:

    Name Type Description Default deployer `redco.Deployer`

    A deployer for low-level operations.

    required collate_fn Callable

    A function converting a data batch to model inputs, e.g., tokenizing sentences into input_ids.

    required pred_fn Callable

    A function to produce model outputs with model inputs, e.g., running beam search with a language model.

    required output_fn Callable

    A function finalizing model outputs (on CPU), e.g., decoding generated ids to text.

    None params_sharding_rules list[tuple]

    Rules for sharding parameters.

    None Source code in redco/predictors/predictor.py
    def __init__(self,\n             deployer,\n             collate_fn,\n             pred_fn,\n             output_fn=None,\n             params_sharding_rules=None):\n    \"\"\"Initializes a Predictor instance.\n\n    Args:\n        deployer (`redco.Deployer`): A deployer for low-level operations.\n        collate_fn (Callable): A function converting a data batch to model inputs,\n            e.g., tokenizing sentences into input_ids.\n        pred_fn (Callable): A function to produce model outputs with model inputs,\n            e.g., running beam search with a language model.\n        output_fn (Callable): A function finalizing model outputs (on CPU),\n            e.g., decoding generated ids to text.\n        params_sharding_rules (list[tuple]): Rules for sharding parameters.\n    \"\"\"\n    self._deployer = deployer\n    self._collate_fn = partial(collate_fn_wrapper, collate_fn=collate_fn)\n    self._params_sharding_rules = params_sharding_rules\n    self._pred_fn = partial(pred_fn_wrapper, pred_fn=pred_fn)\n    self._p_pred_step = None\n\n    if output_fn is None:\n        self._output_fn = default_output_fn\n    else:\n        self._output_fn = output_fn\n
    "},{"location":"predictor/#redco.predictors.predictor.Predictor.predict","title":"predict(examples, per_device_batch_size, params, params_replicated=False, params_sharded=False, desc=None)","text":"

    Runs distributed prediction on a list of examples.

    Parameters:

    Name Type Description Default examples list

    Input examples for prediction.

    required per_device_batch_size int

    Batch size per device.

    required params dict

    Model parameters in a dict/FrozenDict.

    required params_replicated bool

    if the params are already replicated.

    False params_sharded bool

    if the parameters are already sharded.

    False desc str

    Description to show in the progress bar.

    None

    Returns:

    Type Description list

    A list of predictions corresponding to the input examples.

    Source code in redco/predictors/predictor.py
    def predict(self,\n            examples,\n            per_device_batch_size,\n            params,\n            params_replicated=False,\n            params_sharded=False,\n            desc=None):\n    \"\"\"Runs distributed prediction on a list of examples.\n\n    Args:\n        examples (list): Input examples for prediction.\n        per_device_batch_size (int): Batch size per device.\n        params (dict): Model parameters in a dict/FrozenDict.\n        params_replicated (bool): if the params are already replicated.\n        params_sharded (bool): if the parameters are already sharded.\n        desc (str): Description to show in the progress bar.\n\n    Returns:\n        (list): A list of predictions corresponding to the input examples.\n    \"\"\"\n    raw_n_inputs = len(examples)\n    _, global_micro_batch_size = \\\n        self._deployer.get_local_global_micro_batch_size(\n            per_device_batch_size=per_device_batch_size)\n    examples = examples + [examples[0]] * (global_micro_batch_size - 1)\n    examples = add_idxes(examples=examples)\n\n    data_batches = self._deployer.get_model_input_batches(\n        examples=examples,\n        per_device_batch_size=per_device_batch_size,\n        collate_fn=self._collate_fn,\n        shuffle=False,\n        shuffle_rng=None,\n        desc=f'Predicting ({desc})' if desc is not None else 'Predicting')\n\n    params = freeze(params)\n    if (self.mesh is None) and (not params_replicated):\n        params = replicate(params)\n    if (self.mesh is not None) and (not params_sharded):\n        params_spec = self._deployer.get_params_spec(\n            params_shape_or_params=params,\n            params_sharding_rules=self._params_sharding_rules)\n        params = self._deployer.shard_params(\n            params=params, params_spec=params_spec)\n\n    preds = []\n    for batch in data_batches:\n        if self._p_pred_step is None:\n            self.setup_running_step(\n                dummy_batch=batch, params_shape_or_params=params)\n\n        pred_rng = self._deployer.gen_rng()\n        if self.mesh is None:\n            pred_rng = jax.random.split(\n                pred_rng, num=jax.process_count())[jax.process_index()]\n            pred_rng = shard_prng_key(pred_rng)\n\n        batch_preds_with_idxes = self._deployer.run_model_step(\n            step_fn=self._p_pred_step,\n            input_args=(pred_rng, params, batch))\n        batch_preds = process_batch_preds(\n            batch_preds_with_idxes=batch_preds_with_idxes, mesh=self.mesh)\n        batch_preds = self._output_fn(batch_preds)\n\n        assert isinstance(batch_preds, list) and \\\n               len(batch_preds) == global_micro_batch_size\n        preds.extend(batch_preds)\n\n    return preds[:raw_n_inputs]\n
    "},{"location":"predictor/#redco.predictors.predictor.Predictor.setup_running_step","title":"setup_running_step(dummy_batch, params_shape_or_params)","text":"

    Sets up the prediction step function for distributed inference.

    Parameters:

    Name Type Description Default dummy_batch PyTree

    A dummy batch used to determine data shapes.

    required params_shape_or_params dict

    The shape of params or actual params.

    required Source code in redco/predictors/predictor.py
    def setup_running_step(self, dummy_batch, params_shape_or_params):\n    \"\"\"Sets up the prediction step function for distributed inference.\n\n    Args:\n        dummy_batch (PyTree): A dummy batch used to determine data shapes.\n        params_shape_or_params (dict): The shape of params or actual params.\n    \"\"\"\n    pred_step_fn = partial(pred_step, pred_fn=self._pred_fn, mesh=self.mesh)\n\n    if self.mesh is None:\n        self._p_pred_step = jax.pmap(pred_step_fn, axis_name='dp')\n    else:\n        data_spec = jax.tree.map(lambda x: P('dp'), dummy_batch)\n        params_spec = self._deployer.get_params_spec(\n            params_shape_or_params=params_shape_or_params,\n            params_sharding_rules=self._params_sharding_rules)\n        self._p_pred_step = pjit(\n            pred_step_fn,\n            in_shardings=(None, params_spec, data_spec),\n            out_shardings=None)\n
    "},{"location":"trainer/","title":"Trainer","text":""},{"location":"trainer/#redco.trainers.trainer.Trainer","title":"Trainer","text":"

    Trainer class managing distributed training process.

    Attributes:

    Name Type Description step int

    Current training step.

    workdir str

    Working directory for saving checkpoints and logs.

    mesh jax Mesh

    Mesh used for distributed training.

    state flax TrainState

    Current training state.

    Source code in redco/trainers/trainer.py
    class Trainer:\n    \"\"\"Trainer class managing distributed training process.\n\n    Attributes:\n        step (int): Current training step.\n        workdir (str): Working directory for saving checkpoints and logs.\n        mesh (jax Mesh): Mesh used for distributed training.\n        state (flax TrainState): Current training state.\n    \"\"\"\n    def __init__(self,\n                 deployer,\n                 collate_fn,\n                 apply_fn,\n                 loss_fn,\n                 params,\n                 optimizer,\n                 opt_state=None,\n                 compute_dtype=jnp.float32,\n                 last_ckpt_info=None,\n                 lr_schedule_fn=None,\n                 accumulate_grad_batches=None,\n                 params_sharding_rules=None):\n        \"\"\"Initializes the Trainer with initial parameters, etc.\n\n        Args:\n            deployer (Deployer): A deployer supporting low-level operations.\n            collate_fn (Callable): The function converting a data batch to model\n                inputs, e.g., tokenizing sentences into input_ids.\n            apply_fn (Callable): The function to apply the model, such as\n                model.apply for Flax modules, or model itself for HuggingFace\n                models. It would be set as state.apply_fn, and used in loss_fn.\n            loss_fn (Callable): The loss function converting model inputs to a\n                scalar loss, e.g., computing cross-entropy loss from input_ids.\n            params (dict): Initial model parameters.\n            optimizer (`optax.optimizer`): The optimizer used for training.\n            opt_state (dict): optimizer state.\n            compute_dtype (dtype): Computation dtype, e.g., `jnp.bfloat16`,\n                independent of param dtypes. (for mixed-precision training)\n            last_ckpt_info (dict): the beginning step and epoch.\n            lr_schedule_fn (Callable): The learning rate schedule\n                function converting step to learning rate.\n            accumulate_grad_batches (int): Gradient accumulation step.\n            params_sharding_rules (list): Sharding rules.\n        \"\"\"\n        self._deployer = deployer\n        self._collate_fn = collate_fn\n        self._apply_fn = apply_fn\n        self._loss_fn = loss_fn\n        self._optimizer = optimizer\n        self._compute_dtype = compute_dtype\n        self._lr_schedule_fn = lr_schedule_fn\n        self._accumulate_grad_batches = accumulate_grad_batches\n        self._params_sharding_rules = params_sharding_rules\n\n        self._state = None\n        self._state_spec = None\n        self._p_train_step = None\n        self._p_eval_step = None\n\n        self._init_step = 0\n        self._init_epoch_idx = 0\n        if last_ckpt_info is not None:\n            self._init_step = last_ckpt_info.get('step', 0)\n            self._init_epoch_idx = last_ckpt_info.get('epoch_idx', -1) + 1\n\n        n_params = sum([param.size for param in jax.tree.leaves(params)])\n        self._deployer.log_info(f'{n_params:,}', title='Parameters')\n\n        self.set_train_state(\n            apply_fn=self._apply_fn,\n            params=params,\n            optimizer=self._optimizer,\n            step=self._init_step,\n            opt_state=opt_state)\n\n    def set_train_state(\n            self, apply_fn, params, optimizer, step, opt_state=None):\n        \"\"\"Sets/Resets the training state with given parameters and optimizer.\n\n        Args:\n            apply_fn (Callable): The function to apply the model.\n            params (dict): Model parameters.\n            optimizer (dict): The optimizer used for training.\n            step (int): The training step.\n            opt_state (dict): The state of the optimizer.\n        \"\"\"\n        self._deployer.log_info('Setting train_state ...')\n        params = freeze(params)\n\n        if self.mesh is None:\n            params = jax.device_put(params, jax.local_devices()[0])\n            if opt_state is None:\n                self._deployer.log_info('Initializing opt_state ...')\n                opt_state = optimizer.init(params)\n            else:\n                opt_state = jax.device_put(opt_state, jax.local_devices()[0])\n\n            self._state = train_state.TrainState(\n                step=step,\n                apply_fn=apply_fn,\n                params=params,\n                tx=optimizer,\n                opt_state=opt_state)\n            self._state = replicate(self._state)\n        else:\n            params_spec = self._deployer.get_params_spec(\n                params_shape_or_params=params,\n                params_sharding_rules=self._params_sharding_rules)\n            params = self._deployer.shard_params(\n                params=params, params_spec=params_spec)\n\n            if opt_state is None:\n                self._deployer.log_info('Initializing opt_state ...')\n                opt_state = optimizer.init(params)\n\n            opt_state_spec = self._deployer.get_opt_state_spec(\n                params_shape_or_params=params,\n                params_spec=params_spec,\n                optimizer=optimizer)\n            opt_state = self._deployer.shard_params(\n                params=opt_state,\n                params_spec=opt_state_spec,\n                desc='opt_state')\n\n            self._state = train_state.TrainState(\n                apply_fn=apply_fn,\n                params=params,\n                tx=optimizer,\n                opt_state=opt_state,\n                step=step)\n\n            self._state_spec = train_state.TrainState(\n                apply_fn=apply_fn,\n                params=params_spec,\n                tx=optimizer,\n                opt_state=opt_state_spec,\n                step=None)\n\n    def setup_running_step(self, dummy_batch):\n        \"\"\"Sets up the running step functions for training and evaluation.\n\n        Args:\n            dummy_batch (PyTree): A dummy batch of data.\n        \"\"\"\n        train_step_fn = partial(\n            train_step,\n            loss_fn=self._loss_fn,\n            lr_schedule_fn=self._lr_schedule_fn,\n            mesh=self.mesh,\n            compute_dtype=self._compute_dtype)\n        eval_step_fn = partial(\n            eval_step,\n            loss_fn=self._loss_fn,\n            mesh=self.mesh,\n            compute_dtype=self._compute_dtype)\n\n        if self.mesh is None:\n            self._p_train_step = jax.pmap(train_step_fn, axis_name='dp')\n            self._p_eval_step = jax.pmap(eval_step_fn, axis_name='dp')\n        else:\n            data_spec = jax.tree.map(lambda x: P('dp'), dummy_batch)\n            self._p_train_step = pjit(\n                train_step_fn,\n                in_shardings=(None, self._state_spec, data_spec),\n                out_shardings=(self._state_spec, None),\n                donate_argnums=(1, ))\n            self._p_eval_step = pjit(\n                eval_step_fn,\n                in_shardings=(self._state_spec, data_spec),\n                out_shardings=None)\n\n    def train(self, examples, per_device_batch_size, desc=None):\n        \"\"\"Trains the model on the provided examples.\n\n        Args:\n            examples (list): Training examples in python list.\n            per_device_batch_size (int): The batch size per device.\n            desc (str): Description in the progress bar.\n        \"\"\"\n        data_batches = self._deployer.get_model_input_batches(\n            examples=examples,\n            per_device_batch_size=per_device_batch_size,\n            collate_fn=self._collate_fn,\n            shuffle=True,\n            shuffle_rng=self._deployer.gen_rng(),\n            desc=f'Training ({desc})' if desc is not None else 'Training',\n            is_train=True,\n            accumulate_grad_batches=self._accumulate_grad_batches)\n\n        for batch in data_batches:\n            if self._p_train_step is None:\n                self.setup_running_step(dummy_batch=batch)\n\n            train_rng = self._deployer.gen_rng()\n            if self.mesh is None:\n                train_rng = jax.random.split(\n                    train_rng, num=jax.process_count())[jax.process_index()]\n                train_rng = shard_prng_key(train_rng)\n            self._state, metrics = self._deployer.run_model_step(\n                step_fn=self._p_train_step,\n                input_args=(train_rng, self._state, batch))\n\n            if self.mesh is None:\n                metrics = unreplicate(metrics)\n            data_batches.set_postfix(**metrics)\n            self._deployer.log_metrics(metrics=metrics, step=self.step)\n\n    def eval_loss(self, examples, per_device_batch_size, desc=None):\n        \"\"\"Evaluates the loss on the provided examples.\n\n        Args:\n            examples (list): Evaluation examples in list.\n            per_device_batch_size (int): The batch size per device.\n            desc (str): Description in the progress bar.\n\n        Returns:\n            (float): The average loss over the evaluation examples.\n        \"\"\"\n        data_batches = self._deployer.get_model_input_batches(\n            examples=examples,\n            per_device_batch_size=per_device_batch_size,\n            collate_fn=self._collate_fn,\n            shuffle=False,\n            shuffle_rng=None,\n            desc=f'Evaluating ({desc})' if desc is not None else 'Evaluating')\n\n        losses = []\n        for batch in data_batches:\n            if self._p_eval_step is None:\n                self.setup_running_step(dummy_batch=batch)\n\n            metrics = self._deployer.run_model_step(\n                step_fn=self._p_eval_step, input_args=(self._state, batch))\n\n            if self.mesh is None:\n                metrics = unreplicate(metrics)\n\n            losses.append(metrics['loss'].item())\n            data_batches.set_postfix(**metrics)\n\n        return np.mean(losses).item()\n\n    def fit(self,\n            train_examples,\n            per_device_batch_size,\n            n_epochs,\n            eval_examples=None,\n            eval_per_device_batch_size=None,\n            eval_loss=True,\n            eval_predictor=None,\n            eval_metric_fn=None,\n            eval_sanity_check=True,\n            save_every_ckpt=False,\n            save_last_ckpt=False,\n            save_argmin_ckpt_by_metrics=None,\n            save_argmax_ckpt_by_metrics=None,\n            save_opt_states=True,\n            save_float_dtype=None):\n        \"\"\"Fits the model on the training data for a given number of epochs,\n        optionally evaluating and saving checkpoints.\n\n        Args:\n            train_examples (list or Callable): Training examples, can be a\n                list or a function of epoch_idx (for assigning different\n                examples in separate epochs/chunks),\n                e.g., `train_examples=lambda epoch_idx: load_data(chunk_idx)`\n            per_device_batch_size (int): The batch size per device.\n            n_epochs (int): Number of epochs to train.\n            eval_examples (list): Examples for evaluation and prediction.\n            eval_per_device_batch_size (int): Batch size for evaluation\n            eval_loss (bool): Whether to evaluate loss.\n            eval_predictor (`redco.Predictor`): Predicting on `eval_examples`.\n            eval_metric_fn (Callable): Metric function for prediction.\n            eval_sanity_check (bool): if to run a sanity check for\n                evaluation & predict functions before training.\n            save_every_ckpt (bool): if to save a ckpt after every epoch.\n            save_last_ckpt (bool): Whether to save the last checkpoint.\n            save_argmin_ckpt_by_metrics (list[str]): Metrics to save checkpoints\n                based on minimum values.\n            save_argmax_ckpt_by_metrics (list[str]): Metrics to save checkpoints\n                based on maximum values.\n            save_opt_states (bool): of to save optimizer states in ckpts.\n            save_float_dtype (bool): The data type for saving checkpoints.\n        \"\"\"\n        if eval_per_device_batch_size is None:\n            eval_per_device_batch_size = per_device_batch_size\n\n        if save_argmax_ckpt_by_metrics is None:\n            save_argmax_ckpt_by_metrics = []\n        if save_argmin_ckpt_by_metrics is None:\n            save_argmin_ckpt_by_metrics = []\n        min_metrics, max_metrics = {}, {}\n\n        if os.path.exists(f'{self.workdir}/min_metrics.json'):\n            min_metrics = json.load(open(\n                f'{self.workdir}/min_metrics.json'))\n            self._deployer.log_info(\n                json.dumps(min_metrics, indent=4), title='Detected min_metrics')\n\n        if os.path.exists(f'{self.workdir}/max_metrics.json'):\n            max_metrics = json.load(open(\n                f'{self.workdir}/max_metrics.json'))\n            self._deployer.log_info(\n                json.dumps(max_metrics, indent=4), title='Detected max_metrics')\n\n        if eval_sanity_check and eval_examples is not None:\n            rng_backup = self._deployer._rng\n            _, eval_global_micro_batch_size = \\\n                self._deployer.get_local_global_micro_batch_size(\n                    per_device_batch_size=eval_per_device_batch_size)\n\n            if eval_loss:\n                self.eval_loss(\n                    examples=eval_examples[:eval_global_micro_batch_size],\n                    per_device_batch_size=eval_per_device_batch_size,\n                    desc=f'Sanity check')\n                self._deployer.log_info(\n                    'Sanity check (for evaluation loss) passed.')\n\n            if eval_predictor is not None:\n                preds = eval_predictor.predict(\n                    examples=eval_examples[:eval_global_micro_batch_size],\n                    params=self._state.params,\n                    params_replicated=(self.mesh is None),\n                    params_sharded=(self.mesh is not None),\n                    per_device_batch_size=eval_per_device_batch_size,\n                    desc=f'Sanity check')\n                self._deployer.log_info(\n                    'Sanity check (for prediction) passed.')\n\n                if eval_metric_fn is not None:\n                    json.dumps(eval_metric_fn(\n                        examples=eval_examples[:eval_global_micro_batch_size],\n                        preds=preds))\n                    self._deployer.log_info(\n                        'Sanity check (for evaluation metrics) passed.')\n\n            self._deployer._rng = rng_backup\n\n        for epoch_idx in range(self._init_epoch_idx, n_epochs):\n            if isinstance(train_examples, list):\n                epoch_train_examples = train_examples\n            else:\n                epoch_train_examples = train_examples(epoch_idx=epoch_idx)\n\n            self.train(\n                examples=epoch_train_examples,\n                per_device_batch_size=per_device_batch_size,\n                desc=f'epoch {epoch_idx} / {n_epochs}')\n\n            save_ckpt_kwargs = {\n                'epoch_idx': epoch_idx,\n                'save_opt_state': save_opt_states,\n                'float_dtype': save_float_dtype\n            }\n\n            if eval_examples is None:\n                self._deployer.log_info(\n                    'No evaluation cuz \\'eval_examples\\' is None.')\n            else:\n                eval_metrics = {}\n\n                if eval_loss:\n                    loss = self.eval_loss(\n                        examples=eval_examples,\n                        per_device_batch_size=eval_per_device_batch_size,\n                        desc=f'epoch {epoch_idx} / {n_epochs}')\n                    eval_metrics['loss'] = loss\n\n                if eval_predictor is not None:\n                    preds = eval_predictor.predict(\n                        examples=eval_examples,\n                        params=self._state.params,\n                        params_replicated=(self.mesh is None),\n                        params_sharded=(self.mesh is not None),\n                        per_device_batch_size=eval_per_device_batch_size,\n                        desc=f'epoch {epoch_idx} / {n_epochs}')\n\n                    if eval_metric_fn is not None:\n                        eval_metrics.update(eval_metric_fn(\n                            examples=eval_examples, preds=preds))\n\n                    eval_outputs = [\n                        {'example': example, 'pred': pred}\n                        for example, pred in zip(eval_examples, preds)]\n\n                    self._deployer.save_outputs(\n                        outputs=eval_outputs,\n                        desc=f'epoch{epoch_idx}',\n                        step=self.step)\n\n                self._deployer.log_info(\n                    info=json.dumps(eval_metrics, indent=4),\n                    title=f'Eval results',\n                    step=self.step)\n                self._deployer.log_metrics(metrics={\n                    f'eval_{key}': value\n                    for key, value in eval_metrics.items()\n                }, step=self.step)\n\n                if self.workdir is not None:\n                    result_filepath = \\\n                        f'{self.workdir}/eval_results_epoch{epoch_idx}.json'\n                    json.dump(\n                        eval_metrics, open(result_filepath, 'w'), indent=4)\n                    self._deployer.log_info(\n                        f'eval_results saved into {result_filepath}.')\n\n                for key in save_argmin_ckpt_by_metrics:\n                    assert self.workdir is not None\n                    if eval_metrics[key] < min_metrics.get(key, float('inf')):\n                        min_metrics[key] = eval_metrics[key]\n\n                        if jax.process_index() == 0:\n                            self._deployer.log_info(\n                                f'minimal {key} updated to {min_metrics[key]}')\n                            json.dump(min_metrics, open(\n                                f'{self.workdir}/min_metrics.json', 'w'))\n\n                        self.save_ckpt(\n                            ckpt_name=f'min_{key}', **save_ckpt_kwargs)\n\n                for key in save_argmax_ckpt_by_metrics:\n                    assert self.workdir is not None\n                    if eval_metrics[key] > max_metrics.get(key, float('-inf')):\n                        max_metrics[key] = eval_metrics[key]\n\n                        if jax.process_index() == 0:\n                            self._deployer.log_info(\n                                f'maximal {key} updated to {max_metrics[key]}')\n                            json.dump(max_metrics, open(\n                                f'{self.workdir}/max_metrics.json', 'w'))\n\n                        self.save_ckpt(\n                            ckpt_name=f'max_{key}', **save_ckpt_kwargs)\n\n            if save_every_ckpt:\n                self.save_ckpt(\n                    ckpt_name=f'epoch_{epoch_idx}', **save_ckpt_kwargs)\n            elif save_last_ckpt:\n                self.save_ckpt(ckpt_name='last', **save_ckpt_kwargs)\n\n    def save_ckpt(self, epoch_idx, ckpt_name, save_opt_state, float_dtype):\n        \"\"\"Saves a checkpoint into `{self.workdir}/ckpts`.\n\n        Args:\n            epoch_idx (int): The current epoch index.\n            ckpt_name (str): The name of the checkpoint.\n            save_opt_state (bool): Whether to save the optimizer state.\n            float_dtype (`jax.numpy.dtype`): Data type for saving float params.\n        \"\"\"\n        if self.mesh is None:\n            params = jax.tree.map(\n                fully_replicated_host_local_array_to_global_array,\n                self._state.params)\n        else:\n            params = self._state.params\n\n        opt_state = None\n        if save_opt_state:\n            if self.mesh is None:\n                opt_state = jax.tree.map(\n                    fully_replicated_host_local_array_to_global_array,\n                    self._state.opt_state)\n            else:\n                opt_state = self._state.opt_state\n\n        ckpt_dir = f'{self.workdir}/ckpts/{ckpt_name}'\n        self._deployer.save_ckpt(\n            ckpt_dir=ckpt_dir,\n            params=params,\n            opt_state=opt_state,\n            float_dtype=float_dtype,\n            step=self.step,\n            epoch_idx=epoch_idx)\n\n        if jax.process_index() == 0:\n            open(f'{self.workdir}/ckpts/last_ckpt.txt', 'w').write(ckpt_name)\n            self._deployer.log_info(f'last ckpt updated -- {ckpt_dir}')\n\n    @property\n    def step(self):\n        \"\"\"Returns the current training step.\"\"\"\n        if self.mesh is None:\n            return unreplicate(self._state.step).item()\n        else:\n            return self._state.step.item()\n\n    @property\n    def workdir(self):\n        \"\"\"Returns the working directory for saving checkpoints and logs.\"\"\"\n        return self._deployer.workdir\n\n    @property\n    def mesh(self):\n        \"\"\"Returns the mesh used for distributed training.\"\"\"\n        return self._deployer.mesh\n\n    @property\n    def state(self):\n        \"\"\"Returns the current training state.\"\"\"\n        return self._state\n
    "},{"location":"trainer/#redco.trainers.trainer.Trainer.mesh","title":"mesh property","text":"

    Returns the mesh used for distributed training.

    "},{"location":"trainer/#redco.trainers.trainer.Trainer.state","title":"state property","text":"

    Returns the current training state.

    "},{"location":"trainer/#redco.trainers.trainer.Trainer.step","title":"step property","text":"

    Returns the current training step.

    "},{"location":"trainer/#redco.trainers.trainer.Trainer.workdir","title":"workdir property","text":"

    Returns the working directory for saving checkpoints and logs.

    "},{"location":"trainer/#redco.trainers.trainer.Trainer.__init__","title":"__init__(deployer, collate_fn, apply_fn, loss_fn, params, optimizer, opt_state=None, compute_dtype=jnp.float32, last_ckpt_info=None, lr_schedule_fn=None, accumulate_grad_batches=None, params_sharding_rules=None)","text":"

    Initializes the Trainer with initial parameters, etc.

    Parameters:

    Name Type Description Default deployer Deployer

    A deployer supporting low-level operations.

    required collate_fn Callable

    The function converting a data batch to model inputs, e.g., tokenizing sentences into input_ids.

    required apply_fn Callable

    The function to apply the model, such as model.apply for Flax modules, or model itself for HuggingFace models. It would be set as state.apply_fn, and used in loss_fn.

    required loss_fn Callable

    The loss function converting model inputs to a scalar loss, e.g., computing cross-entropy loss from input_ids.

    required params dict

    Initial model parameters.

    required optimizer `optax.optimizer`

    The optimizer used for training.

    required opt_state dict

    optimizer state.

    None compute_dtype dtype

    Computation dtype, e.g., jnp.bfloat16, independent of param dtypes. (for mixed-precision training)

    float32 last_ckpt_info dict

    the beginning step and epoch.

    None lr_schedule_fn Callable

    The learning rate schedule function converting step to learning rate.

    None accumulate_grad_batches int

    Gradient accumulation step.

    None params_sharding_rules list

    Sharding rules.

    None Source code in redco/trainers/trainer.py
    def __init__(self,\n             deployer,\n             collate_fn,\n             apply_fn,\n             loss_fn,\n             params,\n             optimizer,\n             opt_state=None,\n             compute_dtype=jnp.float32,\n             last_ckpt_info=None,\n             lr_schedule_fn=None,\n             accumulate_grad_batches=None,\n             params_sharding_rules=None):\n    \"\"\"Initializes the Trainer with initial parameters, etc.\n\n    Args:\n        deployer (Deployer): A deployer supporting low-level operations.\n        collate_fn (Callable): The function converting a data batch to model\n            inputs, e.g., tokenizing sentences into input_ids.\n        apply_fn (Callable): The function to apply the model, such as\n            model.apply for Flax modules, or model itself for HuggingFace\n            models. It would be set as state.apply_fn, and used in loss_fn.\n        loss_fn (Callable): The loss function converting model inputs to a\n            scalar loss, e.g., computing cross-entropy loss from input_ids.\n        params (dict): Initial model parameters.\n        optimizer (`optax.optimizer`): The optimizer used for training.\n        opt_state (dict): optimizer state.\n        compute_dtype (dtype): Computation dtype, e.g., `jnp.bfloat16`,\n            independent of param dtypes. (for mixed-precision training)\n        last_ckpt_info (dict): the beginning step and epoch.\n        lr_schedule_fn (Callable): The learning rate schedule\n            function converting step to learning rate.\n        accumulate_grad_batches (int): Gradient accumulation step.\n        params_sharding_rules (list): Sharding rules.\n    \"\"\"\n    self._deployer = deployer\n    self._collate_fn = collate_fn\n    self._apply_fn = apply_fn\n    self._loss_fn = loss_fn\n    self._optimizer = optimizer\n    self._compute_dtype = compute_dtype\n    self._lr_schedule_fn = lr_schedule_fn\n    self._accumulate_grad_batches = accumulate_grad_batches\n    self._params_sharding_rules = params_sharding_rules\n\n    self._state = None\n    self._state_spec = None\n    self._p_train_step = None\n    self._p_eval_step = None\n\n    self._init_step = 0\n    self._init_epoch_idx = 0\n    if last_ckpt_info is not None:\n        self._init_step = last_ckpt_info.get('step', 0)\n        self._init_epoch_idx = last_ckpt_info.get('epoch_idx', -1) + 1\n\n    n_params = sum([param.size for param in jax.tree.leaves(params)])\n    self._deployer.log_info(f'{n_params:,}', title='Parameters')\n\n    self.set_train_state(\n        apply_fn=self._apply_fn,\n        params=params,\n        optimizer=self._optimizer,\n        step=self._init_step,\n        opt_state=opt_state)\n
    "},{"location":"trainer/#redco.trainers.trainer.Trainer.eval_loss","title":"eval_loss(examples, per_device_batch_size, desc=None)","text":"

    Evaluates the loss on the provided examples.

    Parameters:

    Name Type Description Default examples list

    Evaluation examples in list.

    required per_device_batch_size int

    The batch size per device.

    required desc str

    Description in the progress bar.

    None

    Returns:

    Type Description float

    The average loss over the evaluation examples.

    Source code in redco/trainers/trainer.py
    def eval_loss(self, examples, per_device_batch_size, desc=None):\n    \"\"\"Evaluates the loss on the provided examples.\n\n    Args:\n        examples (list): Evaluation examples in list.\n        per_device_batch_size (int): The batch size per device.\n        desc (str): Description in the progress bar.\n\n    Returns:\n        (float): The average loss over the evaluation examples.\n    \"\"\"\n    data_batches = self._deployer.get_model_input_batches(\n        examples=examples,\n        per_device_batch_size=per_device_batch_size,\n        collate_fn=self._collate_fn,\n        shuffle=False,\n        shuffle_rng=None,\n        desc=f'Evaluating ({desc})' if desc is not None else 'Evaluating')\n\n    losses = []\n    for batch in data_batches:\n        if self._p_eval_step is None:\n            self.setup_running_step(dummy_batch=batch)\n\n        metrics = self._deployer.run_model_step(\n            step_fn=self._p_eval_step, input_args=(self._state, batch))\n\n        if self.mesh is None:\n            metrics = unreplicate(metrics)\n\n        losses.append(metrics['loss'].item())\n        data_batches.set_postfix(**metrics)\n\n    return np.mean(losses).item()\n
    "},{"location":"trainer/#redco.trainers.trainer.Trainer.fit","title":"fit(train_examples, per_device_batch_size, n_epochs, eval_examples=None, eval_per_device_batch_size=None, eval_loss=True, eval_predictor=None, eval_metric_fn=None, eval_sanity_check=True, save_every_ckpt=False, save_last_ckpt=False, save_argmin_ckpt_by_metrics=None, save_argmax_ckpt_by_metrics=None, save_opt_states=True, save_float_dtype=None)","text":"

    Fits the model on the training data for a given number of epochs, optionally evaluating and saving checkpoints.

    Parameters:

    Name Type Description Default train_examples list or Callable

    Training examples, can be a list or a function of epoch_idx (for assigning different examples in separate epochs/chunks), e.g., train_examples=lambda epoch_idx: load_data(chunk_idx)

    required per_device_batch_size int

    The batch size per device.

    required n_epochs int

    Number of epochs to train.

    required eval_examples list

    Examples for evaluation and prediction.

    None eval_per_device_batch_size int

    Batch size for evaluation

    None eval_loss bool

    Whether to evaluate loss.

    True eval_predictor `redco.Predictor`

    Predicting on eval_examples.

    None eval_metric_fn Callable

    Metric function for prediction.

    None eval_sanity_check bool

    if to run a sanity check for evaluation & predict functions before training.

    True save_every_ckpt bool

    if to save a ckpt after every epoch.

    False save_last_ckpt bool

    Whether to save the last checkpoint.

    False save_argmin_ckpt_by_metrics list[str]

    Metrics to save checkpoints based on minimum values.

    None save_argmax_ckpt_by_metrics list[str]

    Metrics to save checkpoints based on maximum values.

    None save_opt_states bool

    of to save optimizer states in ckpts.

    True save_float_dtype bool

    The data type for saving checkpoints.

    None Source code in redco/trainers/trainer.py
    def fit(self,\n        train_examples,\n        per_device_batch_size,\n        n_epochs,\n        eval_examples=None,\n        eval_per_device_batch_size=None,\n        eval_loss=True,\n        eval_predictor=None,\n        eval_metric_fn=None,\n        eval_sanity_check=True,\n        save_every_ckpt=False,\n        save_last_ckpt=False,\n        save_argmin_ckpt_by_metrics=None,\n        save_argmax_ckpt_by_metrics=None,\n        save_opt_states=True,\n        save_float_dtype=None):\n    \"\"\"Fits the model on the training data for a given number of epochs,\n    optionally evaluating and saving checkpoints.\n\n    Args:\n        train_examples (list or Callable): Training examples, can be a\n            list or a function of epoch_idx (for assigning different\n            examples in separate epochs/chunks),\n            e.g., `train_examples=lambda epoch_idx: load_data(chunk_idx)`\n        per_device_batch_size (int): The batch size per device.\n        n_epochs (int): Number of epochs to train.\n        eval_examples (list): Examples for evaluation and prediction.\n        eval_per_device_batch_size (int): Batch size for evaluation\n        eval_loss (bool): Whether to evaluate loss.\n        eval_predictor (`redco.Predictor`): Predicting on `eval_examples`.\n        eval_metric_fn (Callable): Metric function for prediction.\n        eval_sanity_check (bool): if to run a sanity check for\n            evaluation & predict functions before training.\n        save_every_ckpt (bool): if to save a ckpt after every epoch.\n        save_last_ckpt (bool): Whether to save the last checkpoint.\n        save_argmin_ckpt_by_metrics (list[str]): Metrics to save checkpoints\n            based on minimum values.\n        save_argmax_ckpt_by_metrics (list[str]): Metrics to save checkpoints\n            based on maximum values.\n        save_opt_states (bool): of to save optimizer states in ckpts.\n        save_float_dtype (bool): The data type for saving checkpoints.\n    \"\"\"\n    if eval_per_device_batch_size is None:\n        eval_per_device_batch_size = per_device_batch_size\n\n    if save_argmax_ckpt_by_metrics is None:\n        save_argmax_ckpt_by_metrics = []\n    if save_argmin_ckpt_by_metrics is None:\n        save_argmin_ckpt_by_metrics = []\n    min_metrics, max_metrics = {}, {}\n\n    if os.path.exists(f'{self.workdir}/min_metrics.json'):\n        min_metrics = json.load(open(\n            f'{self.workdir}/min_metrics.json'))\n        self._deployer.log_info(\n            json.dumps(min_metrics, indent=4), title='Detected min_metrics')\n\n    if os.path.exists(f'{self.workdir}/max_metrics.json'):\n        max_metrics = json.load(open(\n            f'{self.workdir}/max_metrics.json'))\n        self._deployer.log_info(\n            json.dumps(max_metrics, indent=4), title='Detected max_metrics')\n\n    if eval_sanity_check and eval_examples is not None:\n        rng_backup = self._deployer._rng\n        _, eval_global_micro_batch_size = \\\n            self._deployer.get_local_global_micro_batch_size(\n                per_device_batch_size=eval_per_device_batch_size)\n\n        if eval_loss:\n            self.eval_loss(\n                examples=eval_examples[:eval_global_micro_batch_size],\n                per_device_batch_size=eval_per_device_batch_size,\n                desc=f'Sanity check')\n            self._deployer.log_info(\n                'Sanity check (for evaluation loss) passed.')\n\n        if eval_predictor is not None:\n            preds = eval_predictor.predict(\n                examples=eval_examples[:eval_global_micro_batch_size],\n                params=self._state.params,\n                params_replicated=(self.mesh is None),\n                params_sharded=(self.mesh is not None),\n                per_device_batch_size=eval_per_device_batch_size,\n                desc=f'Sanity check')\n            self._deployer.log_info(\n                'Sanity check (for prediction) passed.')\n\n            if eval_metric_fn is not None:\n                json.dumps(eval_metric_fn(\n                    examples=eval_examples[:eval_global_micro_batch_size],\n                    preds=preds))\n                self._deployer.log_info(\n                    'Sanity check (for evaluation metrics) passed.')\n\n        self._deployer._rng = rng_backup\n\n    for epoch_idx in range(self._init_epoch_idx, n_epochs):\n        if isinstance(train_examples, list):\n            epoch_train_examples = train_examples\n        else:\n            epoch_train_examples = train_examples(epoch_idx=epoch_idx)\n\n        self.train(\n            examples=epoch_train_examples,\n            per_device_batch_size=per_device_batch_size,\n            desc=f'epoch {epoch_idx} / {n_epochs}')\n\n        save_ckpt_kwargs = {\n            'epoch_idx': epoch_idx,\n            'save_opt_state': save_opt_states,\n            'float_dtype': save_float_dtype\n        }\n\n        if eval_examples is None:\n            self._deployer.log_info(\n                'No evaluation cuz \\'eval_examples\\' is None.')\n        else:\n            eval_metrics = {}\n\n            if eval_loss:\n                loss = self.eval_loss(\n                    examples=eval_examples,\n                    per_device_batch_size=eval_per_device_batch_size,\n                    desc=f'epoch {epoch_idx} / {n_epochs}')\n                eval_metrics['loss'] = loss\n\n            if eval_predictor is not None:\n                preds = eval_predictor.predict(\n                    examples=eval_examples,\n                    params=self._state.params,\n                    params_replicated=(self.mesh is None),\n                    params_sharded=(self.mesh is not None),\n                    per_device_batch_size=eval_per_device_batch_size,\n                    desc=f'epoch {epoch_idx} / {n_epochs}')\n\n                if eval_metric_fn is not None:\n                    eval_metrics.update(eval_metric_fn(\n                        examples=eval_examples, preds=preds))\n\n                eval_outputs = [\n                    {'example': example, 'pred': pred}\n                    for example, pred in zip(eval_examples, preds)]\n\n                self._deployer.save_outputs(\n                    outputs=eval_outputs,\n                    desc=f'epoch{epoch_idx}',\n                    step=self.step)\n\n            self._deployer.log_info(\n                info=json.dumps(eval_metrics, indent=4),\n                title=f'Eval results',\n                step=self.step)\n            self._deployer.log_metrics(metrics={\n                f'eval_{key}': value\n                for key, value in eval_metrics.items()\n            }, step=self.step)\n\n            if self.workdir is not None:\n                result_filepath = \\\n                    f'{self.workdir}/eval_results_epoch{epoch_idx}.json'\n                json.dump(\n                    eval_metrics, open(result_filepath, 'w'), indent=4)\n                self._deployer.log_info(\n                    f'eval_results saved into {result_filepath}.')\n\n            for key in save_argmin_ckpt_by_metrics:\n                assert self.workdir is not None\n                if eval_metrics[key] < min_metrics.get(key, float('inf')):\n                    min_metrics[key] = eval_metrics[key]\n\n                    if jax.process_index() == 0:\n                        self._deployer.log_info(\n                            f'minimal {key} updated to {min_metrics[key]}')\n                        json.dump(min_metrics, open(\n                            f'{self.workdir}/min_metrics.json', 'w'))\n\n                    self.save_ckpt(\n                        ckpt_name=f'min_{key}', **save_ckpt_kwargs)\n\n            for key in save_argmax_ckpt_by_metrics:\n                assert self.workdir is not None\n                if eval_metrics[key] > max_metrics.get(key, float('-inf')):\n                    max_metrics[key] = eval_metrics[key]\n\n                    if jax.process_index() == 0:\n                        self._deployer.log_info(\n                            f'maximal {key} updated to {max_metrics[key]}')\n                        json.dump(max_metrics, open(\n                            f'{self.workdir}/max_metrics.json', 'w'))\n\n                    self.save_ckpt(\n                        ckpt_name=f'max_{key}', **save_ckpt_kwargs)\n\n        if save_every_ckpt:\n            self.save_ckpt(\n                ckpt_name=f'epoch_{epoch_idx}', **save_ckpt_kwargs)\n        elif save_last_ckpt:\n            self.save_ckpt(ckpt_name='last', **save_ckpt_kwargs)\n
    "},{"location":"trainer/#redco.trainers.trainer.Trainer.save_ckpt","title":"save_ckpt(epoch_idx, ckpt_name, save_opt_state, float_dtype)","text":"

    Saves a checkpoint into {self.workdir}/ckpts.

    Parameters:

    Name Type Description Default epoch_idx int

    The current epoch index.

    required ckpt_name str

    The name of the checkpoint.

    required save_opt_state bool

    Whether to save the optimizer state.

    required float_dtype `jax.numpy.dtype`

    Data type for saving float params.

    required Source code in redco/trainers/trainer.py
    def save_ckpt(self, epoch_idx, ckpt_name, save_opt_state, float_dtype):\n    \"\"\"Saves a checkpoint into `{self.workdir}/ckpts`.\n\n    Args:\n        epoch_idx (int): The current epoch index.\n        ckpt_name (str): The name of the checkpoint.\n        save_opt_state (bool): Whether to save the optimizer state.\n        float_dtype (`jax.numpy.dtype`): Data type for saving float params.\n    \"\"\"\n    if self.mesh is None:\n        params = jax.tree.map(\n            fully_replicated_host_local_array_to_global_array,\n            self._state.params)\n    else:\n        params = self._state.params\n\n    opt_state = None\n    if save_opt_state:\n        if self.mesh is None:\n            opt_state = jax.tree.map(\n                fully_replicated_host_local_array_to_global_array,\n                self._state.opt_state)\n        else:\n            opt_state = self._state.opt_state\n\n    ckpt_dir = f'{self.workdir}/ckpts/{ckpt_name}'\n    self._deployer.save_ckpt(\n        ckpt_dir=ckpt_dir,\n        params=params,\n        opt_state=opt_state,\n        float_dtype=float_dtype,\n        step=self.step,\n        epoch_idx=epoch_idx)\n\n    if jax.process_index() == 0:\n        open(f'{self.workdir}/ckpts/last_ckpt.txt', 'w').write(ckpt_name)\n        self._deployer.log_info(f'last ckpt updated -- {ckpt_dir}')\n
    "},{"location":"trainer/#redco.trainers.trainer.Trainer.set_train_state","title":"set_train_state(apply_fn, params, optimizer, step, opt_state=None)","text":"

    Sets/Resets the training state with given parameters and optimizer.

    Parameters:

    Name Type Description Default apply_fn Callable

    The function to apply the model.

    required params dict

    Model parameters.

    required optimizer dict

    The optimizer used for training.

    required step int

    The training step.

    required opt_state dict

    The state of the optimizer.

    None Source code in redco/trainers/trainer.py
    def set_train_state(\n        self, apply_fn, params, optimizer, step, opt_state=None):\n    \"\"\"Sets/Resets the training state with given parameters and optimizer.\n\n    Args:\n        apply_fn (Callable): The function to apply the model.\n        params (dict): Model parameters.\n        optimizer (dict): The optimizer used for training.\n        step (int): The training step.\n        opt_state (dict): The state of the optimizer.\n    \"\"\"\n    self._deployer.log_info('Setting train_state ...')\n    params = freeze(params)\n\n    if self.mesh is None:\n        params = jax.device_put(params, jax.local_devices()[0])\n        if opt_state is None:\n            self._deployer.log_info('Initializing opt_state ...')\n            opt_state = optimizer.init(params)\n        else:\n            opt_state = jax.device_put(opt_state, jax.local_devices()[0])\n\n        self._state = train_state.TrainState(\n            step=step,\n            apply_fn=apply_fn,\n            params=params,\n            tx=optimizer,\n            opt_state=opt_state)\n        self._state = replicate(self._state)\n    else:\n        params_spec = self._deployer.get_params_spec(\n            params_shape_or_params=params,\n            params_sharding_rules=self._params_sharding_rules)\n        params = self._deployer.shard_params(\n            params=params, params_spec=params_spec)\n\n        if opt_state is None:\n            self._deployer.log_info('Initializing opt_state ...')\n            opt_state = optimizer.init(params)\n\n        opt_state_spec = self._deployer.get_opt_state_spec(\n            params_shape_or_params=params,\n            params_spec=params_spec,\n            optimizer=optimizer)\n        opt_state = self._deployer.shard_params(\n            params=opt_state,\n            params_spec=opt_state_spec,\n            desc='opt_state')\n\n        self._state = train_state.TrainState(\n            apply_fn=apply_fn,\n            params=params,\n            tx=optimizer,\n            opt_state=opt_state,\n            step=step)\n\n        self._state_spec = train_state.TrainState(\n            apply_fn=apply_fn,\n            params=params_spec,\n            tx=optimizer,\n            opt_state=opt_state_spec,\n            step=None)\n
    "},{"location":"trainer/#redco.trainers.trainer.Trainer.setup_running_step","title":"setup_running_step(dummy_batch)","text":"

    Sets up the running step functions for training and evaluation.

    Parameters:

    Name Type Description Default dummy_batch PyTree

    A dummy batch of data.

    required Source code in redco/trainers/trainer.py
    def setup_running_step(self, dummy_batch):\n    \"\"\"Sets up the running step functions for training and evaluation.\n\n    Args:\n        dummy_batch (PyTree): A dummy batch of data.\n    \"\"\"\n    train_step_fn = partial(\n        train_step,\n        loss_fn=self._loss_fn,\n        lr_schedule_fn=self._lr_schedule_fn,\n        mesh=self.mesh,\n        compute_dtype=self._compute_dtype)\n    eval_step_fn = partial(\n        eval_step,\n        loss_fn=self._loss_fn,\n        mesh=self.mesh,\n        compute_dtype=self._compute_dtype)\n\n    if self.mesh is None:\n        self._p_train_step = jax.pmap(train_step_fn, axis_name='dp')\n        self._p_eval_step = jax.pmap(eval_step_fn, axis_name='dp')\n    else:\n        data_spec = jax.tree.map(lambda x: P('dp'), dummy_batch)\n        self._p_train_step = pjit(\n            train_step_fn,\n            in_shardings=(None, self._state_spec, data_spec),\n            out_shardings=(self._state_spec, None),\n            donate_argnums=(1, ))\n        self._p_eval_step = pjit(\n            eval_step_fn,\n            in_shardings=(self._state_spec, data_spec),\n            out_shardings=None)\n
    "},{"location":"trainer/#redco.trainers.trainer.Trainer.train","title":"train(examples, per_device_batch_size, desc=None)","text":"

    Trains the model on the provided examples.

    Parameters:

    Name Type Description Default examples list

    Training examples in python list.

    required per_device_batch_size int

    The batch size per device.

    required desc str

    Description in the progress bar.

    None Source code in redco/trainers/trainer.py
    def train(self, examples, per_device_batch_size, desc=None):\n    \"\"\"Trains the model on the provided examples.\n\n    Args:\n        examples (list): Training examples in python list.\n        per_device_batch_size (int): The batch size per device.\n        desc (str): Description in the progress bar.\n    \"\"\"\n    data_batches = self._deployer.get_model_input_batches(\n        examples=examples,\n        per_device_batch_size=per_device_batch_size,\n        collate_fn=self._collate_fn,\n        shuffle=True,\n        shuffle_rng=self._deployer.gen_rng(),\n        desc=f'Training ({desc})' if desc is not None else 'Training',\n        is_train=True,\n        accumulate_grad_batches=self._accumulate_grad_batches)\n\n    for batch in data_batches:\n        if self._p_train_step is None:\n            self.setup_running_step(dummy_batch=batch)\n\n        train_rng = self._deployer.gen_rng()\n        if self.mesh is None:\n            train_rng = jax.random.split(\n                train_rng, num=jax.process_count())[jax.process_index()]\n            train_rng = shard_prng_key(train_rng)\n        self._state, metrics = self._deployer.run_model_step(\n            step_fn=self._p_train_step,\n            input_args=(train_rng, self._state, batch))\n\n        if self.mesh is None:\n            metrics = unreplicate(metrics)\n        data_batches.set_postfix(**metrics)\n        self._deployer.log_metrics(metrics=metrics, step=self.step)\n
    "}]} \ No newline at end of file +{"config":{"lang":["en"],"separator":"[\\s\\-]+","pipeline":["stopWordFilter"]},"docs":[{"location":"","title":"RedCoast","text":"

    Red Coast (redco) is a lightweight and user-friendly tool designed to automate distributed training and inference for large models while simplifying the ML pipeline development process without necessitating MLSys expertise from users.

    RedCoast supports Large Models + Complex Algorithms, in a lightweight and user-friendly manner:

    With RedCoast, to define a ML pipeline, only three functions are needed:

    Redco automates all the remaining of pipeline execution such as data and model parallelism, multi-host related processing, distributed checkpointing, randomness controlling, logging, etc.

    "},{"location":"deployer/","title":"Deployer","text":""},{"location":"deployer/#redco.deployers.deployer.Deployer","title":"Deployer","text":"

    Handles low-level operations to support Trainer and Predictor, e.g., automatic data/model parallelism, distributed checkpointing, data processing, logging, randomness controlling, etc.

    Attributes:

    Name Type Description workdir str

    Working directory for saving checkpoints and logs.

    mesh `jax.sharding.Mesh`

    Mesh used for model sharding.

    Source code in redco/deployers/deployer.py
    class Deployer:\n    \"\"\" Handles low-level operations to support Trainer and Predictor,\n        e.g., automatic data/model parallelism, distributed checkpointing,\n        data processing, logging, randomness controlling, etc.\n\n    Attributes:\n        workdir (str): Working directory for saving checkpoints and logs.\n        mesh (`jax.sharding.Mesh`): Mesh used for model sharding.\n    \"\"\"\n    def __init__(self,\n                 jax_seed,\n                 n_model_shards=1,\n                 verbose=True,\n                 workdir=None,\n                 n_processes=None,\n                 host0_address=None,\n                 host0_port=None,\n                 process_id=None,\n                 n_local_devices=None,\n                 run_tensorboard=False,\n                 wandb_init_kwargs=None):\n        \"\"\" Initializes a Deployer.\n\n        Args:\n            jax_seed (`jax.numpy.Array`): Seed for random number generation.\n            n_model_shards (int): Number of shards for running large model.\n            verbose (bool): Whether to enable verbose logging.\n            workdir (str):  Directory for saving logs and checkpoints.\n            n_processes (int):  For multi-host, number of processes/nodes.\n            host0_address (str):  For multi-host, address of the host0.\n            host0_port (int): For multi-host, port of the host0.\n            process_id (int): For multi-host, index of the current process.\n            n_local_devices (int): For multi-host, number of local devices.\n            run_tensorboard (bool):  Whether to enable TensorBoard logging.\n            wandb_init_kwargs (dict): wandb.init arguments if using wandb.\n        \"\"\"\n        if n_processes is None:\n            if 'SLURM_JOB_NUM_NODES' in os.environ:\n                n_processes = int(os.environ['SLURM_JOB_NUM_NODES'])\n                process_id = int(os.environ['SLURM_NODEID'])\n            else:\n                n_processes = 1\n\n        if n_processes > 1:\n            local_device_ids = None if n_local_devices is None \\\n                else list(range(n_local_devices))\n\n            if host0_port is None:\n                host0_port = DEFAULT_HOST0_PORT\n\n            jax.distributed.initialize(\n                coordinator_address=f'{host0_address}:{host0_port}',\n                num_processes=n_processes,\n                process_id=process_id,\n                local_device_ids=local_device_ids)\n\n        if workdir is not None:\n            os.makedirs(workdir, exist_ok=True)\n\n        self._verbose = verbose\n        self._workdir = workdir\n        self._logger = get_logger(verbose=verbose, workdir=workdir)\n\n        if wandb_init_kwargs is not None and jax.process_index() == 0:\n            import wandb\n            wandb.init(**wandb_init_kwargs)\n            self._wandb_log_fn = wandb.log\n        else:\n            self._wandb_log_fn = None\n\n        if run_tensorboard and jax.process_index() == 0:\n            from flax.metrics import tensorboard\n            self._summary_writer = tensorboard.SummaryWriter(workdir)\n        else:\n            self._summary_writer = None\n\n        self.log_info(\n            f'Local Devices: {jax.local_device_count()} / {jax.device_count()}')\n\n        self._rng = jax.random.PRNGKey(seed=jax_seed)\n        self._mesh = get_mesh(n_model_shards=n_model_shards)\n        self._checkpointer = ocp.PyTreeCheckpointer()\n\n    def get_local_global_micro_batch_size(self, per_device_batch_size):\n        \"\"\"Get local/global micro batch sizes based on per-device batch size.\"\"\"\n        if self._mesh is None:\n            local_micro_batch_size = \\\n                per_device_batch_size * jax.local_device_count()\n            global_micro_batch_size = \\\n                local_micro_batch_size * jax.process_count()\n        else:\n            global_micro_batch_size = local_micro_batch_size = \\\n                per_device_batch_size * self._mesh.shape['dp']\n\n        return local_micro_batch_size, global_micro_batch_size\n\n    def get_accumulate_grad_batches(\n            self, global_batch_size, per_device_batch_size):\n        \"\"\"Calculates the number of gradient accumulation batches.\"\"\"\n        _, global_micro_batch_size = self.get_local_global_micro_batch_size(\n            per_device_batch_size=per_device_batch_size)\n        assert global_batch_size % global_micro_batch_size == 0\n        accumulate_grad_batches = global_batch_size // global_micro_batch_size\n\n        return accumulate_grad_batches\n\n    def get_model_input_batches(self,\n                                examples,\n                                per_device_batch_size,\n                                collate_fn,\n                                shuffle,\n                                shuffle_rng,\n                                desc,\n                                is_train=False,\n                                accumulate_grad_batches=None):\n        \"\"\"Prepares model input batches from examples.\n\n        Args:\n            examples (list): List of input examples.\n            per_device_batch_size (int): Batch size per device.\n            collate_fn (Callable): Function to collate the examples.\n            shuffle (bool): Whether to shuffle the examples.\n            shuffle_rng (`jax.numpy.Array`): RNG for randomness of shuffling.\n            desc (str): Description in the progress bar.\n            is_train (bool): Whether the data is for training.\n            accumulate_grad_batches (int): gradient accumulation batches.\n\n        Returns:\n            (generator): A python generator of batched model inputs.\n        \"\"\"\n        local_micro_batch_size, global_micro_batch_size = \\\n            self.get_local_global_micro_batch_size(\n                per_device_batch_size=per_device_batch_size)\n\n        examples = get_host_examples(\n            examples=examples,\n            global_micro_batch_size=global_micro_batch_size,\n            shuffle=shuffle,\n            shuffle_rng=shuffle_rng,\n            mesh=self._mesh)\n\n        if not is_train:\n            desc = f'{desc} (global_batch_size = {global_micro_batch_size})'\n        elif accumulate_grad_batches is None:\n            desc = \\\n                f'{desc} (global_micro_batch_size = {global_micro_batch_size})'\n        else:\n            desc = (f'{desc} ('\n                    f'global_micro_batch_size = {global_micro_batch_size}, '\n                    f'accumulate_grad_batches = {accumulate_grad_batches})')\n\n        return get_data_batches(\n            examples=examples,\n            batch_size=local_micro_batch_size,\n            collate_fn=collate_fn,\n            mesh=self._mesh,\n            desc=desc,\n            verbose=self._verbose)\n\n    def get_lr_schedule_fn(self,\n                           train_size,\n                           per_device_batch_size,\n                           n_epochs,\n                           learning_rate,\n                           schedule_type='linear',\n                           warmup_ratio=0.,\n                           warmup_steps=None,\n                           init_learning_rate=0.,\n                           end_learning_rate=0.):\n        \"\"\"Creates a learning rate schedule function.\n\n        Args:\n            train_size (int): Number of training examples per epoch.\n            per_device_batch_size (int): Batch size per device.\n            n_epochs (int): Number of epochs.\n            learning_rate (float): Peak learning rate.\n            schedule_type (str): Type of lr schedule, \"linear\" or \"cosine\".\n            warmup_ratio (float): Ratio of lr warmup.\n            warmup_steps (int): Number of warmup steps.\n            init_learning_rate (float): Initial learning rate before warmup.\n            end_learning_rate (float): End learning rate for the schedule.\n\n        Returns:\n            (Callable): A lr schedule function, step -> learning rate.\n        \"\"\"\n        _, global_micro_batch_size = self.get_local_global_micro_batch_size(\n            per_device_batch_size=per_device_batch_size)\n        total_train_steps = n_epochs * (train_size // global_micro_batch_size)\n\n        if warmup_steps is None:\n            warmup_steps = int(total_train_steps * warmup_ratio)\n\n        return get_lr_schedule_fn(\n            schedule_type=schedule_type,\n            total_train_steps=total_train_steps,\n            warmup_steps=warmup_steps,\n            init_learning_rate=init_learning_rate,\n            learning_rate=learning_rate,\n            end_learning_rate=end_learning_rate)\n\n    def get_sharding_rules(self, params_shape_or_params):\n        \"\"\"Get sharding rules based on the parameter shapes.\"\"\"\n        if self._mesh is None:\n            return None\n        else:\n            sharding_rules = get_sharding_rules(\n                params_shape_or_params=params_shape_or_params,\n                n_model_shards=self._mesh.shape['mp'])\n            return sharding_rules\n\n    def get_params_spec(self, params_shape_or_params, params_sharding_rules):\n        \"\"\"Generates parameter specs based on sharding rules.\"\"\"\n        return get_params_spec(\n            params_shape_or_params=params_shape_or_params,\n            params_sharding_rules=params_sharding_rules)\n\n    def get_opt_state_spec(\n            self, params_shape_or_params, params_spec, optimizer):\n        \"\"\"Get optimizer state specs\"\"\"\n        return get_opt_state_spec(\n            params_shape_or_params=params_shape_or_params,\n            params_spec=params_spec,\n            optimizer=optimizer)\n\n    def shard_params(self, params, params_spec, desc='params'):\n        \"\"\"Distributes parameters to all devices based on the provided specs.\"\"\"\n        self.log_info(info=f'Sharding {desc} ...')\n        return shard_params(\n            mesh=self._mesh, params=params, params_spec=params_spec)\n\n    def run_model_step(self, step_fn, input_args):\n        \"\"\"Executes a model step function with the provided inputs.\"\"\"\n        if self._mesh is None:\n            return step_fn(*input_args)\n        else:\n            with self._mesh:\n                return step_fn(*input_args)\n\n    def gen_rng(self):\n        \"\"\"Get a new random number generator key and update the random state.\"\"\"\n        self._rng, new_rng = jax.random.split(self._rng)\n        return new_rng\n\n    def log_info(self, info, title=None, step=None):\n        \"\"\"Logs a messages\"\"\"\n        log_info(\n            info=info,\n            title=title,\n            logger=self._logger,\n            summary_writer=self._summary_writer,\n            step=step)\n\n    def log_metrics(self, metrics, step):\n        \"\"\"Logs metrics to TensorBoard and Weights and Biases (wandb).\"\"\"\n        if self._summary_writer is not None:\n            for metric_name, value in metrics.items():\n                self._summary_writer.scalar(metric_name, value, step=step)\n\n        if self._wandb_log_fn is not None:\n            self._wandb_log_fn(metrics, step)\n\n    def save_outputs(self, outputs, desc, step):\n        \"\"\"Saves model outputs to workdir.\"\"\"\n        if self._workdir is not None and jax.process_index() == 0:\n            save_outputs(\n                workdir=self._workdir,\n                outputs=outputs,\n                desc=desc,\n                step=step,\n                logger=self._logger,\n                summary_writer=self._summary_writer)\n\n    def save_ckpt(\n            self, ckpt_dir, params, opt_state=None, float_dtype=None, **kwargs):\n        \"\"\"Saves a checkpoint to the specified directory.\n\n        Args:\n            ckpt_dir (str): Directory to save the checkpoint.\n            params (dict): Model parameters.\n            opt_state (dict): Optimizer state.\n            float_dtype (`jax.numpy.dtype`): Dtype for floating point numbers.\n            **kwargs (dict): Additional information to be saved into\n                info.json, e.g., current training step, epoch index, etc.\n        \"\"\"\n        ckpt_dir = os.path.abspath(ckpt_dir)\n        self.log_info(f'Saving ckpt to {ckpt_dir} ...')\n        save_ckpt(\n            ckpt_dir=ckpt_dir,\n            checkpointer=self._checkpointer,\n            params=params,\n            opt_state=opt_state,\n            float_dtype=float_dtype,\n            rng=self._rng,\n            **kwargs)\n        self.log_info(f'Ckpt saved into {ckpt_dir}')\n\n    def load_params_shape(self, ckpt_dir):\n        \"\"\"Loads the shape of the parameters from a checkpoint.\"\"\"\n        return load_params_shape(ckpt_dir=ckpt_dir)\n\n    def load_ckpt(self,\n                  ckpt_dir,\n                  params_sharding_rules=None,\n                  optimizer=None,\n                  float_dtype=None,\n                  load_params=True,\n                  load_opt_state=True,\n                  update_rng=False):\n        \"\"\"Loads a checkpoint from the specified directory.\n\n        Args:\n            ckpt_dir (str): Directory of the checkpoint.\n            params_sharding_rules (list[tuple]): Sharding rules for parameters.\n            optimizer (`optax.optimizer`): Optimizer for loading opt_state.\n            float_dtype (`jax.numpy.dtype`): Dtype for floating point numbers.\n            load_params (bool): Whether to load the parameters.\n            load_opt_state (bool): Whether to load the optimizer state.\n            update_rng (bool): if updating the random state of the deployer.\n\n        Returns:\n            (tuple): A tuple with the loaded checkpoint (in a dict with\n                `\"params\"` and `\"opt_state\"`) and additional information (in a\n                dict, usually including `\"steps\"`, `\"epoch_idx\"`, and `\"rng\"`).\n        \"\"\"\n        ckpt_dir = os.path.abspath(ckpt_dir)\n        self.log_info(f'Loading ckpt from {ckpt_dir} ...')\n\n        params_shape = self.load_params_shape(ckpt_dir=ckpt_dir)\n\n        specs = {}\n        if self._mesh is not None:\n            if params_sharding_rules is None:\n                params_sharding_rules = self.get_sharding_rules(\n                    params_shape_or_params=params_shape)\n\n            specs['params'] = self.get_params_spec(\n                params_shape_or_params=params_shape,\n                params_sharding_rules=params_sharding_rules)\n            if optimizer is not None:\n                specs['opt_state'] = self.get_opt_state_spec(\n                    params_shape_or_params=params_shape,\n                    params_spec=specs['params'],\n                    optimizer=optimizer)\n\n        ckpt, info = load_ckpt(\n            ckpt_dir=ckpt_dir,\n            checkpointer=self._checkpointer,\n            params_shape_or_params=params_shape,\n            optimizer=optimizer,\n            float_dtype=float_dtype,\n            mesh=self._mesh,\n            specs=specs,\n            load_params=load_params,\n            load_opt_state=load_opt_state)\n\n        for key, value in info.items():\n            if not update_rng and key == 'rng':\n                continue\n            self.log_info(f'{ckpt_dir}::{key} = {value}')\n\n        if update_rng:\n            self._rng = info['rng']\n            self.log_info(f'rng updated to {self._rng} (by {ckpt_dir})')\n\n        return ckpt, info\n\n    def load_last_ckpt(self,\n                       optimizer=None,\n                       params_sharding_rules=None,\n                       float_dtype=None,\n                       load_params=True,\n                       load_opt_state=True,\n                       update_rng=True):\n        \"\"\"Loads the last checkpoint from the work directory (self.workdir).\n        See load_ckpt() for the explanation of arguments.\n        \"\"\"\n        try:\n            last_ckpt_name = open(\n                f'{self._workdir}/ckpts/last_ckpt.txt').read().strip()\n        except:\n            self.log_info(\n                f'{self._workdir}/ckpts/last_ckpt.txt not found. '\n                f'no ckpt loaded.')\n            return None, None\n\n        return self.load_ckpt(\n            ckpt_dir=f'{self._workdir}/ckpts/{last_ckpt_name}',\n            optimizer=optimizer,\n            float_dtype=float_dtype,\n            params_sharding_rules=params_sharding_rules,\n            load_params=load_params,\n            load_opt_state=load_opt_state,\n            update_rng=update_rng)\n\n    @property\n    def mesh(self):\n        \"\"\"Returns the mesh for model sharding\"\"\"\n        return self._mesh\n\n    @property\n    def workdir(self):\n        \"\"\"Returns the work directory.\"\"\"\n        return self._workdir\n
    "},{"location":"deployer/#redco.deployers.deployer.Deployer.mesh","title":"mesh property","text":"

    Returns the mesh for model sharding

    "},{"location":"deployer/#redco.deployers.deployer.Deployer.workdir","title":"workdir property","text":"

    Returns the work directory.

    "},{"location":"deployer/#redco.deployers.deployer.Deployer.__init__","title":"__init__(jax_seed, n_model_shards=1, verbose=True, workdir=None, n_processes=None, host0_address=None, host0_port=None, process_id=None, n_local_devices=None, run_tensorboard=False, wandb_init_kwargs=None)","text":"

    Initializes a Deployer.

    Parameters:

    Name Type Description Default jax_seed `jax.numpy.Array`

    Seed for random number generation.

    required n_model_shards int

    Number of shards for running large model.

    1 verbose bool

    Whether to enable verbose logging.

    True workdir str

    Directory for saving logs and checkpoints.

    None n_processes int

    For multi-host, number of processes/nodes.

    None host0_address str

    For multi-host, address of the host0.

    None host0_port int

    For multi-host, port of the host0.

    None process_id int

    For multi-host, index of the current process.

    None n_local_devices int

    For multi-host, number of local devices.

    None run_tensorboard bool

    Whether to enable TensorBoard logging.

    False wandb_init_kwargs dict

    wandb.init arguments if using wandb.

    None Source code in redco/deployers/deployer.py
    def __init__(self,\n             jax_seed,\n             n_model_shards=1,\n             verbose=True,\n             workdir=None,\n             n_processes=None,\n             host0_address=None,\n             host0_port=None,\n             process_id=None,\n             n_local_devices=None,\n             run_tensorboard=False,\n             wandb_init_kwargs=None):\n    \"\"\" Initializes a Deployer.\n\n    Args:\n        jax_seed (`jax.numpy.Array`): Seed for random number generation.\n        n_model_shards (int): Number of shards for running large model.\n        verbose (bool): Whether to enable verbose logging.\n        workdir (str):  Directory for saving logs and checkpoints.\n        n_processes (int):  For multi-host, number of processes/nodes.\n        host0_address (str):  For multi-host, address of the host0.\n        host0_port (int): For multi-host, port of the host0.\n        process_id (int): For multi-host, index of the current process.\n        n_local_devices (int): For multi-host, number of local devices.\n        run_tensorboard (bool):  Whether to enable TensorBoard logging.\n        wandb_init_kwargs (dict): wandb.init arguments if using wandb.\n    \"\"\"\n    if n_processes is None:\n        if 'SLURM_JOB_NUM_NODES' in os.environ:\n            n_processes = int(os.environ['SLURM_JOB_NUM_NODES'])\n            process_id = int(os.environ['SLURM_NODEID'])\n        else:\n            n_processes = 1\n\n    if n_processes > 1:\n        local_device_ids = None if n_local_devices is None \\\n            else list(range(n_local_devices))\n\n        if host0_port is None:\n            host0_port = DEFAULT_HOST0_PORT\n\n        jax.distributed.initialize(\n            coordinator_address=f'{host0_address}:{host0_port}',\n            num_processes=n_processes,\n            process_id=process_id,\n            local_device_ids=local_device_ids)\n\n    if workdir is not None:\n        os.makedirs(workdir, exist_ok=True)\n\n    self._verbose = verbose\n    self._workdir = workdir\n    self._logger = get_logger(verbose=verbose, workdir=workdir)\n\n    if wandb_init_kwargs is not None and jax.process_index() == 0:\n        import wandb\n        wandb.init(**wandb_init_kwargs)\n        self._wandb_log_fn = wandb.log\n    else:\n        self._wandb_log_fn = None\n\n    if run_tensorboard and jax.process_index() == 0:\n        from flax.metrics import tensorboard\n        self._summary_writer = tensorboard.SummaryWriter(workdir)\n    else:\n        self._summary_writer = None\n\n    self.log_info(\n        f'Local Devices: {jax.local_device_count()} / {jax.device_count()}')\n\n    self._rng = jax.random.PRNGKey(seed=jax_seed)\n    self._mesh = get_mesh(n_model_shards=n_model_shards)\n    self._checkpointer = ocp.PyTreeCheckpointer()\n
    "},{"location":"deployer/#redco.deployers.deployer.Deployer.gen_rng","title":"gen_rng()","text":"

    Get a new random number generator key and update the random state.

    Source code in redco/deployers/deployer.py
    def gen_rng(self):\n    \"\"\"Get a new random number generator key and update the random state.\"\"\"\n    self._rng, new_rng = jax.random.split(self._rng)\n    return new_rng\n
    "},{"location":"deployer/#redco.deployers.deployer.Deployer.get_accumulate_grad_batches","title":"get_accumulate_grad_batches(global_batch_size, per_device_batch_size)","text":"

    Calculates the number of gradient accumulation batches.

    Source code in redco/deployers/deployer.py
    def get_accumulate_grad_batches(\n        self, global_batch_size, per_device_batch_size):\n    \"\"\"Calculates the number of gradient accumulation batches.\"\"\"\n    _, global_micro_batch_size = self.get_local_global_micro_batch_size(\n        per_device_batch_size=per_device_batch_size)\n    assert global_batch_size % global_micro_batch_size == 0\n    accumulate_grad_batches = global_batch_size // global_micro_batch_size\n\n    return accumulate_grad_batches\n
    "},{"location":"deployer/#redco.deployers.deployer.Deployer.get_local_global_micro_batch_size","title":"get_local_global_micro_batch_size(per_device_batch_size)","text":"

    Get local/global micro batch sizes based on per-device batch size.

    Source code in redco/deployers/deployer.py
    def get_local_global_micro_batch_size(self, per_device_batch_size):\n    \"\"\"Get local/global micro batch sizes based on per-device batch size.\"\"\"\n    if self._mesh is None:\n        local_micro_batch_size = \\\n            per_device_batch_size * jax.local_device_count()\n        global_micro_batch_size = \\\n            local_micro_batch_size * jax.process_count()\n    else:\n        global_micro_batch_size = local_micro_batch_size = \\\n            per_device_batch_size * self._mesh.shape['dp']\n\n    return local_micro_batch_size, global_micro_batch_size\n
    "},{"location":"deployer/#redco.deployers.deployer.Deployer.get_lr_schedule_fn","title":"get_lr_schedule_fn(train_size, per_device_batch_size, n_epochs, learning_rate, schedule_type='linear', warmup_ratio=0.0, warmup_steps=None, init_learning_rate=0.0, end_learning_rate=0.0)","text":"

    Creates a learning rate schedule function.

    Parameters:

    Name Type Description Default train_size int

    Number of training examples per epoch.

    required per_device_batch_size int

    Batch size per device.

    required n_epochs int

    Number of epochs.

    required learning_rate float

    Peak learning rate.

    required schedule_type str

    Type of lr schedule, \"linear\" or \"cosine\".

    'linear' warmup_ratio float

    Ratio of lr warmup.

    0.0 warmup_steps int

    Number of warmup steps.

    None init_learning_rate float

    Initial learning rate before warmup.

    0.0 end_learning_rate float

    End learning rate for the schedule.

    0.0

    Returns:

    Type Description Callable

    A lr schedule function, step -> learning rate.

    Source code in redco/deployers/deployer.py
    def get_lr_schedule_fn(self,\n                       train_size,\n                       per_device_batch_size,\n                       n_epochs,\n                       learning_rate,\n                       schedule_type='linear',\n                       warmup_ratio=0.,\n                       warmup_steps=None,\n                       init_learning_rate=0.,\n                       end_learning_rate=0.):\n    \"\"\"Creates a learning rate schedule function.\n\n    Args:\n        train_size (int): Number of training examples per epoch.\n        per_device_batch_size (int): Batch size per device.\n        n_epochs (int): Number of epochs.\n        learning_rate (float): Peak learning rate.\n        schedule_type (str): Type of lr schedule, \"linear\" or \"cosine\".\n        warmup_ratio (float): Ratio of lr warmup.\n        warmup_steps (int): Number of warmup steps.\n        init_learning_rate (float): Initial learning rate before warmup.\n        end_learning_rate (float): End learning rate for the schedule.\n\n    Returns:\n        (Callable): A lr schedule function, step -> learning rate.\n    \"\"\"\n    _, global_micro_batch_size = self.get_local_global_micro_batch_size(\n        per_device_batch_size=per_device_batch_size)\n    total_train_steps = n_epochs * (train_size // global_micro_batch_size)\n\n    if warmup_steps is None:\n        warmup_steps = int(total_train_steps * warmup_ratio)\n\n    return get_lr_schedule_fn(\n        schedule_type=schedule_type,\n        total_train_steps=total_train_steps,\n        warmup_steps=warmup_steps,\n        init_learning_rate=init_learning_rate,\n        learning_rate=learning_rate,\n        end_learning_rate=end_learning_rate)\n
    "},{"location":"deployer/#redco.deployers.deployer.Deployer.get_model_input_batches","title":"get_model_input_batches(examples, per_device_batch_size, collate_fn, shuffle, shuffle_rng, desc, is_train=False, accumulate_grad_batches=None)","text":"

    Prepares model input batches from examples.

    Parameters:

    Name Type Description Default examples list

    List of input examples.

    required per_device_batch_size int

    Batch size per device.

    required collate_fn Callable

    Function to collate the examples.

    required shuffle bool

    Whether to shuffle the examples.

    required shuffle_rng `jax.numpy.Array`

    RNG for randomness of shuffling.

    required desc str

    Description in the progress bar.

    required is_train bool

    Whether the data is for training.

    False accumulate_grad_batches int

    gradient accumulation batches.

    None

    Returns:

    Type Description generator

    A python generator of batched model inputs.

    Source code in redco/deployers/deployer.py
    def get_model_input_batches(self,\n                            examples,\n                            per_device_batch_size,\n                            collate_fn,\n                            shuffle,\n                            shuffle_rng,\n                            desc,\n                            is_train=False,\n                            accumulate_grad_batches=None):\n    \"\"\"Prepares model input batches from examples.\n\n    Args:\n        examples (list): List of input examples.\n        per_device_batch_size (int): Batch size per device.\n        collate_fn (Callable): Function to collate the examples.\n        shuffle (bool): Whether to shuffle the examples.\n        shuffle_rng (`jax.numpy.Array`): RNG for randomness of shuffling.\n        desc (str): Description in the progress bar.\n        is_train (bool): Whether the data is for training.\n        accumulate_grad_batches (int): gradient accumulation batches.\n\n    Returns:\n        (generator): A python generator of batched model inputs.\n    \"\"\"\n    local_micro_batch_size, global_micro_batch_size = \\\n        self.get_local_global_micro_batch_size(\n            per_device_batch_size=per_device_batch_size)\n\n    examples = get_host_examples(\n        examples=examples,\n        global_micro_batch_size=global_micro_batch_size,\n        shuffle=shuffle,\n        shuffle_rng=shuffle_rng,\n        mesh=self._mesh)\n\n    if not is_train:\n        desc = f'{desc} (global_batch_size = {global_micro_batch_size})'\n    elif accumulate_grad_batches is None:\n        desc = \\\n            f'{desc} (global_micro_batch_size = {global_micro_batch_size})'\n    else:\n        desc = (f'{desc} ('\n                f'global_micro_batch_size = {global_micro_batch_size}, '\n                f'accumulate_grad_batches = {accumulate_grad_batches})')\n\n    return get_data_batches(\n        examples=examples,\n        batch_size=local_micro_batch_size,\n        collate_fn=collate_fn,\n        mesh=self._mesh,\n        desc=desc,\n        verbose=self._verbose)\n
    "},{"location":"deployer/#redco.deployers.deployer.Deployer.get_opt_state_spec","title":"get_opt_state_spec(params_shape_or_params, params_spec, optimizer)","text":"

    Get optimizer state specs

    Source code in redco/deployers/deployer.py
    def get_opt_state_spec(\n        self, params_shape_or_params, params_spec, optimizer):\n    \"\"\"Get optimizer state specs\"\"\"\n    return get_opt_state_spec(\n        params_shape_or_params=params_shape_or_params,\n        params_spec=params_spec,\n        optimizer=optimizer)\n
    "},{"location":"deployer/#redco.deployers.deployer.Deployer.get_params_spec","title":"get_params_spec(params_shape_or_params, params_sharding_rules)","text":"

    Generates parameter specs based on sharding rules.

    Source code in redco/deployers/deployer.py
    def get_params_spec(self, params_shape_or_params, params_sharding_rules):\n    \"\"\"Generates parameter specs based on sharding rules.\"\"\"\n    return get_params_spec(\n        params_shape_or_params=params_shape_or_params,\n        params_sharding_rules=params_sharding_rules)\n
    "},{"location":"deployer/#redco.deployers.deployer.Deployer.get_sharding_rules","title":"get_sharding_rules(params_shape_or_params)","text":"

    Get sharding rules based on the parameter shapes.

    Source code in redco/deployers/deployer.py
    def get_sharding_rules(self, params_shape_or_params):\n    \"\"\"Get sharding rules based on the parameter shapes.\"\"\"\n    if self._mesh is None:\n        return None\n    else:\n        sharding_rules = get_sharding_rules(\n            params_shape_or_params=params_shape_or_params,\n            n_model_shards=self._mesh.shape['mp'])\n        return sharding_rules\n
    "},{"location":"deployer/#redco.deployers.deployer.Deployer.load_ckpt","title":"load_ckpt(ckpt_dir, params_sharding_rules=None, optimizer=None, float_dtype=None, load_params=True, load_opt_state=True, update_rng=False)","text":"

    Loads a checkpoint from the specified directory.

    Parameters:

    Name Type Description Default ckpt_dir str

    Directory of the checkpoint.

    required params_sharding_rules list[tuple]

    Sharding rules for parameters.

    None optimizer `optax.optimizer`

    Optimizer for loading opt_state.

    None float_dtype `jax.numpy.dtype`

    Dtype for floating point numbers.

    None load_params bool

    Whether to load the parameters.

    True load_opt_state bool

    Whether to load the optimizer state.

    True update_rng bool

    if updating the random state of the deployer.

    False

    Returns:

    Type Description tuple

    A tuple with the loaded checkpoint (in a dict with \"params\" and \"opt_state\") and additional information (in a dict, usually including \"steps\", \"epoch_idx\", and \"rng\").

    Source code in redco/deployers/deployer.py
    def load_ckpt(self,\n              ckpt_dir,\n              params_sharding_rules=None,\n              optimizer=None,\n              float_dtype=None,\n              load_params=True,\n              load_opt_state=True,\n              update_rng=False):\n    \"\"\"Loads a checkpoint from the specified directory.\n\n    Args:\n        ckpt_dir (str): Directory of the checkpoint.\n        params_sharding_rules (list[tuple]): Sharding rules for parameters.\n        optimizer (`optax.optimizer`): Optimizer for loading opt_state.\n        float_dtype (`jax.numpy.dtype`): Dtype for floating point numbers.\n        load_params (bool): Whether to load the parameters.\n        load_opt_state (bool): Whether to load the optimizer state.\n        update_rng (bool): if updating the random state of the deployer.\n\n    Returns:\n        (tuple): A tuple with the loaded checkpoint (in a dict with\n            `\"params\"` and `\"opt_state\"`) and additional information (in a\n            dict, usually including `\"steps\"`, `\"epoch_idx\"`, and `\"rng\"`).\n    \"\"\"\n    ckpt_dir = os.path.abspath(ckpt_dir)\n    self.log_info(f'Loading ckpt from {ckpt_dir} ...')\n\n    params_shape = self.load_params_shape(ckpt_dir=ckpt_dir)\n\n    specs = {}\n    if self._mesh is not None:\n        if params_sharding_rules is None:\n            params_sharding_rules = self.get_sharding_rules(\n                params_shape_or_params=params_shape)\n\n        specs['params'] = self.get_params_spec(\n            params_shape_or_params=params_shape,\n            params_sharding_rules=params_sharding_rules)\n        if optimizer is not None:\n            specs['opt_state'] = self.get_opt_state_spec(\n                params_shape_or_params=params_shape,\n                params_spec=specs['params'],\n                optimizer=optimizer)\n\n    ckpt, info = load_ckpt(\n        ckpt_dir=ckpt_dir,\n        checkpointer=self._checkpointer,\n        params_shape_or_params=params_shape,\n        optimizer=optimizer,\n        float_dtype=float_dtype,\n        mesh=self._mesh,\n        specs=specs,\n        load_params=load_params,\n        load_opt_state=load_opt_state)\n\n    for key, value in info.items():\n        if not update_rng and key == 'rng':\n            continue\n        self.log_info(f'{ckpt_dir}::{key} = {value}')\n\n    if update_rng:\n        self._rng = info['rng']\n        self.log_info(f'rng updated to {self._rng} (by {ckpt_dir})')\n\n    return ckpt, info\n
    "},{"location":"deployer/#redco.deployers.deployer.Deployer.load_last_ckpt","title":"load_last_ckpt(optimizer=None, params_sharding_rules=None, float_dtype=None, load_params=True, load_opt_state=True, update_rng=True)","text":"

    Loads the last checkpoint from the work directory (self.workdir). See load_ckpt() for the explanation of arguments.

    Source code in redco/deployers/deployer.py
    def load_last_ckpt(self,\n                   optimizer=None,\n                   params_sharding_rules=None,\n                   float_dtype=None,\n                   load_params=True,\n                   load_opt_state=True,\n                   update_rng=True):\n    \"\"\"Loads the last checkpoint from the work directory (self.workdir).\n    See load_ckpt() for the explanation of arguments.\n    \"\"\"\n    try:\n        last_ckpt_name = open(\n            f'{self._workdir}/ckpts/last_ckpt.txt').read().strip()\n    except:\n        self.log_info(\n            f'{self._workdir}/ckpts/last_ckpt.txt not found. '\n            f'no ckpt loaded.')\n        return None, None\n\n    return self.load_ckpt(\n        ckpt_dir=f'{self._workdir}/ckpts/{last_ckpt_name}',\n        optimizer=optimizer,\n        float_dtype=float_dtype,\n        params_sharding_rules=params_sharding_rules,\n        load_params=load_params,\n        load_opt_state=load_opt_state,\n        update_rng=update_rng)\n
    "},{"location":"deployer/#redco.deployers.deployer.Deployer.load_params_shape","title":"load_params_shape(ckpt_dir)","text":"

    Loads the shape of the parameters from a checkpoint.

    Source code in redco/deployers/deployer.py
    def load_params_shape(self, ckpt_dir):\n    \"\"\"Loads the shape of the parameters from a checkpoint.\"\"\"\n    return load_params_shape(ckpt_dir=ckpt_dir)\n
    "},{"location":"deployer/#redco.deployers.deployer.Deployer.log_info","title":"log_info(info, title=None, step=None)","text":"

    Logs a messages

    Source code in redco/deployers/deployer.py
    def log_info(self, info, title=None, step=None):\n    \"\"\"Logs a messages\"\"\"\n    log_info(\n        info=info,\n        title=title,\n        logger=self._logger,\n        summary_writer=self._summary_writer,\n        step=step)\n
    "},{"location":"deployer/#redco.deployers.deployer.Deployer.log_metrics","title":"log_metrics(metrics, step)","text":"

    Logs metrics to TensorBoard and Weights and Biases (wandb).

    Source code in redco/deployers/deployer.py
    def log_metrics(self, metrics, step):\n    \"\"\"Logs metrics to TensorBoard and Weights and Biases (wandb).\"\"\"\n    if self._summary_writer is not None:\n        for metric_name, value in metrics.items():\n            self._summary_writer.scalar(metric_name, value, step=step)\n\n    if self._wandb_log_fn is not None:\n        self._wandb_log_fn(metrics, step)\n
    "},{"location":"deployer/#redco.deployers.deployer.Deployer.run_model_step","title":"run_model_step(step_fn, input_args)","text":"

    Executes a model step function with the provided inputs.

    Source code in redco/deployers/deployer.py
    def run_model_step(self, step_fn, input_args):\n    \"\"\"Executes a model step function with the provided inputs.\"\"\"\n    if self._mesh is None:\n        return step_fn(*input_args)\n    else:\n        with self._mesh:\n            return step_fn(*input_args)\n
    "},{"location":"deployer/#redco.deployers.deployer.Deployer.save_ckpt","title":"save_ckpt(ckpt_dir, params, opt_state=None, float_dtype=None, **kwargs)","text":"

    Saves a checkpoint to the specified directory.

    Parameters:

    Name Type Description Default ckpt_dir str

    Directory to save the checkpoint.

    required params dict

    Model parameters.

    required opt_state dict

    Optimizer state.

    None float_dtype `jax.numpy.dtype`

    Dtype for floating point numbers.

    None **kwargs dict

    Additional information to be saved into info.json, e.g., current training step, epoch index, etc.

    {} Source code in redco/deployers/deployer.py
    def save_ckpt(\n        self, ckpt_dir, params, opt_state=None, float_dtype=None, **kwargs):\n    \"\"\"Saves a checkpoint to the specified directory.\n\n    Args:\n        ckpt_dir (str): Directory to save the checkpoint.\n        params (dict): Model parameters.\n        opt_state (dict): Optimizer state.\n        float_dtype (`jax.numpy.dtype`): Dtype for floating point numbers.\n        **kwargs (dict): Additional information to be saved into\n            info.json, e.g., current training step, epoch index, etc.\n    \"\"\"\n    ckpt_dir = os.path.abspath(ckpt_dir)\n    self.log_info(f'Saving ckpt to {ckpt_dir} ...')\n    save_ckpt(\n        ckpt_dir=ckpt_dir,\n        checkpointer=self._checkpointer,\n        params=params,\n        opt_state=opt_state,\n        float_dtype=float_dtype,\n        rng=self._rng,\n        **kwargs)\n    self.log_info(f'Ckpt saved into {ckpt_dir}')\n
    "},{"location":"deployer/#redco.deployers.deployer.Deployer.save_outputs","title":"save_outputs(outputs, desc, step)","text":"

    Saves model outputs to workdir.

    Source code in redco/deployers/deployer.py
    def save_outputs(self, outputs, desc, step):\n    \"\"\"Saves model outputs to workdir.\"\"\"\n    if self._workdir is not None and jax.process_index() == 0:\n        save_outputs(\n            workdir=self._workdir,\n            outputs=outputs,\n            desc=desc,\n            step=step,\n            logger=self._logger,\n            summary_writer=self._summary_writer)\n
    "},{"location":"deployer/#redco.deployers.deployer.Deployer.shard_params","title":"shard_params(params, params_spec, desc='params')","text":"

    Distributes parameters to all devices based on the provided specs.

    Source code in redco/deployers/deployer.py
    def shard_params(self, params, params_spec, desc='params'):\n    \"\"\"Distributes parameters to all devices based on the provided specs.\"\"\"\n    self.log_info(info=f'Sharding {desc} ...')\n    return shard_params(\n        mesh=self._mesh, params=params, params_spec=params_spec)\n
    "},{"location":"mnist/","title":"MNIST Example","text":"

    This is a trivial MNIST example with RedCoast. Runnable by

    python main.py\n

    To simulate multiple devices in cpu-only envs,

    XLA_FLAGS=\"--xla_force_host_platform_device_count=8\" python main.py\n

    "},{"location":"mnist/#source-code","title":"Source Code","text":"
    from functools import partial\nimport fire\nimport numpy as np\nfrom flax import linen as nn\nimport optax\nfrom torchvision.datasets import MNIST\nfrom redco import Deployer, Trainer, Predictor\n\n\n# A simple CNN model \n# Copied from https://github.com/google/flax/blob/main/examples/mnist/train.py\nclass CNN(nn.Module):\n    @nn.compact\n    def __call__(self, x):\n        x = nn.Conv(features=32, kernel_size=(3, 3))(x)\n        x = nn.relu(x)\n        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n        x = nn.Conv(features=64, kernel_size=(3, 3))(x)\n        x = nn.relu(x)\n        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n        x = x.reshape((x.shape[0], -1))  # flatten\n        x = nn.Dense(features=256)(x)\n        x = nn.relu(x)\n        x = nn.Dense(features=10)(x)\n        return x\n\n\n# Collate function converting a batch of raw examples to model inputs (in numpy) \ndef collate_fn(examples):\n    images = np.stack(\n        [np.array(example['image'])[:, :, None] for example in examples])\n    labels = np.array([example['label'] for example in examples])\n\n    return {'images': images, 'labels': labels}\n\n\n# Loss function converting model inputs to a scalar loss\ndef loss_fn(train_rng, state, params, batch, is_training):\n    logits = state.apply_fn({'params': params}, batch['images'])\n    return optax.softmax_cross_entropy_with_integer_labels(\n        logits=logits, labels=batch['labels']).mean()\n\n\n# Predict function converting model inputs to the model outputs\ndef pred_fn(pred_rng, params, batch, model):\n    accs = model.apply({'params': params}, batch['images']).argmax(axis=-1)\n    return {'acc': accs}\n\n\n# (Optional) Evaluation function in trainer.fit. Here it computes accuracy.\ndef eval_metric_fn(examples, preds):\n    preds = np.array([pred['acc'] for pred in preds])\n    labels = np.array([example['label'] for example in examples])\n    return {'acc': np.mean(preds == labels).item()}\n\n\ndef main(per_device_batch_size=64, learning_rate=1e-3, jax_seed=42):\n    deployer = Deployer(jax_seed=jax_seed, workdir='./workdir')\n\n    dataset = {\n        'train': [{'image': t[0], 'label': t[1]} for t in list(\n            MNIST('./data', train=True, download=True))],\n        'test': [{'image': t[0], 'label': t[1]} for t in list(\n            MNIST('./data', train=False, download=True))],\n    }\n\n    model = CNN()\n    dummy_batch = collate_fn(examples=[dataset['train'][0]])\n    params = model.init(deployer.gen_rng(), dummy_batch['images'])['params']\n\n    trainer = Trainer(\n        deployer=deployer,\n        collate_fn=collate_fn,\n        apply_fn=model.apply,\n        loss_fn=loss_fn,\n        params=params,\n        optimizer=optax.adamw(learning_rate=learning_rate))\n\n    predictor = Predictor(\n        deployer=deployer,\n        collate_fn=collate_fn,\n        pred_fn=partial(pred_fn, model=model))\n\n    trainer.fit(\n        train_examples=dataset['train'],\n        per_device_batch_size=per_device_batch_size,\n        n_epochs=2,\n        eval_examples=dataset['test'],\n        eval_predictor=predictor,\n        eval_metric_fn=eval_metric_fn)\n\n\nif __name__ == '__main__':\n    fire.Fire(main)\n
    "},{"location":"predictor/","title":"Predictor","text":""},{"location":"predictor/#redco.predictors.predictor.Predictor","title":"Predictor","text":"

    Predictor class managing distributed inference process.

    Attributes:

    Name Type Description mesh `jax.sharding.Mesh`

    Mesh used for distributed inference.

    Source code in redco/predictors/predictor.py
    class Predictor:\n    \"\"\"Predictor class managing distributed inference process.\n\n    Attributes:\n        mesh (`jax.sharding.Mesh`): Mesh used for distributed inference.\n    \"\"\"\n    def __init__(self,\n                 deployer,\n                 collate_fn,\n                 pred_fn,\n                 output_fn=None,\n                 params_sharding_rules=None):\n        \"\"\"Initializes a Predictor instance.\n\n        Args:\n            deployer (`redco.Deployer`): A deployer for low-level operations.\n            collate_fn (Callable): A function converting a data batch to model inputs,\n                e.g., tokenizing sentences into input_ids.\n            pred_fn (Callable): A function to produce model outputs with model inputs,\n                e.g., running beam search with a language model.\n            output_fn (Callable): A function finalizing model outputs (on CPU),\n                e.g., decoding generated ids to text.\n            params_sharding_rules (list[tuple]): Rules for sharding parameters.\n        \"\"\"\n        self._deployer = deployer\n        self._collate_fn = partial(collate_fn_wrapper, collate_fn=collate_fn)\n        self._params_sharding_rules = params_sharding_rules\n        self._pred_fn = partial(pred_fn_wrapper, pred_fn=pred_fn)\n        self._p_pred_step = None\n\n        if output_fn is None:\n            self._output_fn = default_output_fn\n        else:\n            self._output_fn = output_fn\n\n    def setup_running_step(self, dummy_batch, params_shape_or_params):\n        \"\"\"Sets up the prediction step function for distributed inference.\n\n        Args:\n            dummy_batch (PyTree): A dummy batch used to determine data shapes.\n            params_shape_or_params (dict): The shape of params or actual params.\n        \"\"\"\n        pred_step_fn = partial(pred_step, pred_fn=self._pred_fn, mesh=self.mesh)\n\n        if self.mesh is None:\n            self._p_pred_step = jax.pmap(pred_step_fn, axis_name='dp')\n        else:\n            data_spec = jax.tree.map(lambda x: P('dp'), dummy_batch)\n            params_spec = self._deployer.get_params_spec(\n                params_shape_or_params=params_shape_or_params,\n                params_sharding_rules=self._params_sharding_rules)\n            self._p_pred_step = pjit(\n                pred_step_fn,\n                in_shardings=(None, params_spec, data_spec),\n                out_shardings=None)\n\n    def predict(self,\n                examples,\n                per_device_batch_size,\n                params,\n                params_replicated=False,\n                params_sharded=False,\n                desc=None):\n        \"\"\"Runs distributed prediction on a list of examples.\n\n        Args:\n            examples (list): Input examples for prediction.\n            per_device_batch_size (int): Batch size per device.\n            params (dict): Model parameters in a dict/FrozenDict.\n            params_replicated (bool): if the params are already replicated.\n            params_sharded (bool): if the parameters are already sharded.\n            desc (str): Description to show in the progress bar.\n\n        Returns:\n            (list): A list of predictions corresponding to the input examples.\n        \"\"\"\n        raw_n_inputs = len(examples)\n        _, global_micro_batch_size = \\\n            self._deployer.get_local_global_micro_batch_size(\n                per_device_batch_size=per_device_batch_size)\n        examples = examples + [examples[0]] * (global_micro_batch_size - 1)\n        examples = add_idxes(examples=examples)\n\n        data_batches = self._deployer.get_model_input_batches(\n            examples=examples,\n            per_device_batch_size=per_device_batch_size,\n            collate_fn=self._collate_fn,\n            shuffle=False,\n            shuffle_rng=None,\n            desc=f'Predicting ({desc})' if desc is not None else 'Predicting')\n\n        params = freeze(params)\n        if (self.mesh is None) and (not params_replicated):\n            params = replicate(params)\n        if (self.mesh is not None) and (not params_sharded):\n            params_spec = self._deployer.get_params_spec(\n                params_shape_or_params=params,\n                params_sharding_rules=self._params_sharding_rules)\n            params = self._deployer.shard_params(\n                params=params, params_spec=params_spec)\n\n        preds = []\n        for batch in data_batches:\n            if self._p_pred_step is None:\n                self.setup_running_step(\n                    dummy_batch=batch, params_shape_or_params=params)\n\n            pred_rng = self._deployer.gen_rng()\n            if self.mesh is None:\n                pred_rng = jax.random.split(\n                    pred_rng, num=jax.process_count())[jax.process_index()]\n                pred_rng = shard_prng_key(pred_rng)\n\n            batch_preds_with_idxes = self._deployer.run_model_step(\n                step_fn=self._p_pred_step,\n                input_args=(pred_rng, params, batch))\n            batch_preds = process_batch_preds(\n                batch_preds_with_idxes=batch_preds_with_idxes, mesh=self.mesh)\n            batch_preds = self._output_fn(batch_preds)\n\n            assert isinstance(batch_preds, list) and \\\n                   len(batch_preds) == global_micro_batch_size\n            preds.extend(batch_preds)\n\n        return preds[:raw_n_inputs]\n\n    @property\n    def mesh(self):\n        \"\"\"Returns the mesh used for distributed inference.\"\"\"\n        return self._deployer.mesh\n
    "},{"location":"predictor/#redco.predictors.predictor.Predictor.mesh","title":"mesh property","text":"

    Returns the mesh used for distributed inference.

    "},{"location":"predictor/#redco.predictors.predictor.Predictor.__init__","title":"__init__(deployer, collate_fn, pred_fn, output_fn=None, params_sharding_rules=None)","text":"

    Initializes a Predictor instance.

    Parameters:

    Name Type Description Default deployer `redco.Deployer`

    A deployer for low-level operations.

    required collate_fn Callable

    A function converting a data batch to model inputs, e.g., tokenizing sentences into input_ids.

    required pred_fn Callable

    A function to produce model outputs with model inputs, e.g., running beam search with a language model.

    required output_fn Callable

    A function finalizing model outputs (on CPU), e.g., decoding generated ids to text.

    None params_sharding_rules list[tuple]

    Rules for sharding parameters.

    None Source code in redco/predictors/predictor.py
    def __init__(self,\n             deployer,\n             collate_fn,\n             pred_fn,\n             output_fn=None,\n             params_sharding_rules=None):\n    \"\"\"Initializes a Predictor instance.\n\n    Args:\n        deployer (`redco.Deployer`): A deployer for low-level operations.\n        collate_fn (Callable): A function converting a data batch to model inputs,\n            e.g., tokenizing sentences into input_ids.\n        pred_fn (Callable): A function to produce model outputs with model inputs,\n            e.g., running beam search with a language model.\n        output_fn (Callable): A function finalizing model outputs (on CPU),\n            e.g., decoding generated ids to text.\n        params_sharding_rules (list[tuple]): Rules for sharding parameters.\n    \"\"\"\n    self._deployer = deployer\n    self._collate_fn = partial(collate_fn_wrapper, collate_fn=collate_fn)\n    self._params_sharding_rules = params_sharding_rules\n    self._pred_fn = partial(pred_fn_wrapper, pred_fn=pred_fn)\n    self._p_pred_step = None\n\n    if output_fn is None:\n        self._output_fn = default_output_fn\n    else:\n        self._output_fn = output_fn\n
    "},{"location":"predictor/#redco.predictors.predictor.Predictor.predict","title":"predict(examples, per_device_batch_size, params, params_replicated=False, params_sharded=False, desc=None)","text":"

    Runs distributed prediction on a list of examples.

    Parameters:

    Name Type Description Default examples list

    Input examples for prediction.

    required per_device_batch_size int

    Batch size per device.

    required params dict

    Model parameters in a dict/FrozenDict.

    required params_replicated bool

    if the params are already replicated.

    False params_sharded bool

    if the parameters are already sharded.

    False desc str

    Description to show in the progress bar.

    None

    Returns:

    Type Description list

    A list of predictions corresponding to the input examples.

    Source code in redco/predictors/predictor.py
    def predict(self,\n            examples,\n            per_device_batch_size,\n            params,\n            params_replicated=False,\n            params_sharded=False,\n            desc=None):\n    \"\"\"Runs distributed prediction on a list of examples.\n\n    Args:\n        examples (list): Input examples for prediction.\n        per_device_batch_size (int): Batch size per device.\n        params (dict): Model parameters in a dict/FrozenDict.\n        params_replicated (bool): if the params are already replicated.\n        params_sharded (bool): if the parameters are already sharded.\n        desc (str): Description to show in the progress bar.\n\n    Returns:\n        (list): A list of predictions corresponding to the input examples.\n    \"\"\"\n    raw_n_inputs = len(examples)\n    _, global_micro_batch_size = \\\n        self._deployer.get_local_global_micro_batch_size(\n            per_device_batch_size=per_device_batch_size)\n    examples = examples + [examples[0]] * (global_micro_batch_size - 1)\n    examples = add_idxes(examples=examples)\n\n    data_batches = self._deployer.get_model_input_batches(\n        examples=examples,\n        per_device_batch_size=per_device_batch_size,\n        collate_fn=self._collate_fn,\n        shuffle=False,\n        shuffle_rng=None,\n        desc=f'Predicting ({desc})' if desc is not None else 'Predicting')\n\n    params = freeze(params)\n    if (self.mesh is None) and (not params_replicated):\n        params = replicate(params)\n    if (self.mesh is not None) and (not params_sharded):\n        params_spec = self._deployer.get_params_spec(\n            params_shape_or_params=params,\n            params_sharding_rules=self._params_sharding_rules)\n        params = self._deployer.shard_params(\n            params=params, params_spec=params_spec)\n\n    preds = []\n    for batch in data_batches:\n        if self._p_pred_step is None:\n            self.setup_running_step(\n                dummy_batch=batch, params_shape_or_params=params)\n\n        pred_rng = self._deployer.gen_rng()\n        if self.mesh is None:\n            pred_rng = jax.random.split(\n                pred_rng, num=jax.process_count())[jax.process_index()]\n            pred_rng = shard_prng_key(pred_rng)\n\n        batch_preds_with_idxes = self._deployer.run_model_step(\n            step_fn=self._p_pred_step,\n            input_args=(pred_rng, params, batch))\n        batch_preds = process_batch_preds(\n            batch_preds_with_idxes=batch_preds_with_idxes, mesh=self.mesh)\n        batch_preds = self._output_fn(batch_preds)\n\n        assert isinstance(batch_preds, list) and \\\n               len(batch_preds) == global_micro_batch_size\n        preds.extend(batch_preds)\n\n    return preds[:raw_n_inputs]\n
    "},{"location":"predictor/#redco.predictors.predictor.Predictor.setup_running_step","title":"setup_running_step(dummy_batch, params_shape_or_params)","text":"

    Sets up the prediction step function for distributed inference.

    Parameters:

    Name Type Description Default dummy_batch PyTree

    A dummy batch used to determine data shapes.

    required params_shape_or_params dict

    The shape of params or actual params.

    required Source code in redco/predictors/predictor.py
    def setup_running_step(self, dummy_batch, params_shape_or_params):\n    \"\"\"Sets up the prediction step function for distributed inference.\n\n    Args:\n        dummy_batch (PyTree): A dummy batch used to determine data shapes.\n        params_shape_or_params (dict): The shape of params or actual params.\n    \"\"\"\n    pred_step_fn = partial(pred_step, pred_fn=self._pred_fn, mesh=self.mesh)\n\n    if self.mesh is None:\n        self._p_pred_step = jax.pmap(pred_step_fn, axis_name='dp')\n    else:\n        data_spec = jax.tree.map(lambda x: P('dp'), dummy_batch)\n        params_spec = self._deployer.get_params_spec(\n            params_shape_or_params=params_shape_or_params,\n            params_sharding_rules=self._params_sharding_rules)\n        self._p_pred_step = pjit(\n            pred_step_fn,\n            in_shardings=(None, params_spec, data_spec),\n            out_shardings=None)\n
    "},{"location":"trainer/","title":"Trainer","text":""},{"location":"trainer/#redco.trainers.trainer.Trainer","title":"Trainer","text":"

    Trainer class managing distributed training process.

    Attributes:

    Name Type Description step int

    Current training step.

    workdir str

    Working directory for saving checkpoints and logs.

    mesh jax Mesh

    Mesh used for distributed training.

    state flax TrainState

    Current training state.

    Source code in redco/trainers/trainer.py
    class Trainer:\n    \"\"\"Trainer class managing distributed training process.\n\n    Attributes:\n        step (int): Current training step.\n        workdir (str): Working directory for saving checkpoints and logs.\n        mesh (jax Mesh): Mesh used for distributed training.\n        state (flax TrainState): Current training state.\n    \"\"\"\n    def __init__(self,\n                 deployer,\n                 collate_fn,\n                 apply_fn,\n                 loss_fn,\n                 params,\n                 optimizer,\n                 opt_state=None,\n                 compute_dtype=jnp.float32,\n                 last_ckpt_info=None,\n                 lr_schedule_fn=None,\n                 accumulate_grad_batches=None,\n                 params_sharding_rules=None):\n        \"\"\"Initializes the Trainer with initial parameters, etc.\n\n        Args:\n            deployer (Deployer): A deployer supporting low-level operations.\n            collate_fn (Callable): The function converting a data batch to model\n                inputs, e.g., tokenizing sentences into input_ids.\n            apply_fn (Callable): The function to apply the model, such as\n                model.apply for Flax modules, or model itself for HuggingFace\n                models. It would be set as state.apply_fn, and used in loss_fn.\n            loss_fn (Callable): The loss function converting model inputs to a\n                scalar loss, e.g., computing cross-entropy loss from input_ids.\n            params (dict): Initial model parameters.\n            optimizer (`optax.optimizer`): The optimizer used for training.\n            opt_state (dict): optimizer state.\n            compute_dtype (dtype): Computation dtype, e.g., `jnp.bfloat16`,\n                independent of param dtypes. (for mixed-precision training)\n            last_ckpt_info (dict): the beginning step and epoch.\n            lr_schedule_fn (Callable): The learning rate schedule\n                function converting step to learning rate.\n            accumulate_grad_batches (int): Gradient accumulation step.\n            params_sharding_rules (list): Sharding rules.\n        \"\"\"\n        self._deployer = deployer\n        self._collate_fn = collate_fn\n        self._apply_fn = apply_fn\n        self._loss_fn = loss_fn\n        self._optimizer = optimizer\n        self._compute_dtype = compute_dtype\n        self._lr_schedule_fn = lr_schedule_fn\n        self._accumulate_grad_batches = accumulate_grad_batches\n        self._params_sharding_rules = params_sharding_rules\n\n        self._state = None\n        self._state_spec = None\n        self._p_train_step = None\n        self._p_eval_step = None\n\n        self._init_step = 0\n        self._init_epoch_idx = 0\n        if last_ckpt_info is not None:\n            self._init_step = last_ckpt_info.get('step', 0)\n            self._init_epoch_idx = last_ckpt_info.get('epoch_idx', -1) + 1\n\n        n_params = sum([param.size for param in jax.tree.leaves(params)])\n        self._deployer.log_info(f'{n_params:,}', title='Parameters')\n\n        self.set_train_state(\n            apply_fn=self._apply_fn,\n            params=params,\n            optimizer=self._optimizer,\n            step=self._init_step,\n            opt_state=opt_state)\n\n    def set_train_state(\n            self, apply_fn, params, optimizer, step, opt_state=None):\n        \"\"\"Sets/Resets the training state with given parameters and optimizer.\n\n        Args:\n            apply_fn (Callable): The function to apply the model.\n            params (dict): Model parameters.\n            optimizer (dict): The optimizer used for training.\n            step (int): The training step.\n            opt_state (dict): The state of the optimizer.\n        \"\"\"\n        self._deployer.log_info('Setting train_state ...')\n        params = freeze(params)\n\n        if self.mesh is None:\n            params = jax.device_put(params, jax.local_devices()[0])\n            if opt_state is None:\n                self._deployer.log_info('Initializing opt_state ...')\n                opt_state = optimizer.init(params)\n            else:\n                opt_state = jax.device_put(opt_state, jax.local_devices()[0])\n\n            self._state = train_state.TrainState(\n                step=step,\n                apply_fn=apply_fn,\n                params=params,\n                tx=optimizer,\n                opt_state=opt_state)\n            self._state = replicate(self._state)\n        else:\n            params_spec = self._deployer.get_params_spec(\n                params_shape_or_params=params,\n                params_sharding_rules=self._params_sharding_rules)\n            params = self._deployer.shard_params(\n                params=params, params_spec=params_spec)\n\n            if opt_state is None:\n                self._deployer.log_info('Initializing opt_state ...')\n                opt_state = optimizer.init(params)\n\n            opt_state_spec = self._deployer.get_opt_state_spec(\n                params_shape_or_params=params,\n                params_spec=params_spec,\n                optimizer=optimizer)\n            opt_state = self._deployer.shard_params(\n                params=opt_state,\n                params_spec=opt_state_spec,\n                desc='opt_state')\n\n            self._state = train_state.TrainState(\n                apply_fn=apply_fn,\n                params=params,\n                tx=optimizer,\n                opt_state=opt_state,\n                step=step)\n\n            self._state_spec = train_state.TrainState(\n                apply_fn=apply_fn,\n                params=params_spec,\n                tx=optimizer,\n                opt_state=opt_state_spec,\n                step=None)\n\n    def setup_running_step(self, dummy_batch):\n        \"\"\"Sets up the running step functions for training and evaluation.\n\n        Args:\n            dummy_batch (PyTree): A dummy batch of data.\n        \"\"\"\n        train_step_fn = partial(\n            train_step,\n            loss_fn=self._loss_fn,\n            lr_schedule_fn=self._lr_schedule_fn,\n            mesh=self.mesh,\n            compute_dtype=self._compute_dtype)\n        eval_step_fn = partial(\n            eval_step,\n            loss_fn=self._loss_fn,\n            mesh=self.mesh,\n            compute_dtype=self._compute_dtype)\n\n        if self.mesh is None:\n            self._p_train_step = jax.pmap(train_step_fn, axis_name='dp')\n            self._p_eval_step = jax.pmap(eval_step_fn, axis_name='dp')\n        else:\n            data_spec = jax.tree.map(lambda x: P('dp'), dummy_batch)\n            self._p_train_step = pjit(\n                train_step_fn,\n                in_shardings=(None, self._state_spec, data_spec),\n                out_shardings=(self._state_spec, None),\n                donate_argnums=(1, ))\n            self._p_eval_step = pjit(\n                eval_step_fn,\n                in_shardings=(self._state_spec, data_spec),\n                out_shardings=None)\n\n    def train(self, examples, per_device_batch_size, desc=None):\n        \"\"\"Trains the model on the provided examples.\n\n        Args:\n            examples (list): Training examples in python list.\n            per_device_batch_size (int): The batch size per device.\n            desc (str): Description in the progress bar.\n        \"\"\"\n        data_batches = self._deployer.get_model_input_batches(\n            examples=examples,\n            per_device_batch_size=per_device_batch_size,\n            collate_fn=self._collate_fn,\n            shuffle=True,\n            shuffle_rng=self._deployer.gen_rng(),\n            desc=f'Training ({desc})' if desc is not None else 'Training',\n            is_train=True,\n            accumulate_grad_batches=self._accumulate_grad_batches)\n\n        for batch in data_batches:\n            if self._p_train_step is None:\n                self.setup_running_step(dummy_batch=batch)\n\n            train_rng = self._deployer.gen_rng()\n            if self.mesh is None:\n                train_rng = jax.random.split(\n                    train_rng, num=jax.process_count())[jax.process_index()]\n                train_rng = shard_prng_key(train_rng)\n            self._state, metrics = self._deployer.run_model_step(\n                step_fn=self._p_train_step,\n                input_args=(train_rng, self._state, batch))\n\n            if self.mesh is None:\n                metrics = unreplicate(metrics)\n            data_batches.set_postfix(**metrics)\n            self._deployer.log_metrics(metrics=metrics, step=self.step)\n\n    def eval_loss(self, examples, per_device_batch_size, desc=None):\n        \"\"\"Evaluates the loss on the provided examples.\n\n        Args:\n            examples (list): Evaluation examples in list.\n            per_device_batch_size (int): The batch size per device.\n            desc (str): Description in the progress bar.\n\n        Returns:\n            (float): The average loss over the evaluation examples.\n        \"\"\"\n        data_batches = self._deployer.get_model_input_batches(\n            examples=examples,\n            per_device_batch_size=per_device_batch_size,\n            collate_fn=self._collate_fn,\n            shuffle=False,\n            shuffle_rng=None,\n            desc=f'Evaluating ({desc})' if desc is not None else 'Evaluating')\n\n        losses = []\n        for batch in data_batches:\n            if self._p_eval_step is None:\n                self.setup_running_step(dummy_batch=batch)\n\n            metrics = self._deployer.run_model_step(\n                step_fn=self._p_eval_step, input_args=(self._state, batch))\n\n            if self.mesh is None:\n                metrics = unreplicate(metrics)\n\n            losses.append(metrics['loss'].item())\n            data_batches.set_postfix(**metrics)\n\n        return np.mean(losses).item()\n\n    def fit(self,\n            train_examples,\n            per_device_batch_size,\n            n_epochs,\n            eval_examples=None,\n            eval_per_device_batch_size=None,\n            eval_loss=True,\n            eval_predictor=None,\n            eval_metric_fn=None,\n            eval_sanity_check=True,\n            save_every_ckpt=False,\n            save_last_ckpt=False,\n            save_argmin_ckpt_by_metrics=None,\n            save_argmax_ckpt_by_metrics=None,\n            save_opt_states=True,\n            save_float_dtype=None):\n        \"\"\"Fits the model on the training data for a given number of epochs,\n        optionally evaluating and saving checkpoints.\n\n        Args:\n            train_examples (list or Callable): Training examples, can be a\n                list or a function of epoch_idx (for assigning different\n                examples in separate epochs/chunks),\n                e.g., `train_examples=lambda epoch_idx: load_data(chunk_idx)`\n            per_device_batch_size (int): The batch size per device.\n            n_epochs (int): Number of epochs to train.\n            eval_examples (list): Examples for evaluation and prediction.\n            eval_per_device_batch_size (int): Batch size for evaluation\n            eval_loss (bool): Whether to evaluate loss.\n            eval_predictor (`redco.Predictor`): Predicting on `eval_examples`.\n            eval_metric_fn (Callable): Metric function for prediction.\n            eval_sanity_check (bool): if to run a sanity check for\n                evaluation & predict functions before training.\n            save_every_ckpt (bool): if to save a ckpt after every epoch.\n            save_last_ckpt (bool): Whether to save the last checkpoint.\n            save_argmin_ckpt_by_metrics (list[str]): Metrics to save checkpoints\n                based on minimum values.\n            save_argmax_ckpt_by_metrics (list[str]): Metrics to save checkpoints\n                based on maximum values.\n            save_opt_states (bool): of to save optimizer states in ckpts.\n            save_float_dtype (bool): The data type for saving checkpoints.\n        \"\"\"\n        if eval_per_device_batch_size is None:\n            eval_per_device_batch_size = per_device_batch_size\n\n        if save_argmax_ckpt_by_metrics is None:\n            save_argmax_ckpt_by_metrics = []\n        if save_argmin_ckpt_by_metrics is None:\n            save_argmin_ckpt_by_metrics = []\n        min_metrics, max_metrics = {}, {}\n\n        if os.path.exists(f'{self.workdir}/min_metrics.json'):\n            min_metrics = json.load(open(\n                f'{self.workdir}/min_metrics.json'))\n            self._deployer.log_info(\n                json.dumps(min_metrics, indent=4), title='Detected min_metrics')\n\n        if os.path.exists(f'{self.workdir}/max_metrics.json'):\n            max_metrics = json.load(open(\n                f'{self.workdir}/max_metrics.json'))\n            self._deployer.log_info(\n                json.dumps(max_metrics, indent=4), title='Detected max_metrics')\n\n        if eval_sanity_check and eval_examples is not None:\n            rng_backup = self._deployer._rng\n            _, eval_global_micro_batch_size = \\\n                self._deployer.get_local_global_micro_batch_size(\n                    per_device_batch_size=eval_per_device_batch_size)\n\n            if eval_loss:\n                self.eval_loss(\n                    examples=eval_examples[:eval_global_micro_batch_size],\n                    per_device_batch_size=eval_per_device_batch_size,\n                    desc=f'Sanity check')\n                self._deployer.log_info(\n                    'Sanity check (for evaluation loss) passed.')\n\n            if eval_predictor is not None:\n                preds = eval_predictor.predict(\n                    examples=eval_examples[:eval_global_micro_batch_size],\n                    params=self._state.params,\n                    params_replicated=(self.mesh is None),\n                    params_sharded=(self.mesh is not None),\n                    per_device_batch_size=eval_per_device_batch_size,\n                    desc=f'Sanity check')\n                self._deployer.log_info(\n                    'Sanity check (for prediction) passed.')\n\n                if eval_metric_fn is not None:\n                    json.dumps(eval_metric_fn(\n                        examples=eval_examples[:eval_global_micro_batch_size],\n                        preds=preds))\n                    self._deployer.log_info(\n                        'Sanity check (for evaluation metrics) passed.')\n\n            self._deployer._rng = rng_backup\n\n        for epoch_idx in range(self._init_epoch_idx, n_epochs):\n            if isinstance(train_examples, list):\n                epoch_train_examples = train_examples\n            else:\n                epoch_train_examples = train_examples(epoch_idx=epoch_idx)\n\n            self.train(\n                examples=epoch_train_examples,\n                per_device_batch_size=per_device_batch_size,\n                desc=f'epoch {epoch_idx} / {n_epochs}')\n\n            save_ckpt_kwargs = {\n                'epoch_idx': epoch_idx,\n                'save_opt_state': save_opt_states,\n                'float_dtype': save_float_dtype\n            }\n\n            if eval_examples is None:\n                self._deployer.log_info(\n                    'No evaluation cuz \\'eval_examples\\' is None.')\n            else:\n                eval_metrics = {}\n\n                if eval_loss:\n                    loss = self.eval_loss(\n                        examples=eval_examples,\n                        per_device_batch_size=eval_per_device_batch_size,\n                        desc=f'epoch {epoch_idx} / {n_epochs}')\n                    eval_metrics['loss'] = loss\n\n                if eval_predictor is not None:\n                    preds = eval_predictor.predict(\n                        examples=eval_examples,\n                        params=self._state.params,\n                        params_replicated=(self.mesh is None),\n                        params_sharded=(self.mesh is not None),\n                        per_device_batch_size=eval_per_device_batch_size,\n                        desc=f'epoch {epoch_idx} / {n_epochs}')\n\n                    if eval_metric_fn is not None:\n                        eval_metrics.update(eval_metric_fn(\n                            examples=eval_examples, preds=preds))\n\n                    eval_outputs = [\n                        {'example': example, 'pred': pred}\n                        for example, pred in zip(eval_examples, preds)]\n\n                    self._deployer.save_outputs(\n                        outputs=eval_outputs,\n                        desc=f'epoch{epoch_idx}',\n                        step=self.step)\n\n                self._deployer.log_info(\n                    info=json.dumps(eval_metrics, indent=4),\n                    title=f'Eval results',\n                    step=self.step)\n                self._deployer.log_metrics(metrics={\n                    f'eval_{key}': value\n                    for key, value in eval_metrics.items()\n                }, step=self.step)\n\n                if self.workdir is not None:\n                    result_filepath = \\\n                        f'{self.workdir}/eval_results_epoch{epoch_idx}.json'\n                    json.dump(\n                        eval_metrics, open(result_filepath, 'w'), indent=4)\n                    self._deployer.log_info(\n                        f'eval_results saved into {result_filepath}.')\n\n                for key in save_argmin_ckpt_by_metrics:\n                    assert self.workdir is not None\n                    if eval_metrics[key] < min_metrics.get(key, float('inf')):\n                        min_metrics[key] = eval_metrics[key]\n\n                        if jax.process_index() == 0:\n                            self._deployer.log_info(\n                                f'minimal {key} updated to {min_metrics[key]}')\n                            json.dump(min_metrics, open(\n                                f'{self.workdir}/min_metrics.json', 'w'))\n\n                        self.save_ckpt(\n                            ckpt_name=f'min_{key}', **save_ckpt_kwargs)\n\n                for key in save_argmax_ckpt_by_metrics:\n                    assert self.workdir is not None\n                    if eval_metrics[key] > max_metrics.get(key, float('-inf')):\n                        max_metrics[key] = eval_metrics[key]\n\n                        if jax.process_index() == 0:\n                            self._deployer.log_info(\n                                f'maximal {key} updated to {max_metrics[key]}')\n                            json.dump(max_metrics, open(\n                                f'{self.workdir}/max_metrics.json', 'w'))\n\n                        self.save_ckpt(\n                            ckpt_name=f'max_{key}', **save_ckpt_kwargs)\n\n            if save_every_ckpt:\n                self.save_ckpt(\n                    ckpt_name=f'epoch_{epoch_idx}', **save_ckpt_kwargs)\n            elif save_last_ckpt:\n                self.save_ckpt(ckpt_name='last', **save_ckpt_kwargs)\n\n    def save_ckpt(self, epoch_idx, ckpt_name, save_opt_state, float_dtype):\n        \"\"\"Saves a checkpoint into `{self.workdir}/ckpts`.\n\n        Args:\n            epoch_idx (int): The current epoch index.\n            ckpt_name (str): The name of the checkpoint.\n            save_opt_state (bool): Whether to save the optimizer state.\n            float_dtype (`jax.numpy.dtype`): Data type for saving float params.\n        \"\"\"\n        if self.mesh is None:\n            params = jax.tree.map(\n                fully_replicated_host_local_array_to_global_array,\n                self._state.params)\n        else:\n            params = self._state.params\n\n        opt_state = None\n        if save_opt_state:\n            if self.mesh is None:\n                opt_state = jax.tree.map(\n                    fully_replicated_host_local_array_to_global_array,\n                    self._state.opt_state)\n            else:\n                opt_state = self._state.opt_state\n\n        ckpt_dir = f'{self.workdir}/ckpts/{ckpt_name}'\n        self._deployer.save_ckpt(\n            ckpt_dir=ckpt_dir,\n            params=params,\n            opt_state=opt_state,\n            float_dtype=float_dtype,\n            step=self.step,\n            epoch_idx=epoch_idx)\n\n        if jax.process_index() == 0:\n            open(f'{self.workdir}/ckpts/last_ckpt.txt', 'w').write(ckpt_name)\n            self._deployer.log_info(f'last ckpt updated -- {ckpt_dir}')\n\n    @property\n    def step(self):\n        \"\"\"Returns the current training step.\"\"\"\n        if self.mesh is None:\n            return unreplicate(self._state.step).item()\n        else:\n            return self._state.step.item()\n\n    @property\n    def workdir(self):\n        \"\"\"Returns the working directory for saving checkpoints and logs.\"\"\"\n        return self._deployer.workdir\n\n    @property\n    def mesh(self):\n        \"\"\"Returns the mesh used for distributed training.\"\"\"\n        return self._deployer.mesh\n\n    @property\n    def state(self):\n        \"\"\"Returns the current training state.\"\"\"\n        return self._state\n
    "},{"location":"trainer/#redco.trainers.trainer.Trainer.mesh","title":"mesh property","text":"

    Returns the mesh used for distributed training.

    "},{"location":"trainer/#redco.trainers.trainer.Trainer.state","title":"state property","text":"

    Returns the current training state.

    "},{"location":"trainer/#redco.trainers.trainer.Trainer.step","title":"step property","text":"

    Returns the current training step.

    "},{"location":"trainer/#redco.trainers.trainer.Trainer.workdir","title":"workdir property","text":"

    Returns the working directory for saving checkpoints and logs.

    "},{"location":"trainer/#redco.trainers.trainer.Trainer.__init__","title":"__init__(deployer, collate_fn, apply_fn, loss_fn, params, optimizer, opt_state=None, compute_dtype=jnp.float32, last_ckpt_info=None, lr_schedule_fn=None, accumulate_grad_batches=None, params_sharding_rules=None)","text":"

    Initializes the Trainer with initial parameters, etc.

    Parameters:

    Name Type Description Default deployer Deployer

    A deployer supporting low-level operations.

    required collate_fn Callable

    The function converting a data batch to model inputs, e.g., tokenizing sentences into input_ids.

    required apply_fn Callable

    The function to apply the model, such as model.apply for Flax modules, or model itself for HuggingFace models. It would be set as state.apply_fn, and used in loss_fn.

    required loss_fn Callable

    The loss function converting model inputs to a scalar loss, e.g., computing cross-entropy loss from input_ids.

    required params dict

    Initial model parameters.

    required optimizer `optax.optimizer`

    The optimizer used for training.

    required opt_state dict

    optimizer state.

    None compute_dtype dtype

    Computation dtype, e.g., jnp.bfloat16, independent of param dtypes. (for mixed-precision training)

    float32 last_ckpt_info dict

    the beginning step and epoch.

    None lr_schedule_fn Callable

    The learning rate schedule function converting step to learning rate.

    None accumulate_grad_batches int

    Gradient accumulation step.

    None params_sharding_rules list

    Sharding rules.

    None Source code in redco/trainers/trainer.py
    def __init__(self,\n             deployer,\n             collate_fn,\n             apply_fn,\n             loss_fn,\n             params,\n             optimizer,\n             opt_state=None,\n             compute_dtype=jnp.float32,\n             last_ckpt_info=None,\n             lr_schedule_fn=None,\n             accumulate_grad_batches=None,\n             params_sharding_rules=None):\n    \"\"\"Initializes the Trainer with initial parameters, etc.\n\n    Args:\n        deployer (Deployer): A deployer supporting low-level operations.\n        collate_fn (Callable): The function converting a data batch to model\n            inputs, e.g., tokenizing sentences into input_ids.\n        apply_fn (Callable): The function to apply the model, such as\n            model.apply for Flax modules, or model itself for HuggingFace\n            models. It would be set as state.apply_fn, and used in loss_fn.\n        loss_fn (Callable): The loss function converting model inputs to a\n            scalar loss, e.g., computing cross-entropy loss from input_ids.\n        params (dict): Initial model parameters.\n        optimizer (`optax.optimizer`): The optimizer used for training.\n        opt_state (dict): optimizer state.\n        compute_dtype (dtype): Computation dtype, e.g., `jnp.bfloat16`,\n            independent of param dtypes. (for mixed-precision training)\n        last_ckpt_info (dict): the beginning step and epoch.\n        lr_schedule_fn (Callable): The learning rate schedule\n            function converting step to learning rate.\n        accumulate_grad_batches (int): Gradient accumulation step.\n        params_sharding_rules (list): Sharding rules.\n    \"\"\"\n    self._deployer = deployer\n    self._collate_fn = collate_fn\n    self._apply_fn = apply_fn\n    self._loss_fn = loss_fn\n    self._optimizer = optimizer\n    self._compute_dtype = compute_dtype\n    self._lr_schedule_fn = lr_schedule_fn\n    self._accumulate_grad_batches = accumulate_grad_batches\n    self._params_sharding_rules = params_sharding_rules\n\n    self._state = None\n    self._state_spec = None\n    self._p_train_step = None\n    self._p_eval_step = None\n\n    self._init_step = 0\n    self._init_epoch_idx = 0\n    if last_ckpt_info is not None:\n        self._init_step = last_ckpt_info.get('step', 0)\n        self._init_epoch_idx = last_ckpt_info.get('epoch_idx', -1) + 1\n\n    n_params = sum([param.size for param in jax.tree.leaves(params)])\n    self._deployer.log_info(f'{n_params:,}', title='Parameters')\n\n    self.set_train_state(\n        apply_fn=self._apply_fn,\n        params=params,\n        optimizer=self._optimizer,\n        step=self._init_step,\n        opt_state=opt_state)\n
    "},{"location":"trainer/#redco.trainers.trainer.Trainer.eval_loss","title":"eval_loss(examples, per_device_batch_size, desc=None)","text":"

    Evaluates the loss on the provided examples.

    Parameters:

    Name Type Description Default examples list

    Evaluation examples in list.

    required per_device_batch_size int

    The batch size per device.

    required desc str

    Description in the progress bar.

    None

    Returns:

    Type Description float

    The average loss over the evaluation examples.

    Source code in redco/trainers/trainer.py
    def eval_loss(self, examples, per_device_batch_size, desc=None):\n    \"\"\"Evaluates the loss on the provided examples.\n\n    Args:\n        examples (list): Evaluation examples in list.\n        per_device_batch_size (int): The batch size per device.\n        desc (str): Description in the progress bar.\n\n    Returns:\n        (float): The average loss over the evaluation examples.\n    \"\"\"\n    data_batches = self._deployer.get_model_input_batches(\n        examples=examples,\n        per_device_batch_size=per_device_batch_size,\n        collate_fn=self._collate_fn,\n        shuffle=False,\n        shuffle_rng=None,\n        desc=f'Evaluating ({desc})' if desc is not None else 'Evaluating')\n\n    losses = []\n    for batch in data_batches:\n        if self._p_eval_step is None:\n            self.setup_running_step(dummy_batch=batch)\n\n        metrics = self._deployer.run_model_step(\n            step_fn=self._p_eval_step, input_args=(self._state, batch))\n\n        if self.mesh is None:\n            metrics = unreplicate(metrics)\n\n        losses.append(metrics['loss'].item())\n        data_batches.set_postfix(**metrics)\n\n    return np.mean(losses).item()\n
    "},{"location":"trainer/#redco.trainers.trainer.Trainer.fit","title":"fit(train_examples, per_device_batch_size, n_epochs, eval_examples=None, eval_per_device_batch_size=None, eval_loss=True, eval_predictor=None, eval_metric_fn=None, eval_sanity_check=True, save_every_ckpt=False, save_last_ckpt=False, save_argmin_ckpt_by_metrics=None, save_argmax_ckpt_by_metrics=None, save_opt_states=True, save_float_dtype=None)","text":"

    Fits the model on the training data for a given number of epochs, optionally evaluating and saving checkpoints.

    Parameters:

    Name Type Description Default train_examples list or Callable

    Training examples, can be a list or a function of epoch_idx (for assigning different examples in separate epochs/chunks), e.g., train_examples=lambda epoch_idx: load_data(chunk_idx)

    required per_device_batch_size int

    The batch size per device.

    required n_epochs int

    Number of epochs to train.

    required eval_examples list

    Examples for evaluation and prediction.

    None eval_per_device_batch_size int

    Batch size for evaluation

    None eval_loss bool

    Whether to evaluate loss.

    True eval_predictor `redco.Predictor`

    Predicting on eval_examples.

    None eval_metric_fn Callable

    Metric function for prediction.

    None eval_sanity_check bool

    if to run a sanity check for evaluation & predict functions before training.

    True save_every_ckpt bool

    if to save a ckpt after every epoch.

    False save_last_ckpt bool

    Whether to save the last checkpoint.

    False save_argmin_ckpt_by_metrics list[str]

    Metrics to save checkpoints based on minimum values.

    None save_argmax_ckpt_by_metrics list[str]

    Metrics to save checkpoints based on maximum values.

    None save_opt_states bool

    of to save optimizer states in ckpts.

    True save_float_dtype bool

    The data type for saving checkpoints.

    None Source code in redco/trainers/trainer.py
    def fit(self,\n        train_examples,\n        per_device_batch_size,\n        n_epochs,\n        eval_examples=None,\n        eval_per_device_batch_size=None,\n        eval_loss=True,\n        eval_predictor=None,\n        eval_metric_fn=None,\n        eval_sanity_check=True,\n        save_every_ckpt=False,\n        save_last_ckpt=False,\n        save_argmin_ckpt_by_metrics=None,\n        save_argmax_ckpt_by_metrics=None,\n        save_opt_states=True,\n        save_float_dtype=None):\n    \"\"\"Fits the model on the training data for a given number of epochs,\n    optionally evaluating and saving checkpoints.\n\n    Args:\n        train_examples (list or Callable): Training examples, can be a\n            list or a function of epoch_idx (for assigning different\n            examples in separate epochs/chunks),\n            e.g., `train_examples=lambda epoch_idx: load_data(chunk_idx)`\n        per_device_batch_size (int): The batch size per device.\n        n_epochs (int): Number of epochs to train.\n        eval_examples (list): Examples for evaluation and prediction.\n        eval_per_device_batch_size (int): Batch size for evaluation\n        eval_loss (bool): Whether to evaluate loss.\n        eval_predictor (`redco.Predictor`): Predicting on `eval_examples`.\n        eval_metric_fn (Callable): Metric function for prediction.\n        eval_sanity_check (bool): if to run a sanity check for\n            evaluation & predict functions before training.\n        save_every_ckpt (bool): if to save a ckpt after every epoch.\n        save_last_ckpt (bool): Whether to save the last checkpoint.\n        save_argmin_ckpt_by_metrics (list[str]): Metrics to save checkpoints\n            based on minimum values.\n        save_argmax_ckpt_by_metrics (list[str]): Metrics to save checkpoints\n            based on maximum values.\n        save_opt_states (bool): of to save optimizer states in ckpts.\n        save_float_dtype (bool): The data type for saving checkpoints.\n    \"\"\"\n    if eval_per_device_batch_size is None:\n        eval_per_device_batch_size = per_device_batch_size\n\n    if save_argmax_ckpt_by_metrics is None:\n        save_argmax_ckpt_by_metrics = []\n    if save_argmin_ckpt_by_metrics is None:\n        save_argmin_ckpt_by_metrics = []\n    min_metrics, max_metrics = {}, {}\n\n    if os.path.exists(f'{self.workdir}/min_metrics.json'):\n        min_metrics = json.load(open(\n            f'{self.workdir}/min_metrics.json'))\n        self._deployer.log_info(\n            json.dumps(min_metrics, indent=4), title='Detected min_metrics')\n\n    if os.path.exists(f'{self.workdir}/max_metrics.json'):\n        max_metrics = json.load(open(\n            f'{self.workdir}/max_metrics.json'))\n        self._deployer.log_info(\n            json.dumps(max_metrics, indent=4), title='Detected max_metrics')\n\n    if eval_sanity_check and eval_examples is not None:\n        rng_backup = self._deployer._rng\n        _, eval_global_micro_batch_size = \\\n            self._deployer.get_local_global_micro_batch_size(\n                per_device_batch_size=eval_per_device_batch_size)\n\n        if eval_loss:\n            self.eval_loss(\n                examples=eval_examples[:eval_global_micro_batch_size],\n                per_device_batch_size=eval_per_device_batch_size,\n                desc=f'Sanity check')\n            self._deployer.log_info(\n                'Sanity check (for evaluation loss) passed.')\n\n        if eval_predictor is not None:\n            preds = eval_predictor.predict(\n                examples=eval_examples[:eval_global_micro_batch_size],\n                params=self._state.params,\n                params_replicated=(self.mesh is None),\n                params_sharded=(self.mesh is not None),\n                per_device_batch_size=eval_per_device_batch_size,\n                desc=f'Sanity check')\n            self._deployer.log_info(\n                'Sanity check (for prediction) passed.')\n\n            if eval_metric_fn is not None:\n                json.dumps(eval_metric_fn(\n                    examples=eval_examples[:eval_global_micro_batch_size],\n                    preds=preds))\n                self._deployer.log_info(\n                    'Sanity check (for evaluation metrics) passed.')\n\n        self._deployer._rng = rng_backup\n\n    for epoch_idx in range(self._init_epoch_idx, n_epochs):\n        if isinstance(train_examples, list):\n            epoch_train_examples = train_examples\n        else:\n            epoch_train_examples = train_examples(epoch_idx=epoch_idx)\n\n        self.train(\n            examples=epoch_train_examples,\n            per_device_batch_size=per_device_batch_size,\n            desc=f'epoch {epoch_idx} / {n_epochs}')\n\n        save_ckpt_kwargs = {\n            'epoch_idx': epoch_idx,\n            'save_opt_state': save_opt_states,\n            'float_dtype': save_float_dtype\n        }\n\n        if eval_examples is None:\n            self._deployer.log_info(\n                'No evaluation cuz \\'eval_examples\\' is None.')\n        else:\n            eval_metrics = {}\n\n            if eval_loss:\n                loss = self.eval_loss(\n                    examples=eval_examples,\n                    per_device_batch_size=eval_per_device_batch_size,\n                    desc=f'epoch {epoch_idx} / {n_epochs}')\n                eval_metrics['loss'] = loss\n\n            if eval_predictor is not None:\n                preds = eval_predictor.predict(\n                    examples=eval_examples,\n                    params=self._state.params,\n                    params_replicated=(self.mesh is None),\n                    params_sharded=(self.mesh is not None),\n                    per_device_batch_size=eval_per_device_batch_size,\n                    desc=f'epoch {epoch_idx} / {n_epochs}')\n\n                if eval_metric_fn is not None:\n                    eval_metrics.update(eval_metric_fn(\n                        examples=eval_examples, preds=preds))\n\n                eval_outputs = [\n                    {'example': example, 'pred': pred}\n                    for example, pred in zip(eval_examples, preds)]\n\n                self._deployer.save_outputs(\n                    outputs=eval_outputs,\n                    desc=f'epoch{epoch_idx}',\n                    step=self.step)\n\n            self._deployer.log_info(\n                info=json.dumps(eval_metrics, indent=4),\n                title=f'Eval results',\n                step=self.step)\n            self._deployer.log_metrics(metrics={\n                f'eval_{key}': value\n                for key, value in eval_metrics.items()\n            }, step=self.step)\n\n            if self.workdir is not None:\n                result_filepath = \\\n                    f'{self.workdir}/eval_results_epoch{epoch_idx}.json'\n                json.dump(\n                    eval_metrics, open(result_filepath, 'w'), indent=4)\n                self._deployer.log_info(\n                    f'eval_results saved into {result_filepath}.')\n\n            for key in save_argmin_ckpt_by_metrics:\n                assert self.workdir is not None\n                if eval_metrics[key] < min_metrics.get(key, float('inf')):\n                    min_metrics[key] = eval_metrics[key]\n\n                    if jax.process_index() == 0:\n                        self._deployer.log_info(\n                            f'minimal {key} updated to {min_metrics[key]}')\n                        json.dump(min_metrics, open(\n                            f'{self.workdir}/min_metrics.json', 'w'))\n\n                    self.save_ckpt(\n                        ckpt_name=f'min_{key}', **save_ckpt_kwargs)\n\n            for key in save_argmax_ckpt_by_metrics:\n                assert self.workdir is not None\n                if eval_metrics[key] > max_metrics.get(key, float('-inf')):\n                    max_metrics[key] = eval_metrics[key]\n\n                    if jax.process_index() == 0:\n                        self._deployer.log_info(\n                            f'maximal {key} updated to {max_metrics[key]}')\n                        json.dump(max_metrics, open(\n                            f'{self.workdir}/max_metrics.json', 'w'))\n\n                    self.save_ckpt(\n                        ckpt_name=f'max_{key}', **save_ckpt_kwargs)\n\n        if save_every_ckpt:\n            self.save_ckpt(\n                ckpt_name=f'epoch_{epoch_idx}', **save_ckpt_kwargs)\n        elif save_last_ckpt:\n            self.save_ckpt(ckpt_name='last', **save_ckpt_kwargs)\n
    "},{"location":"trainer/#redco.trainers.trainer.Trainer.save_ckpt","title":"save_ckpt(epoch_idx, ckpt_name, save_opt_state, float_dtype)","text":"

    Saves a checkpoint into {self.workdir}/ckpts.

    Parameters:

    Name Type Description Default epoch_idx int

    The current epoch index.

    required ckpt_name str

    The name of the checkpoint.

    required save_opt_state bool

    Whether to save the optimizer state.

    required float_dtype `jax.numpy.dtype`

    Data type for saving float params.

    required Source code in redco/trainers/trainer.py
    def save_ckpt(self, epoch_idx, ckpt_name, save_opt_state, float_dtype):\n    \"\"\"Saves a checkpoint into `{self.workdir}/ckpts`.\n\n    Args:\n        epoch_idx (int): The current epoch index.\n        ckpt_name (str): The name of the checkpoint.\n        save_opt_state (bool): Whether to save the optimizer state.\n        float_dtype (`jax.numpy.dtype`): Data type for saving float params.\n    \"\"\"\n    if self.mesh is None:\n        params = jax.tree.map(\n            fully_replicated_host_local_array_to_global_array,\n            self._state.params)\n    else:\n        params = self._state.params\n\n    opt_state = None\n    if save_opt_state:\n        if self.mesh is None:\n            opt_state = jax.tree.map(\n                fully_replicated_host_local_array_to_global_array,\n                self._state.opt_state)\n        else:\n            opt_state = self._state.opt_state\n\n    ckpt_dir = f'{self.workdir}/ckpts/{ckpt_name}'\n    self._deployer.save_ckpt(\n        ckpt_dir=ckpt_dir,\n        params=params,\n        opt_state=opt_state,\n        float_dtype=float_dtype,\n        step=self.step,\n        epoch_idx=epoch_idx)\n\n    if jax.process_index() == 0:\n        open(f'{self.workdir}/ckpts/last_ckpt.txt', 'w').write(ckpt_name)\n        self._deployer.log_info(f'last ckpt updated -- {ckpt_dir}')\n
    "},{"location":"trainer/#redco.trainers.trainer.Trainer.set_train_state","title":"set_train_state(apply_fn, params, optimizer, step, opt_state=None)","text":"

    Sets/Resets the training state with given parameters and optimizer.

    Parameters:

    Name Type Description Default apply_fn Callable

    The function to apply the model.

    required params dict

    Model parameters.

    required optimizer dict

    The optimizer used for training.

    required step int

    The training step.

    required opt_state dict

    The state of the optimizer.

    None Source code in redco/trainers/trainer.py
    def set_train_state(\n        self, apply_fn, params, optimizer, step, opt_state=None):\n    \"\"\"Sets/Resets the training state with given parameters and optimizer.\n\n    Args:\n        apply_fn (Callable): The function to apply the model.\n        params (dict): Model parameters.\n        optimizer (dict): The optimizer used for training.\n        step (int): The training step.\n        opt_state (dict): The state of the optimizer.\n    \"\"\"\n    self._deployer.log_info('Setting train_state ...')\n    params = freeze(params)\n\n    if self.mesh is None:\n        params = jax.device_put(params, jax.local_devices()[0])\n        if opt_state is None:\n            self._deployer.log_info('Initializing opt_state ...')\n            opt_state = optimizer.init(params)\n        else:\n            opt_state = jax.device_put(opt_state, jax.local_devices()[0])\n\n        self._state = train_state.TrainState(\n            step=step,\n            apply_fn=apply_fn,\n            params=params,\n            tx=optimizer,\n            opt_state=opt_state)\n        self._state = replicate(self._state)\n    else:\n        params_spec = self._deployer.get_params_spec(\n            params_shape_or_params=params,\n            params_sharding_rules=self._params_sharding_rules)\n        params = self._deployer.shard_params(\n            params=params, params_spec=params_spec)\n\n        if opt_state is None:\n            self._deployer.log_info('Initializing opt_state ...')\n            opt_state = optimizer.init(params)\n\n        opt_state_spec = self._deployer.get_opt_state_spec(\n            params_shape_or_params=params,\n            params_spec=params_spec,\n            optimizer=optimizer)\n        opt_state = self._deployer.shard_params(\n            params=opt_state,\n            params_spec=opt_state_spec,\n            desc='opt_state')\n\n        self._state = train_state.TrainState(\n            apply_fn=apply_fn,\n            params=params,\n            tx=optimizer,\n            opt_state=opt_state,\n            step=step)\n\n        self._state_spec = train_state.TrainState(\n            apply_fn=apply_fn,\n            params=params_spec,\n            tx=optimizer,\n            opt_state=opt_state_spec,\n            step=None)\n
    "},{"location":"trainer/#redco.trainers.trainer.Trainer.setup_running_step","title":"setup_running_step(dummy_batch)","text":"

    Sets up the running step functions for training and evaluation.

    Parameters:

    Name Type Description Default dummy_batch PyTree

    A dummy batch of data.

    required Source code in redco/trainers/trainer.py
    def setup_running_step(self, dummy_batch):\n    \"\"\"Sets up the running step functions for training and evaluation.\n\n    Args:\n        dummy_batch (PyTree): A dummy batch of data.\n    \"\"\"\n    train_step_fn = partial(\n        train_step,\n        loss_fn=self._loss_fn,\n        lr_schedule_fn=self._lr_schedule_fn,\n        mesh=self.mesh,\n        compute_dtype=self._compute_dtype)\n    eval_step_fn = partial(\n        eval_step,\n        loss_fn=self._loss_fn,\n        mesh=self.mesh,\n        compute_dtype=self._compute_dtype)\n\n    if self.mesh is None:\n        self._p_train_step = jax.pmap(train_step_fn, axis_name='dp')\n        self._p_eval_step = jax.pmap(eval_step_fn, axis_name='dp')\n    else:\n        data_spec = jax.tree.map(lambda x: P('dp'), dummy_batch)\n        self._p_train_step = pjit(\n            train_step_fn,\n            in_shardings=(None, self._state_spec, data_spec),\n            out_shardings=(self._state_spec, None),\n            donate_argnums=(1, ))\n        self._p_eval_step = pjit(\n            eval_step_fn,\n            in_shardings=(self._state_spec, data_spec),\n            out_shardings=None)\n
    "},{"location":"trainer/#redco.trainers.trainer.Trainer.train","title":"train(examples, per_device_batch_size, desc=None)","text":"

    Trains the model on the provided examples.

    Parameters:

    Name Type Description Default examples list

    Training examples in python list.

    required per_device_batch_size int

    The batch size per device.

    required desc str

    Description in the progress bar.

    None Source code in redco/trainers/trainer.py
    def train(self, examples, per_device_batch_size, desc=None):\n    \"\"\"Trains the model on the provided examples.\n\n    Args:\n        examples (list): Training examples in python list.\n        per_device_batch_size (int): The batch size per device.\n        desc (str): Description in the progress bar.\n    \"\"\"\n    data_batches = self._deployer.get_model_input_batches(\n        examples=examples,\n        per_device_batch_size=per_device_batch_size,\n        collate_fn=self._collate_fn,\n        shuffle=True,\n        shuffle_rng=self._deployer.gen_rng(),\n        desc=f'Training ({desc})' if desc is not None else 'Training',\n        is_train=True,\n        accumulate_grad_batches=self._accumulate_grad_batches)\n\n    for batch in data_batches:\n        if self._p_train_step is None:\n            self.setup_running_step(dummy_batch=batch)\n\n        train_rng = self._deployer.gen_rng()\n        if self.mesh is None:\n            train_rng = jax.random.split(\n                train_rng, num=jax.process_count())[jax.process_index()]\n            train_rng = shard_prng_key(train_rng)\n        self._state, metrics = self._deployer.run_model_step(\n            step_fn=self._p_train_step,\n            input_args=(train_rng, self._state, batch))\n\n        if self.mesh is None:\n            metrics = unreplicate(metrics)\n        data_batches.set_postfix(**metrics)\n        self._deployer.log_metrics(metrics=metrics, step=self.step)\n
    "}]} \ No newline at end of file