AI21 Labs + Stable Diffusion Tutorial: Beginner friendly Tutorial - How to build your app with AI21 Labs, adding Stable Diffusion integration.

Wednesday, July 26, 2023 by abdibrokhim
AI21 Labs + Stable Diffusion Tutorial: Beginner friendly Tutorial - How to build your app with AI21 Labs, adding Stable Diffusion integration.

Introduction

Stable Diffusion, is a new generative model that can generate high-resolution images with a single forward pass. AI21 Studio is a platform that provides developers and businesses with top-tier natural language processing (NLP) solutions, powered by AI21 Labs’ state-of-the-art language models.

What we are going to do?

In this tutorial, we will build a simple and funny app, for creating engaging tweets, using AI21 Labs and cool cover images using Stable Diffusion. We will use Streamlit to build our app. Streamlit is an open-source Python library that makes it easy to build beautiful custom web-apps for machine learning and data science.

Prerequisites

Go to Visual Studio Code and choose that compatible with your operating system and download it, or feel free to use any other code editor like: IntelliJ IDEA, PyCharm, etc.

Visual Studio Code
Visual Studio Code
Visual Studio Code
Visual Studio Code

To use AI21 Labs Studio we need API Key. Go to AI21 Labs Studio, Sign up for an account. Go to Dashboard once you have created an account. Then, click to Profile picture on the top right corner.

AI21 Labs Studio
AI21 Labs Studio

Select API Key from the dropdown menu.

AI21 Labs Studio
AI21 Labs Studio

Click copy icon to copy your API Key. Don't forget to save it somewhere safe.

AI21 Labs Studio
AI21 Labs Studio

To use Stable Diffusion we need API Key. Go to Dream Studio, Sign up for an account to be taken to your API Key. Go to Dashboard. Then, click + button to create new API Key.

Stable Diffusion
Stable Diffusion

Click copy icon to copy your API Key. Don't forget to save it somewhere safe.

Stable Diffusion
Stable Diffusion

To deploy our app we need to create an account on Streamlit. Go to Streamlit and click Sign up button.

Stable Diffusion
Stable Diffusion

I recommend you to use GitHub to sign up. Later it will be helpful to depploy our app in a seconds. Click Sign up with GitHub button.

Getting Started

Create a brand new project

Open Visual Studio Code and create new folder named ai21-sd-tutorial:

mkdir ai21-sd-tutorial
cd ai21-sd-tutorial

Create a virtual environment and activate it

Next, we need to create a virtual environment and activate it. To do this, run the following command in your terminal:

python3 -m venv venv

# on MacOS and Linux:
source venv/bin/activate

# on Windows:
venv\Scripts\activate

Install all dependencies

Now, we need to install all dependencies. To do this, open and run the following command in your terminal:

pip install streamlit
pip install ai21
pip install stability-sdk

Implementing the app

Firstly, let's create a file named stable_diffusion.py and implement function to generate image from text using Stable Diffusion model. Read more about Text-to-image gRPC API python client.

import os
import io
import warnings
from PIL import Image
from stability_sdk import client
import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation

# Our Host URL should not be prepended with "https" nor should it have a trailing slash.
os.environ['STABILITY_HOST'] = 'grpc.stability.ai:443'

# Sign up for an account at the following link to get an API Key.
# https://dreamstudio.ai/

# Click on the following link once you have created an account to be taken to your API Key.
# https://dreamstudio.ai/account

# Paste your API Key below.


def imagine(prompt, stable_diffusion_api_key):

    try:
        output_path = ""
        # Set up our connection to the API.
        stability_api = client.StabilityInference(
            # key=os.environ.get("STABILITY_KEY"), # API Key reference.
            key=stable_diffusion_api_key, # API Key reference.
            verbose=True, # Print debug messages.
            engine="stable-diffusion-xl-beta-v2-2-2", # Set the engine to use for generation.
            # Available engines: stable-diffusion-v1 stable-diffusion-v1-5 stable-diffusion-512-v2-0 stable-diffusion-768-v2-0
            # stable-diffusion-512-v2-1 stable-diffusion-768-v2-1 stable-diffusion-xl-beta-v2-2-2 stable-inpainting-v1-0 stable-inpainting-512-v2-0
        )

        # Set up our initial generation parameters.
        answers = stability_api.generate(
            prompt=prompt, # The prompt to generate from.
            seed=992446758, # If a seed is provided, the resulting generated image will be deterministic.
                            # What this means is that as long as all generation parameters remain the same, you can always recall the same image simply by generating it again.
                            # Note: This isn't quite the case for CLIP Guided generations, which we tackle in the CLIP Guidance documentation.
            steps=30, # Amount of inference steps performed on image generation. Defaults to 30.
            cfg_scale=8.0, # Influences how strongly your generation is guided to match your prompt.
                        # Setting this value higher increases the strength in which it tries to match your prompt.
                        # Defaults to 7.0 if not specified.
            width=512, # Generation width, defaults to 512 if not included.
            height=512, # Generation height, defaults to 512 if not included.
            samples=1, # Number of images to generate, defaults to 1 if not included.
            sampler=generation.SAMPLER_K_DPMPP_2M # Choose which sampler we want to denoise our generation with.
                                                        # Defaults to k_dpmpp_2m if not specified. Clip Guidance only supports ancestral samplers.
                                                        # (Available Samplers: ddim, plms, k_euler, k_euler_ancestral, k_heun, k_dpm_2, k_dpm_2_ancestral, k_dpmpp_2s_ancestral, k_lms, k_dpmpp_2m, k_dpmpp_sde)
        )
        # Set up our warning to print to the console if the adult content classifier is tripped.
        # If adult content classifier is not tripped, save generated images.
        for resp in answers:
            for artifact in resp.artifacts:
                if artifact.finish_reason == generation.FILTER:
                    warnings.warn(
                        "Your request activated the API's safety filters and could not be processed."
                        "Please modify the prompt and try again.")                  
                if artifact.type == generation.ARTIFACT_IMAGE:
                    img = Image.open(io.BytesIO(artifact.binary))
                    prompt.replace(".", "")
                    img.save(prompt[:10] + '.png') # Save our generated images with their seed number as the filename.
                    img_path = prompt[:10] + '.png'
        
        return img_path

    except Exception as e:
        print(e)
        return ""

