Back to Blog
April 27, 2026
Image Embeddings: Tutorial & Examples

Image Embeddings: Tutorial & Examples

Learn about the concept of image embeddings, their various use cases, and best practices for handling them in data processing workflows.

by Daft Team

With the advent of vision-language models, analyzing images and text data simultaneously to extract deeper insights has become increasingly popular. Unifying multiple data types enables organizations to better understand complex patterns that are not possible with single-modality analysis.

That said, data processing frameworks that support multiple modalities through a unified interface are rare and come with a steep learning curve. A key element of multimodal analysis is image embeddings, which represent the meaning of images as numerical vectors. This article explains what image embeddings are, their various use cases, and how to handle them in a data processing workflow.

Summary of key image embeddings concepts

ConceptDescription
Image embeddingsAn image embedding is a compact vector representation of an image that captures semantic meaning, typically in a lower-dimensional space than the original raw pixel data.
Image embedding use casesImage embeddings are used in image search, product recommendations, caption generation, image tagging, facial recognition, anomaly detection, and visual similarity, among other applications.
Handling image embeddingsImage embeddings are commonly generated and consumed using machine learning frameworks such as PyTorch and TensorFlow.
PyTorch for working with imagesPyTorch is an open-source deep learning framework used for building neural networks. It provides several functions to manipulate images, prepare them for ML inference, and run inferences with PyTorch-based models.
Challenges with integrating PyTorch in typical data engineering workflowsIntegrating PyTorch with typical data processing frameworks like Apache Spark is complex due to challenges such as serialization overheads, differences in GPU management, and dependency management.
Best practices while dealing with image embeddingsUse pretrained embedding models that match the task at hand. Evaluate embeddings quantitatively, monitoring embedding drift, and caching weights, etc. Select a processing framework that provides a unified API for multimodal data to reduce developer effort and time to production.

Understanding image embeddings

Embeddings, in general, are vectors in continuous high-dimensional spaces. Machine learning models consume images as pixel tensors, but embeddings provide a learned representation that is usually more compact and semantically useful for downstream tasks.

Consider the image of a dog, a cat, and a kitten. The mathematical distance between representing the dog and the cat will be higher than that of the vectors representing the cat and the kitten.

Image embeddings thus numerically represent images, capturing their features and semantic context.

Image embedding use cases

Image embeddings are used in many applications.

Image search

A user can search with a text or a similar image to fetch images from a repository. This is primarily used in retail ecommerce to enable better apparel and clothing search. When a user provides an image or text query, it is converted into embeddings. It is then compared against the embeddings of all images in the system, and the closest matches are returned.

Caption generation

Caption generation is widely used on social media for automated content creation. It is also used in entertainment, media, and healthcare image processing. Caption generation works by combining an image embedding model with a language model. The output of the image embedding model is fed to a language model that generates text token by token.

Recommendation systems

Image embeddings are also used in product recommendation systems. They help models select products similar to previous purchases. They can also generate features for the algorithm, such as content-based filtering.

Working with image embeddings

A typical enterprise data pipeline consists of numerous image-related operations, such as image ingestion, preprocessing, embedding generation, model inference, and storing the results to support downstream use cases and analytics. Traditionally, these operations are handled by a combination of data processing and machine learning frameworks.

Preprocessing

A data processing framework such as Spark can ingest image data at scale, while preprocessing steps like resizing and format conversion are often implemented with Python image libraries such as PIL inside UDFs.

Embedding generation

The next step would be to use a machine learning framework such as PyTorch or TensorFlow to generate embeddings with a model like ResNet or a Vision Transformer. The embeddings can then be stored in distributed data storage systems or vector databases. PyTorch can also be used for downstream tasks like classification, object detection, and caption generation that use image embeddings as input.

So the pipeline will begin with PySpark for reading and preparing data, then hand off image batches to PyTorch for image embedding generation, and then back to PySpark for writing and storage.

Example

To understand this better, let us consider an e-commerce platform that uses image embeddings for generating product recommendations. Such use cases require generating embeddings for millions of images. Let's assume the technology stack for this organization is on the AWS cloud.

Engineers start the pipeline using PySpark to read from AWS S3 and perform transformations such as resizing and downsampling.

PySpark can be used to combine the image IDs with sales data stored in another S3 location. Sales data is important when using image embeddings for training systems such as recommendation systems.

