Redis tutorial: Text to Image AI assistant with Redis Search

Tuesday, April 25, 2023 by jakub.misilo
Redis tutorial: Text to Image AI assistant with Redis Search

Introduction

In recent months, both the text-to-image and Vector Database model markets have grown significantly. These two technologies are very powerful on their own, and combining them can make them even more significant! In this tutorial, I will teach you how to build a simple application to support the process of finding similar prompts and images for text-to-image models. We encourage you to join lablab.ai’s community and learn more about how to use Redis during our Hackathon artificial intelligence!

RediSearch

It's Redis Database module that enables querying and indexing data from Redis Databases. It's a very powerful tool that can be used in many different ways. In this tutorial we will use it to index data and find similar prompts/images using vector similarity search.

CLIP

CLIP is a neural network that learns visual concepts from natural language supervision. It is trained on a variety of image-text pairs, and can be used to predict the most likely image for a given text description, or the most likely text description for a given image. We will use this to find similar prompts and images based on the description we entered or the image we provided.

Coding

Okay, we can start coding. Our application will consist of two parts:

  • API
  • Streamlit Application (UI).

Redis Database

First, we need the Redis Database. I will use Redis Cloud for that, but you can e.g. use Docker image for that. Of course you can start with Redis for free.

Data

For the purpose of this project, we will use the popular Flickr8k dataset. You can download it from the Internet, it is readily available on Kaggle, for example.

Dependencies

To start our project, he proposes to create a proper file structure. Let's create a main directory.

mkdir t2i-assistant-redis


cd t2i-assistant-redis

Now we can create a virtual environment and install all the necessary dependencies.

python3 -m venv venv


# for linux/mac
source venv/bin/activate


# for windows
./venv/Scripts/activate

Let's create requirements.txt file and install all the necessary dependencies. File content:

redis
fastapi
python-multipart
uvicorn[standard]
Pillow
transformers
open_clip_torch
torch
streamlit
requests

Install all the dependencies:

pip install -r requirements.txt

Now we can prepare the rest of the files. My folder structure looks like this:

├── data
│   ├── captions.csv
│   ├── Images
│       ├── <image_name>.jpg
├── src
│   ├── model
│       ├── __init__.py
│       ├── clip.py
│   ├── utils
│       ├── __init__.py
│       ├── data.py
│   ├── main.py
│   ├── streamlit.py
├── venv (virtual environment)
├── requirements.txt

Let's start coding!

Model

I suggest starting by preparing the model for photo processing and captions. Let's do it in the src/model/clip.py file. First, we need to import all the necessary dependencies.

from typing import List


import open_clip
import torch
from PIL import Image

We can prepare a class for our model, and then implement some methods that will allow us to use its functionalities in a simpler way. I will use LAION AI's implementation of CLIP. You can find it on Hugging Face.

class CLIP:
    def __init__(
        self, model_name="hf-hub:laion/CLIP-ViT-H-14-laion2B-s32B-b79K", device="cpu"
    ):
        model, _, preprocess_img = open_clip.create_model_and_transforms(model_name)


        self.device = device


        self.model = model.to(self.device)
        self.preprocess_img = preprocess_img
        self.tokenizer = open_clip.get_tokenizer(model_name)


    def encode_image(self, image: Image.Image | List[Image.Image], normalize=True):
        processed_img = (
            torch.stack([self.preprocess_img(img).to(self.device) for img in image])
            if type(image) == list
            else self.preprocess_img(image).to(self.device)
        )


        if processed_img.dim() == 3:
            processed_img = processed_img.unsqueeze(0)


        image_features = self.model.encode_image(processed_img)


        if normalize:
            image_features /= image_features.norm(dim=-1, keepdim=True)


        return image_features


    def encode_text(self, text: str | List[str], normalize=True):
        text = self.tokenizer(text).to(self.device)


        text_features = self.model.encode_text(text)


        if normalize:
            text_features /= text_features.norm(dim=-1, keepdim=True)


        return text_features

Utils

Now we can move on to the utility functions that will be needed to index our data in the Redis database. I will start importing dependencies.

import os
from uuid import uuid4


import pandas as pd
from redis.commands.search.field import TextField, VectorField

I will also define a constant value - EMBEDDING_DIM. It will be used to define the size of the vector that will be used to index our data (size returned from CLIP model, you can get it from model itself or from Hugging Face Docs).

EMBEDDING_DIM = 1024

Another thing will be the function that will embed our descriptions.

def embed_record(clip, caption):
    caption_features = clip.encode_text(caption).squeeze()


    return caption_features.cpu().detach().numpy()

Now we can create a function that will index our data in Redis database.

