kaishi.image.model

Definition for PyTorch model abstraction.

Module Contents

class kaishi.image.model.Model(n_classes: int = 6, model_arch: str = 'resnet18')

Abstraction for working with PyTorch models.

vgg16_bn(self, n_classes: int)

Basic VGG16 model with variable number of output classes.

Parameters

n_classes (int) – number of classes at output layer

Returns

PyTorch VGG16 model object with batch normalization

Return type

torchvision.models.vgg16_bn

resnet18(self, n_classes: int)

Basic ResNet18 model with specified number of output classes.

Parameters

n_classes (int) – number of classes at the output layer

Returns

PyTorch ResNet18 model object

Return type

torchvision.models.resnet18

resnet50(self, n_classes: int)

Basic ResNet50 model with specified number of output classes.

Parameters

n_classes (int) – number of classes at the output layer

Returns

PyTorch ResNet50 model object

Return type

torchvision.models.resnet50

predict(self, numpy_array)

Make predictions from a numpy array, where dimensions are (batch, channel, x, y).

Parameters

numpy_array (numpy.array) – input array to predict

Returns

predictions, where the dimensions are (batch, output)

Return type

numpy.array