kaishi.image.model
¶
Definition for PyTorch model abstraction.
Module Contents¶
-
class
kaishi.image.model.
Model
(n_classes: int = 6, model_arch: str = 'resnet18')¶ Abstraction for working with PyTorch models.
-
vgg16_bn
(self, n_classes: int)¶ Basic VGG16 model with variable number of output classes.
- Parameters
n_classes (int) – number of classes at output layer
- Returns
PyTorch VGG16 model object with batch normalization
- Return type
torchvision.models.vgg16_bn
-
resnet18
(self, n_classes: int)¶ Basic ResNet18 model with specified number of output classes.
- Parameters
n_classes (int) – number of classes at the output layer
- Returns
PyTorch ResNet18 model object
- Return type
torchvision.models.resnet18
-
resnet50
(self, n_classes: int)¶ Basic ResNet50 model with specified number of output classes.
- Parameters
n_classes (int) – number of classes at the output layer
- Returns
PyTorch ResNet50 model object
- Return type
torchvision.models.resnet50
-
predict
(self, numpy_array)¶ Make predictions from a numpy array, where dimensions are (batch, channel, x, y).
- Parameters
numpy_array (numpy.array) – input array to predict
- Returns
predictions, where the dimensions are (batch, output)
- Return type
numpy.array
-