def index_data(redis_client, clip):
    # when running for the first time, we don't need to drop index.
    redis_client.ft().dropindex()


    DATA_DIR = os.path.join("data")
    df = pd.read_csv(os.path.join(DATA_DIR, "captions.csv"))


    redis_client.ft().create_index(
        [
            TextField("image"),
            TextField("caption"),
            VectorField(
                "caption_features",
                "FLAT",
                {
                    "TYPE": "FLOAT32",
                    "DIM": EMBEDDING_DIM,
                    "DISTANCE_METRIC": "COSINE",
                },
            ),
        ]
    )


    selected_data = (
        # select every 5th row (each image has 5 similar captions)
        df.iloc[::5, :]
        .apply(
            lambda x: (x["image"], x["caption"], embed_record(clip, x["caption"])),
            axis=1,
        )
        .to_numpy()
    )


    pipe = redis_client.pipeline()
    i = 0
    for img_filename, caption, caption_features in selected_data:
        pipe.hset(
            uuid4().hex,
            mapping={
                "image": img_filename,
                "caption": caption,
                "caption_features": caption_features.tobytes(),
            },
        )
        i += 1
    pipe.execute()

API

Let's move on to our API. We will implement this in the src/main.py file. We need to create two endpoints - one for image-based search and one for description-based search. But let's start with the necessary dependencies.

import numpy as np
import redis
from fastapi import FastAPI, HTTPException, UploadFile, status
from PIL import Image
from pydantic import BaseModel
from redis.commands.search.query import Query


from src.model import CLIP
from src.utils import index_data

At this point we can move on to initializing the model and the Redis client. Also at this point it would be useful to index our data.

clip = CLIP()


redis_client = redis.Redis(
    host="redis-10292.c23738.us-east-1-mz.ec2.cloud.rlrcp.com",
    port=10292,
    password="newnative",
)


index_data(redis_client, clip)

The last thing you would want to prepare before moving on to the API implementation is a function to query images.

def query_image(caption_features: np.array, n=1):
    if caption_features.dtype != np.float32:
        raise TypeError("caption_features must be of type float32")


    query = (
        Query(f"*=>[KNN {n} @caption_features $caption_features]")
        .return_fields("image", "caption")
        .dialect(2)
    )


    result = redis_client.ft().search(
        query=query, query_params={"caption_features": caption_features.tobytes()}
    )


    return result.docs

The time has come for API implementations. We need to create two endpoints:

  • one for image processing
  • one for processing the description. Both should return a description and a path to the most similar object to the entered data.

My code will look like the following:

class SearchBody(BaseModel):
    description: str




app = FastAPI()




@app.post("/search/image/")
async def search_by_image(image: UploadFile):
    # check if image is valid
    if not image.content_type.startswith("image/"):
        raise HTTPException(
            status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
            detail="File is not an image",
        )


    image = Image.open(image.file)


    # embed image using CLIP
    img_features = clip.encode_image(image)


    img_features = img_features.squeeze().cpu().detach().numpy().astype(np.float32)


    # search for similar images/prompts
    result = query_image(img_features)


    result = result[0]


    return {
        "image": result["image"],
        "caption": result["caption"],
    }




@app.post("/search/description/")
async def search_description(body: SearchBody):
    # embed description using CLIP
    caption_features = clip.encode_text(body.description)


    # cast to float32
    caption_features = (
        caption_features.squeeze().cpu().detach().numpy().astype(np.float32)
    )


    # search for similar images/prompts
    result = query_image(caption_features)


    result = result[0]


    return {
        "image": result["image"],
        "caption": result["caption"],
    }

To run our API we can use the command in terminal:

uvicorn src.main:app --host 0.0.0.0 --port 8000

UI

The last part of our application is the UI implementation. For this we will use Streamlit. We will create a simple interface that will consist of Text Input, File Input (for images) and submit button.

Let's do it!

import json
import os


import requests


import streamlit as st


# Add a prompt to the app
prompt = st.text_input("Prompt")


# Add file uploader to the app
image = st.file_uploader("Upload an image")


# Add a button to the app
button = st.button("Find similar images/prompts")


# when the button is clicked
if button:
    # if the user uploaded an image
    if image:
        URL = "http://localhost:8000/search/image"
        IMG_EXT = ["jpg", "jpeg", "png"]


        file_extension = image.name.split(".")[-1]
        print(file_extension)


        if not file_extension in IMG_EXT:
            print("Invalid file extension")


        # send the image to the server (form data)
        files = {
            "image": (
                image.name,
                image.read(),
                f"image/{file_extension}",
            ),
        }


        response = requests.post(
            URL,
            files=files,
        )


        # display the response
        res = response.json()


        caption = res["caption"]
        image = os.path.join("data", "Images", res["image"])


        st.image(image, caption=caption)


    if prompt and not image:
        URL = "http://localhost:8000/search/description"
        response = requests.post(
            URL,
            data=json.dumps({"description": prompt}),
        )


        res = response.json()


        caption = res["caption"]
        image = os.path.join("data", "Images", res["image"])


        st.image(image, caption=caption)

Okay, I think we are ready to go.

Let's run our application.

streamlit run src/streamlit.py

Conclusion

Let's check how our application works. We can do so by entering a description or uploading an image.

A black dog and a spotted dog are fighting

As you can see the result is pretty cool!

If you have managed to rub it in to this point - well done! Hope you learnt a lot. I encourage you to explore other technologies. Maybe you want to build GPT3 app? Or just upgrade your project with it? Or you want to get inspired and build Cohere app - potential is limitless with the power of AI!

Project repository

Thank you for your time! - Jakub Misiło @newnative