Next, let's create a file named ai21_studio.py and implement function which will be used to generate cool ideas for engaging tweets to promote a brand or product:

import ai21

def generate(prompt, ai21_api_key):
    ai21.api_key = ai21_api_key

    _prompt = """
    Write a tweet to promote an event for Shmolt, a food and more delivery company.
    ##
    Event: Christmas.
    Tweet: Oven broke in the last minute and you have no turkey? No need to panic, with Shmolt delivery it will be at your door step by Christmas eve!
    ##
    Event: Valentine's Day.
    Tweet: Order her roses using Shmolt, so you can have time to stop and smell the roses #Valentines
    ##
    Event: Summer Sale.
    Tweet: In this heat, why get out of the AC when you have 10% off of all Shmolt deliveries?
    ##
    Event: Mother's Day.
    Tweet:

    """
    
    response = ai21.Completion.execute(
        model="j2-mid",  # The engine to use for generation.
        prompt=_prompt + prompt,
        numResults=1,
        maxTokens=64,
        temperature=0.84,
        topKReturn=0,
        topP=1,
        countPenalty={
            "scale": 0.1,
            "applyToNumbers": False,
            "applyToPunctuations": False,
            "applyToStopwords": False,
            "applyToWhitespaces": False,
            "applyToEmojis": False
        },
        frequencyPenalty={
            "scale": 0,
            "applyToNumbers": False,
            "applyToPunctuations": False,
            "applyToStopwords": False,
            "applyToWhitespaces": False,
            "applyToEmojis": False
        },
        presencePenalty={
            "scale": 0,
            "applyToNumbers": False,
            "applyToPunctuations": False,
            "applyToStopwords": False,
            "applyToWhitespaces": False,
            "applyToEmojis": False
        },
        stopSequences=["##"]
    )

    return response['completions'][0]['data']['text']

Perfect! Now we have all the functions we need to generate engaging ideas and cool cover images.

Finally, Create a file named app.py here we will implement our logic for Streamlit app. It allows us to create really fast web apps in pure Python and share it with everyone.

Stable Diffusion
Stable Diffusion

Let's import all the necessary libraries and functions:

# Import from standard library
import os
import logging

# Import from 3rd party libraries
import streamlit as st

# Import modules from the local package
import stable_diffusion, ai21_studio

Basic configuration for the app:

# Configure logger
logging.basicConfig(format="\n%(asctime)s\n%(message)s", level=logging.INFO, force=True)

# Configure Streamlit page and state
st.set_page_config(page_title="AI21 Studio Tutorial", page_icon="🍩")

Initialize the app states:

# Store the initial value of widgets in session state
if "imagine" not in st.session_state:
    st.session_state.imagine = ""

if "img_path" not in st.session_state:
    st.session_state.img_path = ""

if "query" not in st.session_state:
    st.session_state.query = ""

if "im_query" not in st.session_state:
    st.session_state.im_query = ""

if "prompt_generate" not in st.session_state:
    st.session_state.prompt_generate = ""

if "text_error" not in st.session_state:
    st.session_state.text_error = ""

if "stable_diffusion_api_key" not in st.session_state:
    st.session_state.stable_diffusion_api_key = ""

if "ai21_api_key" not in st.session_state:
    st.session_state.ai21_api_key = ""

if "visibility" not in st.session_state:
    st.session_state.visibility = "visible"

Work on the app layout:

