Models API Reference¶
API documentation for the model architectures.
Model Classes¶
BaselineCNN¶
mlops_project.model.BaselineCNN
¶
Bases: LightningModule
Model 1: A simple baseline CNN model.
forward(x)
¶
Forward pass of the BaselineCNN model.
training_step(batch, batch_idx)
¶
Training step for the BaselineCNN model.
validation_step(batch)
¶
Validation step for the BaselineCNN model.
predict_step(batch, batch_idx, dataloader_idx=0)
¶
Prediction step for the BaselineCNN model.
configure_optimizers()
¶
Configure optimizers for the BaselineCNN model.
ResNet¶
mlops_project.model.ResNet
¶
Bases: LightningModule
Miniature Flexible ResNet model.
Strides: - Example stride patterns: Start with stride 2 for each stage to downsample the image. - 3 blocks: strides=[1, 2, 2] → stages: [1] | [2] | [2] - 4 blocks: strides=[1, 1, 2, 2] → stages: [1,1] | [2] | [2] - 5 blocks: strides=[1, 1, 2, 1, 2] → stages: [1,1] | [2,1] | [2]
forward(x)
¶
Forward pass of the ResNet model.
training_step(batch, batch_idx)
¶
Training step for the ResNet model.
validation_step(batch)
¶
Validation step for the ResNet model.
predict_step(batch, batch_idx, dataloader_idx=0)
¶
Prediction step for the ResNet model.
configure_optimizers()
¶
Configure optimizers for the ResNet model.
EfficientNet¶
mlops_project.model.EfficientNet
¶
Bases: LightningModule
EfficientNet model from torchvision. Based on this paper: https://arxiv.org/abs/1905.11946.
forward(x)
¶
Forward pass of the EfficientNet model.
training_step(batch, batch_idx)
¶
Training step for the ResNet model.
validation_step(batch)
¶
Validation step for the ResNet model.
predict_step(batch, batch_idx, dataloader_idx=0)
¶
Prediction step for the ResNet model.
configure_optimizers()
¶
Configure optimizers for the EfficientNet model.
Helper Classes¶
ConvBlock¶
mlops_project.model.ConvBlock
¶
Bases: Module
Building blocks for the baseline CNN model.
forward(x)
¶
Forward pass of the ConvBlock.
ResidualBlock¶
mlops_project.model.ResidualBlock
¶
Bases: Module
ResNet Residual Block. Credits to : https://d2l.ai/chapter_convolutional-modern/resnet.html.
forward(x)
¶
Forward pass of the ResidualBlock.
Usage Examples¶
from mlops_project.model import EfficientNet, ResNet, BaselineCNN
# Create model instances
model = EfficientNet(variant="b3", num_classes=2, pretrained=True)
model = ResNet(num_classes=2)
model = BaselineCNN(num_classes=2)
# Load from checkpoint
model = EfficientNet.load_from_checkpoint("path/to/checkpoint.ckpt")