The results are then stored in a distributed storage system like Delta Lake. PyTorch is then used to generate image embeddings, and the results are written back to Delta Lake. The whole sequence looks as follows.

Data pipeline operation sequence

Implementing the pipeline using PySpark and PyTorch

Note: The related code examples are found at this Google Colab Notebook. The Notebook shows the complete image embedding pipeline for local images. The scripts below show embedding pipelines for images from Amazon S3 storage. The core concept remains the same.

Implement the following steps

Initialize PySpark and write your preprocessing and integration functions to integrate with Sales data

Initialize PySpark session using the code block below.

from pyspark.sql import SparkSession
from pyspark.sql.functions import udf
from pyspark.sql.types import BinaryType, StringType
from PIL import Image
import io
 
spark = SparkSession.builder \
    .appName("ImageDownscaleToDelta") \
    .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") \
    .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog") \
    .getOrCreate()

Write a utility function to downscale images, and execute the downscaling. We also generate an image identifier from the image file name. This ID will be used to join with the sales data.

def downscale_image(content, size=(224, 224)):
    try:
        img = Image.open(io.BytesIO(content)).convert("RGB")
        img = img.resize(size)
        buf = io.BytesIO()
        img.save(buf, format="JPEG")
        return buf.getvalue()
    except Exception as e:
        return None
 
downscale_udf = udf(lambda content: downscale_image(content), BinaryType())
 
images_df = spark.read.format("binaryFile") \
    .option("pathGlobFilter", "*.jpg") \
    .option("recursiveFileLookup", "true") \
    .load("s3://my-bucket/raw-images/")
 
from pyspark.sql.functions import regexp_extract
 
images_df = images_df.withColumn("image_id", regexp_extract("path", r"([^/]+)\.jpg$", 1))
 
images_df = images_df.withColumn("downscaled", downscale_udf(images_df["content"])) \
                      .select("image_id", "downscaled")

Join the sales data from another S3 folder and write the results

sales_df = spark.read.format("parquet").load("s3://my-bucket/sales-data/")
 
joined_df = images_df.join(sales_df, on="image_id", how="inner")
 
joined_df.write.format("delta") \
     .mode("overwrite") \
     .save("s3://my-bucket/delta/images-sales-combined")

The above script can be run on a Spark cluster with the required Delta Lake and Python dependencies installed, along with access to the target S3 bucket. It does not require a GPU, since Spark can run on any commodity hardware with enough RAM.

Read the data from S3

Initialize the import statements as below.

from pyspark.sql import SparkSession
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import ArrayType, FloatType
import pandas as pd
import torch
import torchvision.models as models
import torchvision.transforms as T
from PIL import Image
import io

Create a Spark session context using the snippet below to read images from S3.

spark = SparkSession.builder \
    .appName("ImageEmbeddingsToDelta") \
    .config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension") \
    .config("spark.sql.catalog.spark_catalog", "org.apache.spark.sql.delta.catalog.DeltaCatalog") \
    .getOrCreate()
 
df = spark.read.format("delta").load("s3://my-bucket/delta/images-sales-combined")

Generate embeddings with PyTorch

Initialize the PyTorch model and move it to CUDA if available.

Note that the last layer of ResNet has been removed here so that it can be used to generate embeddings.

device = "cuda" if torch.cuda.is_available() else "cpu"
 
resnet = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
resnet = torch.nn.Sequential(*list(resnet.children())[:-1])
resnet = resnet.to(device)
resnet.eval()

You can then build the Torch processing logic to read images into a Tensor and normalize them.

transform = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225])
])

Create a Pandas user-defined function for generating embeddings.

@pandas_udf(ArrayType(FloatType()))
def generate_embeddings_udf(content_series: pd.Series) -> pd.Series:
    embeddings = []
    with torch.no_grad():
        for content in content_series:
            try:
                img = Image.open(io.BytesIO(content)).convert("RGB")
                img_t = transform(img).unsqueeze(0).to(device)  # (1,3,224,224)
                emb = resnet(img_t).squeeze().cpu().numpy()
                embeddings.append(emb.tolist())
            except Exception as e:
                embeddings.append(None)
    return pd.Series(embeddings)

In production Spark deployments, model initialization and GPU placement usually need to be managed carefully on executor-side Python workers, so additional packaging and execution setup is often required.

