It's very simple to get Lightning CLI running on Modal infrastructure.
This provides a couple of benefits:
- Automatic class composition from a YAML config file.
- Can spin up infrastructure on demand.
import sys
import lightning as L
from lightning.pytorch.cli import LightningCLI
@app.function(
gpu="A100-40GB",
volumes=volumes,
)
def train(*args, **kwargs):
# NOTE: clear sysargv to allow passing of command-line args to CLI
if len(sys.argv) > 1:
sys.argv[1:] = []
LightningCLI(
model_class=L.LightningModule, # Use base class for subclass mode
datamodule_class=L.LightningDataModule, # Use base class for subclass mode
subclass_mode_model=True, # Enable subclass mode for model
subclass_mode_data=True, # Enable subclass mode for datamodule
seed_everything_default=42,
args=args, # pass the arguments parsed from the command-line
)
@app.local_entrypoint()
def main(*args):
train.remote(*args)
NOTE: as mentioned in the comment, lightning will complain if you pass in args through the CLI and also pass the args
keyword arg to the CLI. To get around this and still allow specifying args from the command-line, I parse them into args
and then remove them from the argv
array.