# Force responsive layout for columns also on mobile
st.write(
    """
    <style>
    [data-testid="column"] {
        width: calc(50% - 1rem);
        flex: 1 1 calc(50% - 1rem);
        min-width: calc(50% - 1rem);
    }
    </style>
    """,
    unsafe_allow_html=True,
)

Let's create sidebar where we can put our API keys. Why not simple use .env? Because we want to make it available for everyone to use it without any hassle. Later, we can also deploy it to Streamlit Sharing Cloud.

# Render Streamlit page
with st.sidebar:
    st.session_state.cohere_api_key = st.text_input('AI21 Studio API Key', )
    st.session_state.stable_diffusion_api_key = st.text_input('Stable Diffusion API Key', )

Quickly add a cool title and description for the app:

# title of the app
st.title("AI21 Studio + Stable Diffusion Tutorial")


st.markdown(
    "This is a demo of the AI21 Studio + Stable Diffusion Tutorial app."
)

Now, let's add a st.text_area() widget to get the prompt from the user and a st.button() widget to generate the ideas based on the prompt.

# textarea
st.session_state.query = st.text_area(
    label="Generate engaging ideas for tweets",
    placeholder="Elon Musk wants to fight with Mu", height=100)


# button
st.button(
    label="Generate ideas",
    help="Click to genearate ideas",
    key="generate_prompt",
    type="primary",
    on_click=generate_ideas,
    )


Now, let's add another st.text_area() widget for image generation using Stable Diffusion. You may ask why we don't instantly generate the image after clicking the Generate ideas button along with the ideas. The reason is that we want to make it more interactive and fun.

# textarea
st.session_state.im_query = st.text_area(label="Cool description for tweets cover", placeholder="Elon Musk wants to fight with Mu", height=100)


# button
st.button(
    label="Generate image",
    help="Click to genearate image",
    key="generate_image",
    type="primary",
    on_click=imagine,
    )


Output genearated ideas and image:

text_spinner_placeholder = st.empty()
if st.session_state.text_error:
    st.error(st.session_state.text_error)


if st.session_state.prompt_generate:
    st.markdown("""---""")
    st.text_area(label="Cool ideas", value=st.session_state.prompt_generate,)


if st.session_state.img_path:
    st.markdown("""---""")
    st.subheader("Cool image")
    st.image(st.session_state.img_path, use_column_width=True, caption="Image generated by Stable Diffusion", output_format="PNG")

And last but not least, it's time to implement main functions for generating ideas:

def generate_ideas():

    st.session_state.text_error = ""

    if st.session_state.cohere_api_key == "":
        st.session_state.text_error = "Missed API key."
        return


    st.session_state.text_error = ""

    if st.session_state.file_path == "" or st.session_state.query == "":
        st.session_state.text_error = "Missed a file or query."
        return


    st.session_state.prompt_generate = ""
    st.session_state.text_error = ""

    with text_spinner_placeholder:
        with st.spinner("Please wait while we process your query..."):
            prompt = ai21_studio.generate(prompt=st.session_state.query, ai21_api_key=st.session_state.ai21_api_key)

            if prompt == "":
                st.session_state.text_error = "Your request activated the API's safety filters and could not be processed. Please modify the prompt and try again."
                logging.info(f"Text Error: {st.session_state.text_error}")
                return
            
            st.session_state.prompt_generate = (prompt)

and image:

def imagine():

    st.session_state.text_error = ""

    if st.session_state.stable_diffusion_api_key == "":
        st.session_state.text_error = "Missed API key."
        return


    st.session_state.text_error = ""

    if st.session_state.im_query == "":
        st.session_state.text_error = "Missed a query."
        return

    st.session_state.imagine = ""
    st.session_state.img_path = ""
    st.session_state.text_error = ""

    with text_spinner_placeholder:
        with st.spinner("Please wait while we generate your image..."):
            im_path = stable_diffusion.imagine(prompt=st.session_state.im_query, stable_diffusion_api_key=st.session_state.stable_diffusion_api_key)

            if im_path == "":
                st.session_state.text_error = "Your request activated the API's safety filters and could not be processed. Please modify the prompt and try again."
                logging.info(f"Text Error: {st.session_state.text_error}")
                return
            
            st.session_state.img_path = (im_path)

Perfect! Now, we can run our app and see the result.

streamlit run app.py

Deploy to Streamlit Sharing Cloud

Checkout ElevenLabs Tutorial for detailed instructions on how to deploy your app to Streamlit Sharing Cloud.

Stable Diffusion
Stable Diffusion

Conclusion

We build something cool today! We learned how to use AI21 Studio API to generate ideas for tweets and Stable Diffusion API to generate images for tweets cover. We also learned how to deploy our app to Streamlit Sharing Cloud. Perfect! Don't forget to join upcoming hackathons on lablab.ai and win cool prizes.

Thank you for following along with this tutorial.

If you have any questions, feel free to reach out to me on LinkedIn or Twitter. I'd love to hear from you!

made with 💜 by abdibrokhim for lablab.ai tutorials.