Generate the embeddings and write to Delta Lake using the snippet below.

embed_df = df.withColumn("embedding", generate_embeddings_udf(df["downscaled"]))
 
embed_df.write.format("delta") \
    .mode("overwrite") \
    .save("s3://my-bucket/delta/image-embeddings")

Save the second script separately and run it on GPU infrastructure for the best results.

Challenges with integrating PyTorch in data engineering workflows

The above approach used two different scripts: one for transforming images to a different format and another for generating embeddings. PySpark has a distinct advantage in data ingestion thanks to its native connectors, which are better suited for distributed file storage. It also has an advantage when data needs to be joined with another dataset, thanks to its rich SQL-like syntax.

While this approach is feasible, it has limitations when deployed in production.

  • Because of the use of two scripts, the pipeline is inherently complex. Since PySpark and PyTorch have very different hardware requirements and configurations, these scripts must be handled separately during execution.
  • Because data is being transferred across two execution engines, there will be serialization overhead.
  • PyTorch and PySpark have distinct dependency management requirements, resulting in a complex setup and installation process.
  • With PyTorch working better with GPUs and PySpark primarily RAM-heavy, it is very difficult to optimize the entire pipeline to ensure GPUs are fully utilized.

Image embedding generation with Daft

An alternative to the split Spark-plus-PyTorch workflow is to use a unified data processing framework that can handle both data manipulation and model execution in one pipeline. Daft is an example of such a framework.

Initialize Daft using the code snippet below.

import daft
import torch
import torchvision.models as models
import torchvision.transforms as T
from daft import DataType, Series, col
from daft.functions import regexp_extract
 
device = "cuda" if torch.cuda.is_available() else "cpu"
 
# Optional: use Ray when you want distributed execution / GPU scheduling.
# daft.set_runner_ray()

Read images from S3 and generate image identifiers.

images_df = daft.from_glob_path("s3://my-bucket/raw-images/*.jpg")
 
# 1. Download images and resize
images_df = images_df.with_column(
    "image",
    col("path").download(on_error="null").decode_image(on_error="null", mode="RGB")
).where(col("image").not_null())
 
images_df = images_df.with_column("image", col("image").resize(224, 224))
 
# 2. Extract image_id from filename
images_df = images_df.with_column(
    "image_id",
    regexp_extract(col("path"), r"([^/]+)\.[^.]+$", 1)
)

Read sales data from S3 and join it with image identifiers.

# 3. Read sales data from S3
sales_df = daft.read_parquet("s3://my-bucket/sales-data/")
 
# 4. Join images with sales data
joined_df = images_df.join(sales_df, on="image_id", how="inner")

Create utility functions for generating embeddings.

@daft.cls(gpus=1 if torch.cuda.is_available() else 0)
class ImageEmbedder:
    def __init__(self):
        weights = models.ResNet50_Weights.DEFAULT
        model = models.resnet50(weights=weights)
        model.fc = torch.nn.Identity()
        self.device = device
        self.model = model.eval().to(self.device)
        self.transform = T.Compose(
            [
                T.ToTensor(),
                T.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225],
                ),
            ]
        )
 
    @daft.method.batch(
        return_dtype=DataType.embedding(DataType.float32(), 2048),
        batch_size=64,
    )
    def embed(self, batch: Series):
        tensors = []
        valid_idx = []
        outputs = [None] * len(batch)
 
        for idx, image in enumerate(batch.to_pylist()):
            if image is None:
                continue
            tensors.append(self.transform(image))
            valid_idx.append(idx)
 
        if not tensors:
            return outputs
 
        with torch.no_grad():
            embeddings = (
                self.model(torch.stack(tensors).to(self.device))
                .cpu()
                .numpy()
                .astype("float32")
            )
 
        for idx, embedding in zip(valid_idx, embeddings):
            outputs[idx] = embedding.tolist()
 
        return outputs
 
embedder = ImageEmbedder()

Execute the embedding generation and write the final dataset to the Delta Lake.

# 6. Compute embeddings
result_df = (
    joined_df
    .with_column("embedding", embedder.embed(col("image")))
    .select("image_id", "sales", "embedding")
)
 
# 7. Write final dataset to Delta Lake
result_df.write_deltalake("s3://my-bucket/delta/images-sales-embeddings", mode="overwrite")

Developers can now save the script and run it in the infrastructure where Daft is installed.

Benefits

