PyTorch

PyTorch is a tensor and autograd library widely used for machine learning. The class-resolver provides several class resolvers and function resolvers to make it possible to more easily parametrize models and training loops.

activation_resolver = <class_resolver.api.ClassResolver object>

A resolver for torch.nn.modules.activation classes.

import torch
from class_resolver.contrib.torch import activation_resolver
from more_itertools import pairwise
from torch import nn
from torch.nn import functional as F

class TwoLayerPerceptron(nn.Module):
    def __init__(
        self,
        dims: list[int]
        activation: Hint[nn.Module] = None
    )
        layers = []
        for in_features, out_features in pairwise(dims):
            layers.extend((
                nn.Linear(in_features, out_features),
                activation_resolver.make(activation),
            ))
        self.layers = nn.Sequential(*layers)

    def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
        return self.layers(x)
aggregation_resolver = <class_resolver.func.FunctionResolver object>

A resolver for common aggregation functions in PyTorch including the following functions:

The default value is torch.mean(). This resolver can be used like in the following:

import torch
from class_resolver.contrib.torch import aggregation_resolver

# Lookup with string
func = aggregation_resolver.lookup("max")
arr = torch.tensor([1, 2, 3, 10], dtype=torch.float)
assert 10.0 == func(arr).item()

# Default lookup gives mean
func = aggregation_resolver.lookup(None)
arr = torch.tensor([1.0, 2.0, 3.0, 10.0], dtype=torch.float)
assert 4.0 == func(arr).item()

def first(x):
    return x[0]

# Custom functions pass through
func = aggregation_resolver.lookup(first)
arr = torch.tensor([1.0, 2.0, 3.0, 10.0], dtype=torch.float)
assert 1.0 == func(arr).item()
initializer_resolver = <class_resolver.func.FunctionResolver object>

A resolver for torch.nn.init functions.

import torch
from class_resolver.contrib.torch import initializer_resolver
from torch import nn
from torch.nn import functional as F

class TwoLayerPerceptron(nn.Module):
    def __init__(
        self,
        in_features: int,
        hidden_features: int,
        out_features: int,
        initializer=nn.init.xavier_normal_,
    )
        self.layer_1 = nn.Linear(in_features, hidden_features)
        self.layer_2 = nn.Linear(hidden_features, out_features)

        initializer = initializer_resolver.lookup(initializer)
        initializer(self.layer_1.weights)
        initializer(self.layer_1.weights)

    def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
        x = self.layer_1(x)
        x = F.relu(x)
        x = self.layer_2(x)
        x = F.relu(x)
        return x
lr_scheduler_resolver = <class_resolver.api.ClassResolver object>

A resolver for learning rate schedulers.

Borrowing from the PyTorch documentation’s example on how to adjust the learning rate, the following example shows how a training loop can be first turned into a funciton then parametrized to accept a LRScheduler hint.

from class_resolver import Hint, OptionalKwargs
from class_resolver.contrib.torch import lr_scheduler_resolver
from torch import Parameter, nn
from torch.optim import SGD
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler

dataset = ...

def train(
    model: nn.Module,
    scheduler: Hint[LRScheduler] = "exponential",
    scheduler_kwargs: OptionalKwargs = None,
):
    optimizer = SGD(params=model.parameters(), lr=0.1)
    scheduler = lr_scheduler_resolver.make(scheduler, scheduler_kwargs, optimizer=optimizer)

    for epoch in range(20):
        for input, target in dataset:
            optimizer.zero_grad()
            output = model(input)
            loss = loss_fn(output, target)
            loss.backward()
            optimizer.step()
        scheduler.step()

    return model
margin_activation_resolver = <class_resolver.api.ClassResolver object>

A resolver for a subset of torch.nn.modules.activation classes.

This resolver fulfills the same idea as activation_resolver but it is explicitly limited to torch.nn.ReLU and torch.nn.Softplus for certain scenarios where a margin-style activation is appropriate.

optimizer_resolver = <class_resolver.api.ClassResolver object>

A resolver for torch.optim.Optimizer classes.

from class_resolver import Hint, OptionalKwargs
from class_resolver.contrib.torch import optimizer_resolver
from torch import Parameter, nn
from torch.optim import Optimizer

dataset = ...

def train(
    model: nn.Module,
    optimizer: Hint[Optimizer] = "adam",
    optimizer_kwargs: OptionalKwargs = None,
):
    optimizer = optimizer_resolver.make(
        optimizer, optimizer_kwargs, params=model.parameters(),
    )

    for epoch in range(20):
        for input, target in dataset:
            optimizer.zero_grad()
            output = model(input)
            loss = loss_fn(output, target)
            loss.backward()
            optimizer.step()

    return model