PyTorch
PyTorch is a tensor and autograd library widely used for machine learning.
The class-resolver
provides several class resolvers and function resolvers
to make it possible to more easily parametrize models and training loops.
- activation_resolver = <class_resolver.api.ClassResolver object>
A resolver for
torch.nn.modules.activation
classes.import torch from class_resolver.contrib.torch import activation_resolver from more_itertools import pairwise from torch import nn from torch.nn import functional as F class TwoLayerPerceptron(nn.Module): def __init__( self, dims: list[int] activation: Hint[nn.Module] = None ) layers = [] for in_features, out_features in pairwise(dims): layers.extend(( nn.Linear(in_features, out_features), activation_resolver.make(activation), )) self.layers = nn.Sequential(*layers) def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: return self.layers(x)
- aggregation_resolver = <class_resolver.func.FunctionResolver object>
A resolver for common aggregation functions in PyTorch including the following functions:
The default value is
torch.mean()
. This resolver can be used like in the following:import torch from class_resolver.contrib.torch import aggregation_resolver # Lookup with string func = aggregation_resolver.lookup("max") arr = torch.tensor([1, 2, 3, 10], dtype=torch.float) assert 10.0 == func(arr).item() # Default lookup gives mean func = aggregation_resolver.lookup(None) arr = torch.tensor([1.0, 2.0, 3.0, 10.0], dtype=torch.float) assert 4.0 == func(arr).item() def first(x): return x[0] # Custom functions pass through func = aggregation_resolver.lookup(first) arr = torch.tensor([1.0, 2.0, 3.0, 10.0], dtype=torch.float) assert 1.0 == func(arr).item()
- initializer_resolver = <class_resolver.func.FunctionResolver object>
A resolver for
torch.nn.init
functions.import torch from class_resolver.contrib.torch import initializer_resolver from torch import nn from torch.nn import functional as F class TwoLayerPerceptron(nn.Module): def __init__( self, in_features: int, hidden_features: int, out_features: int, initializer=nn.init.xavier_normal_, ) self.layer_1 = nn.Linear(in_features, hidden_features) self.layer_2 = nn.Linear(hidden_features, out_features) initializer = initializer_resolver.lookup(initializer) initializer(self.layer_1.weights) initializer(self.layer_1.weights) def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: x = self.layer_1(x) x = F.relu(x) x = self.layer_2(x) x = F.relu(x) return x
- lr_scheduler_resolver = <class_resolver.api.ClassResolver object>
A resolver for learning rate schedulers.
Borrowing from the PyTorch documentation’s example on how to adjust the learning rate, the following example shows how a training loop can be first turned into a funciton then parametrized to accept a LRScheduler hint.
from class_resolver import Hint, OptionalKwargs from class_resolver.contrib.torch import lr_scheduler_resolver from torch import Parameter, nn from torch.optim import SGD from torch.optim.lr_scheduler import _LRScheduler as LRScheduler dataset = ... def train( model: nn.Module, scheduler: Hint[LRScheduler] = "exponential", scheduler_kwargs: OptionalKwargs = None, ): optimizer = SGD(params=model.parameters(), lr=0.1) scheduler = lr_scheduler_resolver.make(scheduler, scheduler_kwargs, optimizer=optimizer) for epoch in range(20): for input, target in dataset: optimizer.zero_grad() output = model(input) loss = loss_fn(output, target) loss.backward() optimizer.step() scheduler.step() return model
- margin_activation_resolver = <class_resolver.api.ClassResolver object>
A resolver for a subset of
torch.nn.modules.activation
classes.This resolver fulfills the same idea as
activation_resolver
but it is explicitly limited totorch.nn.ReLU
andtorch.nn.Softplus
for certain scenarios where a margin-style activation is appropriate.
- optimizer_resolver = <class_resolver.api.ClassResolver object>
A resolver for
torch.optim.Optimizer
classes.from class_resolver import Hint, OptionalKwargs from class_resolver.contrib.torch import optimizer_resolver from torch import Parameter, nn from torch.optim import Optimizer dataset = ... def train( model: nn.Module, optimizer: Hint[Optimizer] = "adam", optimizer_kwargs: OptionalKwargs = None, ): optimizer = optimizer_resolver.make( optimizer, optimizer_kwargs, params=model.parameters(), ) for epoch in range(20): for input, target in dataset: optimizer.zero_grad() output = model(input) loss = loss_fn(output, target) loss.backward() optimizer.step() return model