Introducing fasttransform, a Python library that makes data transformations reversible and extensible through the power of multiple dispatch.
“How did this image get misclassified?”
If you’ve ever trained a machine learning model, you know what comes next: the frustrating journey of trying to understand what your model actually saw. You dig through layers of transformations - normalizations, resizes, augmentations - only to realize you’ll need to write inverse functions just to see your data again. It’s so painful that many of us skip it altogether, debugging our models based on abstract numbers rather than actual data.
Or as OpenAI’s Greg Brockman puts it:
Greg Brockman tweet: “Manual inspection of data has probably the highest value-to-prestige ratio of any activity in machine learning.”
Let’s look at what you might be missing. Here’s a simple example using fastai:
from fastai.vision.all import *
dls = ImageDataLoaders.from_folder(
Path("./huskies_vs_wolves/"),
item_tfms=RandomResizedCrop(128, min_scale=0.35),
batch_tfms=Normalize.from_stats(*imagenet_stats)
)
dls.show_batch() # One line to see our datalearn = Learner(dls, xresnet34(n_out=2), metrics=accuracy)
learn.fit_one_cycle(5, 0.0015)
learn.show_results() # One line to see predictions# Two lines to see the model's biggest mistakes
interp = Interpretation.from_learner(learn)
interp.plot_top_losses(9)With just these four lines, we discover something fascinating: our “wolf detector” isn’t detecting wolves at all - it’s detecting snow! Look at the training data: wolves in snow, huskies in forests. Then look at the predictions: the model fails whenever we flip the backgrounds. Without being able to easily visualize our data, we might never have caught this obvious flaw.
While sophisticated interpretability techniques like LIME¹ can beautifully visualize what parts of the image your model is focusing on (as shown above), often the most valuable insights come from simply being able to look at your data with your own eyes. In this case, a quick visual inspection revealed an obvious dataset bias just as well.
How does fastai do this? Well, it uses Transform – a deceptively simple yet powerful idea that’s been hiding inside fastcore’s codebase. Today, we’re excited to announce that we’ve moved it to its own library: fasttransform, because we believe its applications may go beyond machine learning.
Whether you’re working with images, text, time series, or any other data that needs processing, fasttransform offers a simple promise: if you can transform your data one way, you should be able to transform it back just as easily. No more writing inverse functions, no more losing sight of your data.
Let’s see how it works.
Ever tried to debug a machine learning pipeline by looking at your data? It usually goes something like this:
Let’s make this concrete with a simple example: normalizing an image with PyTorch:
from torchvision import transforms as T
transforms_pt = T.Compose([
T.Resize(256),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(*imagenet_stats)
])
# Load and transform an image
img = Image.open("./huskies_vs_wolves/train/husky/husky_0.jpeg")
img_transformed = transforms_pt(img)
# Try to look at what we did...
show_image(img_transformed);Clipping input data to the valid range for imshow with RGB data (0..1 for floats or 0..255 for integers). Got range -2.1007793..2.2489083. Normalization is a crucial preprocessing step that scales pixel values to have similar ranges (typically mean=0 and standard deviation=1), which helps neural networks train more effectively.
However, the normalization doesn’t really make this picture suitable for inspection with human eyes. To fix this, we need to manually write an inverse transform:
def decode_pt(tensor, mean, std):
"""Decode a normalized PyTorch tensor back to RGB range"""
out = tensor.clone() # Clone to avoid modifying original
for t, m, s in zip(out, mean, std): t.mul_(s).add_(m) # Denormalize
out = out.mul(255).clamp(0, 255).byte() # Scale back to RGB
return out
img_decoded = decode_pt(img_transformed, *imagenet_stats)
show_image(img_decoded);And this is not some obscure problem. This has been a pain point of many ML practicioners for years:
And this was just for a simple normalization. In real projects, you’re probably dealing with:
Each transformation adds another layer of complexity to unwind. And here’s the worst part: because it’s so painful to look at your transformed data, many of us just… don’t. We end up debugging our models based on abstract numbers rather than actual data, hoping our transformations are doing what we think they’re doing.
Remember how easy it was in our fastai example to see exactly what the model was seeing? That’s not magic - it’s the power of reversible transforms. Let’s see how fasttransform makes this possible.
Here’s how fastai handles the same pipeline as the pytorch example of the previous section:
from fastai.vision.all import *
transforms_ft = Pipeline([
PILImage.create,
Resize(256,method="squish"),
Resize(224,method="crop"),
ToTensor(),
IntToFloatTensor(),
Normalize.from_stats(*imagenet_stats, cuda=False)
])
# Transform our image
fpath = Path("./huskies_vs_wolves/train/husky/husky_0.jpeg")
img_transformed = transforms_ft(fpath)
show_image(img_transformed[0]); # Still looks wrong...Clipping input data to the valid range for imshow with RGB data (0..1 for floats or 0..255 for integers). Got range -2.0836544..2.2317834.
# But now the magic:
img_decoded = transforms_ft.decode(img_transformed)
show_image(img_decoded[0]); # That's better!That’s it. No manual inverse functions. No remembering means and standard deviations. Just .decode() and we’re back to something we can actually look at.
fasttransform brings this same capability to your own code. The key insight is that for any transformation you want to apply, you probably already know how to undo it. Let’s look at how this works internally.
The core idea behind fasttransform is simple: pair a transformation with its inverse.
Here’s how you write a reversible normalization transform:
class Normalize(Transform):
def __init__(self, mean=None, std=None):
self.mean = mean
self.std = std
def encodes(self, x): return (x-self.mean) / self.std # forward transform
def decodes(self, x): return x*self.std + self.mean # inverse transformThat’s all it takes.
By defining both encodes and decodes, fasttransform automatically knows how to reverse your transformations. Compare this to our earlier PyTorch example - instead of writing separate forward and inverse functions, we keep them together where they belong.
You might notice the peculiar naming - encodes and decodes with an ‘s’. We’ll explain why later, but it has everything to do with how fasttransform handles different types of data automatically.
When you call decode(), fasttransform is smart about which transforms to reverse. Some transforms, like loading an image or resizing it, don’t need to be undone, you actually want to see what the model sees! Others, like normalization, need to be reversed to be human-readable.
How do you do this? Well, only define a .decodes method if the transform needs to be inverted!
The introduction’s plotting functions used exactly this functionality to turn the transformed inputs back into a human interpretable state.
We’ve seen how making transforms reversible makes it easier to look at your data. But there’s another challenge when working with transforms: different types of data need different transformations.
You see this most often where your inputs and your labels need different transforms. Here the same principle applies. We’d like to keep all those transforms in one place together because we want to be able to undo both of them. For example, we want to transform our categorical labels from strings to integers and back to strings again for human readability. But we don’t want to maintain separate transform pipelines for the inputs and the outputs.
To understand why this is a problem, let’s look at how PyTorch - one of the most popular deep learning frameworks - handles this situation. Here’s an example from the tutorial showing a typical custom dataset:
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform # <- separate target transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform: # <- separate target transform
label = self.target_transform(label)
return image, labelThe transforms for images and labels are separately defined and provided to the dataset class. This separation might seem reasonable at first, but it creates two problems:
Let’s see how fasttransform makes this easier.
Here’s where fasttransform’s approach shines: instead of juggling separate pipelines, it handles both your image and its label in a single transform. When you pass a tuple to a transform, it only applies the relevant transforms. This might sound like a small thing, but it’s a game-changer for real-world machine learning work.
Let’s see this in action.
First, we’ll create a function that loads both an image and its label:
def load_img_and_label(fp): return PILImage.create(fp), parent_label(fp)
load_img_and_label(fpath)
# Output: (PILImage mode=RGB size=375x500, 'husky')Now for the cool part - we can use this function in our transform pipeline with just one small change. Look how clean this is:
transforms_ft = Pipeline([
load_img_and_label, # <-- Load img and label as a tuple
Resize(256,method="squish"),
Resize(224,method="crop"),
ToTensor(),
IntToFloatTensor(),
Normalize.from_stats(*imagenet_stats,cuda=False)
])
out = transforms_ft(fpath)
print((out[0][0,:2,:2,:2], out[1]))
# Output: (TensorImage([[[-0.2856, -0.2856],
# [-0.2856, -0.2856]],
#
# [[ 0.5553, 0.5553],
# [ 0.5553, 0.5553]]]), 'husky')But we’re not done yet! Those string labels (“husky”, “wolf”) need to be converted to numbers for our model. In PyTorch, we’d need a separate transform pipeline for this. With fasttransform, we just add another transform that only applies to strings:
class StrCategorize(Transform):
def __init__(self, vocab):
self.vocab = vocab
self.s2i = {s:i for i,s in enumerate(vocab)}
self.i2s = {i:s for i,s in enumerate(vocab)}
def encodes(self, s:str): return self.s2i[s]
def decodes(self, i:int): return self.i2s[i]
transforms_ft = Pipeline([
load_img_and_label,
Resize(256,method="squish"),
Resize(224,method="crop"),
ToTensor(),
IntToFloatTensor(),
Normalize.from_stats(*imagenet_stats, cuda=False),
StrCategorize(vocab=['husky','wolf']), # <-- Transform is just for the target label
])
out = transforms_ft(fpath)
print((out[0][0,:2,:2,:2], out[1]))
# Output: (TensorImage([[[-0.2856, -0.2856],
# [-0.2856, -0.2856]],
#
# [[ 0.5553, 0.5553],
# [ 0.5553, 0.5553]]]), 0)And you might think “Okay, keeping transforms in one pipeline is nice, but is it really that important?”
Well, one benefit is that now you can also reverse both transforms again in one go:
rev = transforms_ft.decode(out)
print((rev[0][0,:2,:2,:2], rev[1]))
# Output: (TensorImage([[[107, 107],
# [107, 107]],
#
# [[148, 148],
# [148, 148]]]), 'husky')Next we’ll show another example that shows why it’s crucial to keep those transforms in one place: image segmentation.
In segmentation, you’re trying to identify specific regions in an image - like finding a husky in a photo. But here’s the tricky part: both your input image AND your target mask need to be transformed in exactly the same way. And that gets tricky when you use random transforms as a form of data augmentation. To illustrate, if you apply a randomized crop to your image, then you better crop that mask in the exact same way!
Let’s see what this looks like in practice. First, we define a new function which loads both images and their corresponding mask:
fnames = list(Path("./segment_huskies/img/").glob("*"))
fn = fnames[0]
def load_img_msk(fn):
return PILImage.create(fn), PILMask.create(fn.parent.parent / "msk" / fn.name)
img, msk = load_img_msk(fn)
show_images([img,msk])Now, if we want to randomly crop both the image and the mask (a common augmentation technique), they need to be cropped in exactly the same way. If they’re not aligned then your whole training data becomes nonsense.
Here’s how fasttransform handles this:
transforms_ft = Pipeline([
load_img_msk, # <-- New load func for img and mask
RandomResizedCrop(200), # Applied to both img and mask
ToTensor(), # Applied to both img and mask
IntToFloatTensor(), # Only applied to img
Normalize.from_stats(*imagenet_stats,cuda=False) # Only applied to img
])
out = transforms_ft(fn)
out
show_images((out[0][0], out[1]))Clipping input data to the valid range for imshow with RGB data (0..1 for floats or 0..255 for integers). Got range -1.8096584..2.64.
And voila, both the source image and the target mask have been transformed in identical ways.
If these transforms were stored in different pipelines then it would have been a lot harder to keep these transforms in sync. Especially because there was a randomized element in the transform.
Also, note that reversing is still just as easy:
rev = transforms_ft.decode(out)
show_images((rev[0][0], rev[1]))At this point you might be thinking: “This is pretty great - one pipeline handling different types of data, applying on the the relevant transforms where needed. But how does it actually work?”
Well, let’s dive into that next!
The secret sauce that makes Transforms only apply to relevant data types is something called multiple dispatch. Don’t worry if you haven’t heard of it before - it’s a powerful programming concept that’s popular in languages like Julia², but relatively unknown in Python.
Think of multiple dispatch like having different versions of the same function, each designed to handle specific types of data. When you call the function, Python automatically picks the right version based on what you give it.
Python provides an implementation limited to single argument functions out of the box:
from functools import singledispatch
@singledispatch
def greet(x): return "Hello stranger!"
@greet.register
def _(x:str): return f"Hello {x}!"
@greet.register
def _(x:int): return f"Hello number {x}!"
greet(None), greet("Alice"), greet(42)
# Output: ('Hello stranger!', 'Hello Alice!', 'Hello number 42!')Multiple dispatch extends this idea to functions with multiple arguments. While Python’s built-in tools only handle single argument dispatch, the plum library provides true multiple dispatch for any number of arguments. Here’s a simple example to illustrate the concept:
from plum import dispatch
class Dog: pass
class Cat: pass
@dispatch
def greet(a: Cat, b: Dog):
return "Hiss!"
@dispatch
def greet(a: Dog, b: Cat):
return "Grrrr..."
# Let's try it out
cat, dog = Cat(), Dog()
print(greet(cat, dog)) # "Hiss!"
print(greet(dog, cat)) # Grrrr...
# Output:
# Hiss!
# Grrrr...Transform uses plum’s multiple dispatch capabilities internally, but the core idea is the same: the right function is called based on the runtime data types it receives. This is what allows a single pipeline to handle images, labels, masks, and other types of data.
There are three different ways you can define type-specific behavior in your transforms, each suited to different situations. Let’s look at each one in turn.
The simplest way to create a transform is to pass it functions directly. This is great for quick experiments or one-off transforms:
# Method 1: Direct functions
def enc_str(x:str): return f"encoded str: {x=}"
def enc_int(x:int): return f"encoded int: {x=}"
my_transform = Transform(enc=(enc_str,enc_int))
my_transform(("hello", 42))
# Output: ("encoded str: x='hello'", 'encoded int: x=42')You might use this approach when you’re prototyping or when you don’t need to reuse the transform elsewhere in your code. But for more structured code, you’ll probably want to create a proper class…
Subclassing Transform gives you a more organized way to handle different types:
# Method 2: Create a Transform subclass
class MyTransform(Transform):
def encodes(self, x:str): return f"encoded str: {x=}"
def encodes(self, x:int): return f"encoded int: {x=}"
my_transform = MyTransform()
my_transform(("my str", 42))
# Output: ("encoded str: x='my str'", 'encoded int: x=42')Notice something interesting here: in a regular Python class, you can’t define the same method multiple times. But when subclassing from Transform, you can!
The encodes method is automatically set up for multiple dispatch, so Python knows which version to call based on the input type.
But there’s one more way to define transforms, which is particularly useful when you want to extend an existing transform…
# Method 3: Extend with decorators
@MyTransform
def encodes(self, x: float): return f"encoded float: {x=}"
# Now our transform handles three types!
my_transform(("hello", 42, 6.28))
# Output: ("encoded str: x='hello'", 'encoded int: x=42', 'encoded float: x=6.28')This decorator syntax is incredibly useful in real-world applications.
For instance, in fastai, the Normalize transform is defined in the core library to handle images, but other modules can extend it to work with new types:
# In fastai.data.transforms:
class Normalize(Transform): ... # handles image normalization
# In fastai.tabular.core:
@Normalize
def encodes(self, x: pd.DataFrame): ... # adds DataFrame supportThis plugin-like architecture means anyone can extend existing transforms to work with new types of data, without modifying the original code. That’s the power of multiple dispatch in action!
The real power shows up when code is reused and extended in the ecosystem around fastai. Libraries like fastxtend add support for new data types without modifying the original code. Without multiple dispatch, they’d face a classic inheritance problem. Instead, with fasttransform, they can simply register new behaviors for existing transforms.
We’ve seen how fasttransform solves two fundamental problems in data processing:
While these ideas grew out of fastai’s deep learning needs, their applications extend far beyond. Whether you’re processing images, text, time series, or quantum states, fasttransform offers a simple promise: if you can transform your data one way, you should be able to transform it back just as easily.
Ready to try it yourself? Install fasttransform with:
pip install fasttransformCheck out our documentation for more examples and detailed API references. If you were already using fastcore’s dispatch and transform modules, then you might want to take a look at our migration guide.
We’d love to hear how you’re using fasttransform in your own projects!
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。
原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。
如有侵权,请联系 cloudcommunity@tencent.com 删除。