Module seg_mask_modifs.download_models

Expand source code
import os
import torch
import torchvision
from google_drive_downloader import GoogleDriveDownloader as gdd


def maskrcnn_coco(save_path='models/maskrcnn_restnet50_fpn.pt'):
    """ Download and save maskrcnn model

    Args:
        save_path (str, optional): Path to save maskrcnn model. Must end with '.pt' or '.pth'.
                                   Default: 'models/maskrcnn_restnet50_fpn.pt'
    """

    if save_path[-3:] != '.pt' and save_path[-4:] != '.pth':
        raise ValueError('Save path should end with .pt or .pth')

    if save_path == 'models/maskrcnn_restnet50_fpn.pt' and not os.path.exists('models'):
        os.makedirs('models')

    # getting base model from pytorch torchvision
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
    torch.save(model, 'models/maskrcnn_resnet50_fpn.pt')


def deeplab_pascal(save_path='models/deeplab_restnet101.pt'):
    """ Download and save deeplab model

    Args:
        save_path (str, optional): Path to save deeplab model. Must end with '.pt' or '.pth'.
                                   Default: 'models/deeplab_restnet101.pt'
    """

    if save_path[-3:] != '.pt' and save_path[-4:] != '.pth':
        raise ValueError('Save path should end with .pt or .pth')

    if save_path == 'models/deeplab_restnet101.pt' and not os.path.exists('models'):
        os.makedirs('models')

    # getting base model from pytorch torchvision
    model = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=True)
    torch.save(model, 'models/deeplab_restnet101.pt')


def face(save_path='models/face.pth'):
    """ Download and save face model

    Args:
        save_path (str, optional): Path to save face model. Must end with '.pth'. Default: 'models/face.pth'
    """

    if save_path[-4:] != '.pth':
        raise ValueError('Save path should end with .pth')

    if save_path == 'models/face.pth' and not os.path.exists('models'):
        os.makedirs('models')

    gdd.download_file_from_google_drive(file_id='154JgKpzCPW82qINcVieuPH3fZ2e0P812',
                                        dest_path='models/face.pth')

    print('Downloading resnet18 backbone to torch cache')
    import torch.utils.model_zoo as modelzoo

    resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
    _ = modelzoo.load_url(resnet18_url)


def download_all():
    """ Function to download all models with their default names"""

    print('Downloading maskrcnn model to models/maskrcnn_resnet50_fpn.pt')
    maskrcnn_coco()

    print('Downloading deeplab model to models/deeplab_restnet101.pt')
    deeplab_pascal()

    print('Downloading face model to models/face.pth')
    face()

Functions

def deeplab_pascal(save_path='models/deeplab_restnet101.pt')

Download and save deeplab model

Args

save_path : str, optional
Path to save deeplab model. Must end with '.pt' or '.pth'. Default: 'models/deeplab_restnet101.pt'
Expand source code
def deeplab_pascal(save_path='models/deeplab_restnet101.pt'):
    """ Download and save deeplab model

    Args:
        save_path (str, optional): Path to save deeplab model. Must end with '.pt' or '.pth'.
                                   Default: 'models/deeplab_restnet101.pt'
    """

    if save_path[-3:] != '.pt' and save_path[-4:] != '.pth':
        raise ValueError('Save path should end with .pt or .pth')

    if save_path == 'models/deeplab_restnet101.pt' and not os.path.exists('models'):
        os.makedirs('models')

    # getting base model from pytorch torchvision
    model = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=True)
    torch.save(model, 'models/deeplab_restnet101.pt')
def download_all()

Function to download all models with their default names

Expand source code
def download_all():
    """ Function to download all models with their default names"""

    print('Downloading maskrcnn model to models/maskrcnn_resnet50_fpn.pt')
    maskrcnn_coco()

    print('Downloading deeplab model to models/deeplab_restnet101.pt')
    deeplab_pascal()

    print('Downloading face model to models/face.pth')
    face()
def face(save_path='models/face.pth')

Download and save face model

Args

save_path : str, optional
Path to save face model. Must end with '.pth'. Default: 'models/face.pth'
Expand source code
def face(save_path='models/face.pth'):
    """ Download and save face model

    Args:
        save_path (str, optional): Path to save face model. Must end with '.pth'. Default: 'models/face.pth'
    """

    if save_path[-4:] != '.pth':
        raise ValueError('Save path should end with .pth')

    if save_path == 'models/face.pth' and not os.path.exists('models'):
        os.makedirs('models')

    gdd.download_file_from_google_drive(file_id='154JgKpzCPW82qINcVieuPH3fZ2e0P812',
                                        dest_path='models/face.pth')

    print('Downloading resnet18 backbone to torch cache')
    import torch.utils.model_zoo as modelzoo

    resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
    _ = modelzoo.load_url(resnet18_url)
def maskrcnn_coco(save_path='models/maskrcnn_restnet50_fpn.pt')

Download and save maskrcnn model

Args

save_path : str, optional
Path to save maskrcnn model. Must end with '.pt' or '.pth'. Default: 'models/maskrcnn_restnet50_fpn.pt'
Expand source code
def maskrcnn_coco(save_path='models/maskrcnn_restnet50_fpn.pt'):
    """ Download and save maskrcnn model

    Args:
        save_path (str, optional): Path to save maskrcnn model. Must end with '.pt' or '.pth'.
                                   Default: 'models/maskrcnn_restnet50_fpn.pt'
    """

    if save_path[-3:] != '.pt' and save_path[-4:] != '.pth':
        raise ValueError('Save path should end with .pt or .pth')

    if save_path == 'models/maskrcnn_restnet50_fpn.pt' and not os.path.exists('models'):
        os.makedirs('models')

    # getting base model from pytorch torchvision
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
    torch.save(model, 'models/maskrcnn_resnet50_fpn.pt')