This approach offers the following advantages over the two-script approach.

The Daft-based approach is much simpler. For example, Daft's unified API handles image preprocessing and embedding generation in a single script, reducing context switching compared to managing separate Spark and PyTorch environments. The difference here is that orchestration, preprocessing, joins, and embedding execution are expressed in one Daft pipeline, while PyTorch is used inside Daft for model inference.

The Daft pipeline also reduces the amount of cross-framework data movement because preprocessing, joins, and model inference can be expressed in one pipeline rather than split across separate Spark and PyTorch stages.

The two-script approach results in multiple scripts running on different frameworks, making it difficult to keep GPUs busy. Because Daft lets teams express preprocessing, joins, and model execution in one pipeline, it can make resource planning and execution flow easier to manage than a split multi-framework setup.

Dependency management is much simpler because you are not dealing with two disjointed frameworks.

Recommendations

Images are now considered first-class citizens in any enterprise data platform, and handling them efficiently helps reduce both costs and time-to-market.

Use pretrained models wherever possible

At this point, image models are fairly mature, and pretrained models are available for almost all purposes, including embedding generation, object detection, tracking, and caption generation. Pretrained models such as ResNet, Vision Transformer (ViT), CLIP, and DINOv2 are common choices for image embedding workloads.

Unless you are dealing with a very unique requirement that involves untrained objects or complex tracking scenarios, there is no need for training a model from scratch. One can download pre-trained models and use them as is, or fine-tune them for specific use cases when requirements are complex.

Ensure optimum utilization of GPUs

When working with inference or training that requires higher throughput, it is essential to ensure that the pipeline feeding GPUs has sufficient resources to maintain high GPU utilization. A GPU waiting for input is not an optimal situation. In many cases, CPUs become a bottleneck, unable to provide GPUs with sufficient work, resulting in underutilization of GPUs.

Analyze frameworks according to their strengths

Many frameworks exist in the data processing and machine learning space with unique advantages and disadvantages. Developers must carefully analyze the strengths of frameworks while choosing them. While frameworks like PyTorch can handle both training and inference, there are frameworks specific to inference or serving, like ONNX Runtime, that perform better for inference-only use cases. It also helps to choose frameworks that work on top of these ML frameworks to enable distributed processing and further optimization. For example, frameworks like Daft provide unified multimodal access APIs and natively support vectorized processing with data-type-specific batch configurations.

Use task-specific quantitative evaluation of embedding models

There are several open-source embedding models with unique strengths in performance, latency, parameter count, etc. Choosing an embedding model suitable for your use case is important for reducing the cost and time of embedding generation. Embeddings can be evaluated quantitatively by measuring pairwise similarity and distance between relevant objects. Another method to evaluate embeddings is to cluster them and calculate the silhouette score to ensure sufficient separation between clusters.

Monitor embedding drift

Embedding drift refers to changes in the embedding distribution over time. This happens because of changes in the distribution of input data or environmental changes. Another reason for embedding drift is concept drift, where the semantics of your classes or categories change over time. Statistical distance metrics, such as cosine similarity, or population-level metrics, such as the mean and variance of the feature dimensions, can be used to check whether the embedding is drifting.

Last thoughts

Image embeddings are key to representing the semantic meaning of images. They are critical for implementing computer vision use cases, as they enable models to understand what images represent.

Working with image embeddings often involves machine learning frameworks such as PyTorch or TensorFlow, although teams may also use inference runtimes, hosted APIs, or vector databases depending on the workflow. Typical operations in an enterprise data pipeline include both data manipulation, such as joining and aggregation, and model inference. Data manipulation operations are generally executed using distributed processing frameworks like Spark or managed Spark platforms such as Databricks, which have very different infrastructure and dependency requirements compared to deep learning frameworks.

Frameworks like Daft that can perform both kinds of operations while providing a unified multi-modal API and data type-specific batch optimizations can help streamline the implementation of such complex data pipelines. By enabling seamless execution of heterogeneous workloads and data-type-level optimizations, such frameworks can streamline the implementation of AI data pipelines.

Suggested Posts

Get updates, contribute code, or say hi.
Daft Engineering Blog
Join us as we explore innovative ways to handle multimodal datasets, optimize performance, and simplify your data workflows.
Github Discussions Forums
join
GitHub logo
The Distributed Data Community Slack
join
Slack logo