Build a fullstack storytelling application using MediaPipe, langchain.js and gemma 2

Connie Leung - Sep 18 - - Dev Community

In this blog post, I describe how to build a fullstack storytelling application using MediaPipe Image Classification Task, Angular, langchain.js, NestJS, and Gemma 2. The Angular application used the image classification task of the MediaPipe Web SDK to find the top 3 categories of a selected image. Then, it called the NestJS API to use the Gemma 2 LLM to generate a story about the categories, a maximum of 300 words. Finally, the API returns the story in the HTTP response and the Angular application displays the result on the browser.

Create the NestJS application

nest new nestjs-ai-sprint-2024-demo
Enter fullscreen mode Exit fullscreen mode

The nest CLI scaffold a new application for storytelling.

Set up environment variables

Copy .env.example to .env

PORT=3001
GROQ_API_KEY=<GROQ API KEY>
GROQ_MODEL=gemma2-9b-it
SWAGGER_TITLE=AI Storytelling Application
SWAGGER_DESCRIPTION=This application uses the prompt to ask the LLM to generate a short story.
SWAGGER_VERSION=1.0
SWAGGER_TAG=Gemma 2 9B, Groq, Langchain.js
Enter fullscreen mode Exit fullscreen mode

Navigate to Groq Cloud, https://console.groq.com/, sign up and register a new API Key. Replace the API Key to GROQ_API_KEY.

Install the dependencies

npm i -save-exact @langchain/community @langchain/core @langchain/groq @nestjs/config @nestjs/swagger @nestjs/throttler class-transformer class-validator langchain compression
Enter fullscreen mode Exit fullscreen mode

Define the configuration in the application

Create a src/configs folder and add a configuration.ts to it.

export default () => ({
  port: parseInt(process.env.PORT, 10) || 3000,
  groq: {
    apiKey: process.env.GROQ_API_KEY || '',
    model: process.env.GROQ_MODEL || 'gemma2-9b-it',
  },
  swagger: {
    title: process.env.SWAGGER_TITLE || '',
    description: process.env.SWAGGER_DESCRIPTION || '',
    version: process.env.SWAGGER_VERSION || '1.0',
    tag: process.env.SWAGGER_TAG || '',
  },
});
Enter fullscreen mode Exit fullscreen mode

Create a Storytelling Module

Create a storytelling module that prompts the Gemma2 model to generate a story about some categories.

nest g mo storytelling
nest g co storytelling/presenters/http/storytelling --flat
nest g s storytelling/application/storytelling --flat
Enter fullscreen mode Exit fullscreen mode

Create constants

// application/constants/groq.constant.ts 

export const GROQ_CHAT_MODEL = 'GROQ_CHAT_MODEL';
Enter fullscreen mode Exit fullscreen mode

This constant eventually injects a Groq Chat Model in the NestJS application.

Declare Groq Configuration Type

// application/types/groq-config.type.ts

export type GroqConfig = {
 model: string;
 apiKey: string;
};
Enter fullscreen mode Exit fullscreen mode

GroqConfig is a configuration type that stores the model name and API key of Groq.

Custom Factory Providers

The GROQ_CHAT_MODEL creates a Groq Chat Model that uses the Gemma 2 model.

import { ChatGroq } from '@langchain/groq';
import { Inject, Provider } from '@nestjs/common';
import { ConfigService } from '@nestjs/config';
import { GROQ_CHAT_MODEL } from '../constants/groq.constant';
import { GroqConfig } from '../types/groq-config.type';

export function InjectChatModel() {
  return Inject(GROQ_CHAT_MODEL);
}

export const GroqChatModelProvider: Provider<ChatGroq> = {
  provide: GROQ_CHAT_MODEL,
  useFactory: (configService: ConfigService) => {
    const { apiKey, model } = configService.get<GroqConfig>('groq');
    return new ChatGroq({
      apiKey,
      model,
      temperature: 0.7,
      maxTokens: 2048,
      streaming: false,
    });
  },
  inject: [ConfigService],
};
Enter fullscreen mode Exit fullscreen mode

The InjectChatModel is a shortcut to inject the Groq Chat Model in a service.

Create the Storytelling Service

import { StringOutputParser } from '@langchain/core/output_parsers';
import { ChatPromptTemplate } from '@langchain/core/prompts';
import { ChatGroq } from '@langchain/groq';
import { Injectable, Logger } from '@nestjs/common';
import { InjectChatModel } from './providers/model.provider';

@Injectable()
export class StorytellingService {
  private readonly logger = new Logger(StorytellingService.name);

  constructor(@InjectChatModel() private llm: ChatGroq) {}

  async ask(inputs: string[]): Promise<string> {
    const categories = inputs.join(',');
    this.logger.log(`categories: ${categories}`);

    const promptTemplate = ChatPromptTemplate.fromMessages([
      [
        'system',
        'You are a professional storyteller with vivid imagination who can tell a story about certain objects, animals, and human being',
      ],
      [
        'user',
        `Please write a story with the following categories delimited by triple dashes:
        ---{categories}---

        The story should be written in one paragraph, 300 words max.
        Story:
      `,
      ],
    ]);

    const chain = promptTemplate.pipe(this.llm).pipe(new StringOutputParser());
    const response = await chain.invoke({ categories });

    this.logger.log(response);

    return response;
  }
}
Enter fullscreen mode Exit fullscreen mode

The StorytellingService service is straightforward. The promptTemplate is an array of messages. The system message provides the context of the model that is a professional storyteller who can tell a story about different categories. The user message asks the model to write a story about some categories in a paragraph, at most 300 words. The service injects an instance of ChatGroq, invokes the invoke method to submit the categories to the chain and outputs a string out.

Add the Storytelling Controller

// presenters/dtos/ask.dto.ts

import { ApiProperty } from '@nestjs/swagger';
import { IsArray, IsNotEmpty, IsString } from 'class-validator';

export class AskDto {
  @ApiProperty({
    isArray: true,
    type: String,
  })
  @IsArray()
  @IsNotEmpty({ each: true })
  @IsString({ each: true })
  categories: string[];
}
Enter fullscreen mode Exit fullscreen mode
// presenters/http/storytelling.controller.ts

import { Body, Controller, HttpStatus, Post } from '@nestjs/common';
import { AskDto } from '../dtos/ask.dto';
import { StorytellingService } from '~storytelling/application/storytelling.service';

@Controller('storytelling')
export class StorytellingController {
  constructor(private service: StorytellingService) {}

  @Post()
  async ask(@Body() dto: AskDto): Promise<{ story: string }> {
    const story = await this.service.ask(dto.categories);
    return { story };
  }
}
Enter fullscreen mode Exit fullscreen mode

The Storytelling controller submits the query to the chain to generate a short story, and sends the response back to the Angular application.

Angular Application

Scaffold an Angular Application

ng new ng-ai-sprint-2024-demo
Enter fullscreen mode Exit fullscreen mode

Upload models to Google Cloud Storage

First, I downloaded a few image classifying models in here https://www.kaggle.com/models/google/aiy/tfLite/vision-classifier-birds-v1/3 and https://ai.google.dev/edge/mediapipe/solutions/vision/image_classifier/index#models

Then, I uploaded them to a new GCS bucket to keep the bundle size of my project small. Next, I update the CORS policy of the bucket such that the Angular application can load these files.

// cors.json

[
    {
      "origin": ["http://localhost:4200"],
      "responseHeader": ["Content-Type"],
      "method": ["GET", "HEAD", "PUT", "POST"],
      "maxAgeSeconds": 3600
    }
]
Enter fullscreen mode Exit fullscreen mode
cd ~/google-cloud-sdk 
gcloud storage buckets update gs://<bucket name> --cors-file=cors.json
Enter fullscreen mode Exit fullscreen mode

The gcloud command updates the CORS policy of the GCS bucket.

Load the models during application startup

When the application starts, I use the APP_INITIALIZER token to initialize the models. The models must be available before users make their first classification request.

// assets/config.json

{
    "taskVisionUrl": "https://cdn.jsdelivr.net/npm/@mediapipe/tasks-vision@0.10.15/wasm",
    "modelLocations": [
        { 
            "name": "EfficientNet-Lite0 model",
            "path": <public url>
        },
        {
            "name": "Vision Classifier Food V1",
            "path": <public url>
        }
    ],
    "maxResults": 3,
    "backendUrl": "http://localhost:3001"
}
Enter fullscreen mode Exit fullscreen mode

The config.json file stores the configuration of the MediaPipe SDK and the public URL of the classify models.

// core/utils/load-classifier.ts

import { FilesetResolver, ImageClassifier } from '@mediapipe/tasks-vision';
import config from '~assets/config.json';
import { ImageClassifierModel } from '../types/image-classifier.type';

async function createImageClassifier(modelAssetPath: string): Promise<ImageClassifier> {
    const vision = await FilesetResolver.forVisionTasks(config.taskVisionUrl);
    return ImageClassifier.createFromOptions(vision, {
      baseOptions: {
        modelAssetPath
      },
      maxResults: config.maxResults,
      runningMode: 'IMAGE',
    });
}

export async function loadClassifiers() {
  const classifierMap: {[key: string]: ImageClassifier } = {};

  const promises = config.modelLocations.map(async (model) => {
    const classifier = await createImageClassifier(model.path);
    return {
      name: model.name,
      classifier,
    } as ImageClassifierModel;
  })

  const classifiers = await Promise.all(promises);
  for (const { name, classifier } of classifiers) {
    classifierMap[name] = classifier;
  }

  return classifierMap;
}
Enter fullscreen mode Exit fullscreen mode

The loadClassifiers function loads the classify models and stores them in a map. When a user selects a model name, the corresponding model is used to classify an image.

// core/providers/core.provider.ts

import { APP_INITIALIZER, Provider } from '@angular/core';
import { ImageClassificationService } from '~app/image-classification/services/image-classification.service';

export function provideAppInitializer(): Provider {
    return {
        provide: APP_INITIALIZER,
        multi: true,
        useFactory: (service: ImageClassificationService) => () => service.init(),
        deps: [ImageClassificationService]
    } as Provider;
}
Enter fullscreen mode Exit fullscreen mode
import { ApplicationConfig, provideExperimentalZonelessChangeDetection } from '@angular/core';
import { provideHttpClient } from '@angular/common/http';
import { provideAppInitializer } from './core/providers/core.provider';

export const appConfig: ApplicationConfig = {
  providers: [
    provideExperimentalZonelessChangeDetection(),
    provideHttpClient(),
    provideAppInitializer(),
  ]
};
Enter fullscreen mode Exit fullscreen mode

In appConfig, the provideAppInitializer function loads the models in memory.

Create the Services

import { computed, Injectable, signal } from '@angular/core';
import { ImageClassifier } from '@mediapipe/tasks-vision';
import { loadClassifiers } from '~app/core/utils/load-classifier';
import { ImageClassificationResult } from '../types/image-classification.type';

@Injectable({
  providedIn: 'root'
})
export class ImageClassificationService {
  #classifierMap = signal<{ [key:string]: ImageClassifier }>({});
  modelNames = computed(() => Object.keys(this.#classifierMap()));

  async init() {
    const classifiers = await loadClassifiers();
    this.#classifierMap.set(classifiers);
  }

  classify(modelName: string, source: TexImageSource): ImageClassificationResult {
    if (!this.#classifierMap()[modelName]) {
      throw new Error(`The model, ${modelName}, does not exist`);
    }

    const classifier = this.#classifierMap()[modelName];
    if (!classifier) {
      throw new Error('The classifier is undefined.');
    }

    const results = classifier.classify(source)
    if (results.classifications.length <= 0) {
      throw new Error('No result.');
    }

    const categoryScores =  results.classifications[0].categories.map(({ categoryName, displayName, score }) => ({
      categoryName: displayName || categoryName,
      score: (score * 100).toFixed(2),
    }));

    const categories = categoryScores.map((item) => item.categoryName);

    return {
      categoryScores,
      categories,
    };
  }
}
Enter fullscreen mode Exit fullscreen mode

The classify method of the ImageClassificationService service accepts a model name and an HTMLImageElement, and asks the classify model to return the top 3 categories.

import { HttpClient } from '@angular/common/http';
import { inject, Injectable } from '@angular/core';
import { lastValueFrom, map } from 'rxjs';
import config from '~assets/config.json';
import { Story } from '../types/story.type';

@Injectable({
  providedIn: 'root'
})
export class StorytellingService {
  httpClient = inject(HttpClient);

  generateStory(categories: string[]): Promise<string> {
    const storytellingUrl = `${config.backendUrl}/storytelling`;
    return lastValueFrom(this.httpClient.post<Story>(storytellingUrl, {
        categories
      }).pipe(map(({ story }) => story))
    );
  }
}
Enter fullscreen mode Exit fullscreen mode

The generateStory method calls the backend to generate a story about the categories.

Build the user interface

import { ChangeDetectionStrategy, Component, inject } from '@angular/core';
import { ClassificationContainerComponent } from './image-classification/components';

@Component({
  selector: 'app-root',
  standalone: true,
  imports: [ClassificationContainerComponent],
  template: '<app-classification-container />',
  changeDetection: ChangeDetectionStrategy.OnPush,
})
export class AppComponent {}
Enter fullscreen mode Exit fullscreen mode

The AppComponent has a container component for image classification.

import { ChangeDetectionStrategy, Component, inject, signal } from '@angular/core';
import { ImageClassificationService } from '../services/image-classification.service';
import { ClassificationComponent } from './classification.component';
import { GeneratedStoryComponent } from './generated-story.component';
import { ImageClassificationResult } from '../types/image-classification.type';

@Component({
  selector: 'app-classification-container',
  standalone: true,
  imports: [ClassificationComponent, GeneratedStoryComponent],
  template: `
    <div>
      <h2 class="title">Storytelling by MediaPipe Image Classifier Task and Gemma 2</h2>
      <app-classification [models]="service.modelNames()" class="classification" 
        (results)="results.set($event)" (story)="story.set($event)" />
      <app-generated-story [results]="results()" [story]="story()" /> 
    </div>
  `,
  changeDetection: ChangeDetectionStrategy.OnPush,
})
export class ClassificationContainerComponent {
  results = signal<ImageClassificationResult[]>([]);
  story = signal('No story has generated.');

  service = inject(ImageClassificationService);
}
Enter fullscreen mode Exit fullscreen mode

This container component has two components: ClassificationComponent allows a user to select an image and find out the classification. GeneratedStoryComponent displays the categories, system and user prompts and story.

import { ChangeDetectionStrategy, Component, input, output, signal } from '@angular/core';
import { FormsModule } from '@angular/forms';
import { CategoryScore } from '../types/image-classification.type';
import { PreviewImageComponent } from './preview-image.component';

@Component({
  selector: 'app-classification',
  standalone: true,
  imports: [FormsModule, PreviewImageComponent],
  template: `
    <label for="models">Image Classifier Models: </label>
    <select id="models" name="models" [(ngModel)]="selectedModel">
      @for(model of models(); track model) {
        <option [value]="model">{{ model }}</option>
      }
    </select>
    <app-preview-image [model]="selectedModel()" (results)="results.emit($event)"
      (story)="story.emit($event)"
    />
  `,
  changeDetection: ChangeDetectionStrategy.OnPush
})
export class ClassificationComponent {
  models = input.required<string[]>();
  selectedModel = signal('EfficientNet-Lite0 model');

  results = output<CategoryScore[]>();
  story = output<string>();
}
Enter fullscreen mode Exit fullscreen mode

This component populates the model names in a dropdown and an HTMLImageElement to preview the selected image.

import { ChangeDetectionStrategy, Component, computed, ElementRef, input, output, signal, viewChild } from '@angular/core';
import { CategoryScore } from '../types/image-classification.type';
import { ClassificationButtonsComponent } from './classification-buttons.component';

@Component({
  selector: 'app-preview-image',
  standalone: true,
  imports: [ClassificationButtonsComponent],
  template: `
    <div>
      <input type="file" #fileInput style="display: none;" accept=".jpg, .jpeg, .png" (change)="previewImage($event)" />
      <div id="imageContainer"><img #imagePreview /></div>

      <app-classification-buttons [model]="model()" [imageSource]="imageElement()" [hasImage]="hasImage()" 
        (openFileDialog)="openFileDialog()" (results)="results.emit($event)" (story)="story.emit($event)" />
    </div>
  `,
  changeDetection: ChangeDetectionStrategy.OnPush
})
export class PreviewImageComponent {
  model = input.required<string>();
  fileInput = viewChild.required<ElementRef<HTMLInputElement>>('fileInput');
  imagePreview = viewChild.required<ElementRef<HTMLImageElement>>('imagePreview');

  hasImage = signal(false);
  imageElement = computed(() => this.imagePreview().nativeElement);

  results = output<CategoryScore[]>();
  story = output<string>();

  openFileDialog() {
    this.fileInput().nativeElement.click();
  }

  getFirstFile(event: Event) {
    return event.target && 'files' in event.target && event.target.files instanceof FileList && event.target.files.length ?
      event.target.files[0] : null;
  }

  previewImage(event: Event) {
    const reader = new FileReader();
    reader.onload = () => {
      if (reader.result && typeof reader.result === 'string') {
        this.imagePreview().nativeElement.src = reader.result;
        this.hasImage.set(true);
      }
    }

    this.hasImage.set(false);
    const file = this.getFirstFile(event);
    if (file) {
      reader.readAsDataURL(file);
    }
  }
}
Enter fullscreen mode Exit fullscreen mode

A user clicks a button to choose an image from a file dialog and the previewImage method updates the source of the image element.

import { ChangeDetectionStrategy, Component, computed, inject, input, output, signal } from '@angular/core';
import { ImageClassificationService } from '../services/image-classification.service';
import { StorytellingService } from '../services/storytelling.service';
import { CategoryScore } from '../types/image-classification.type';

@Component({
  selector: 'app-classification-buttons',
  standalone: true,
  template: `
    <button (click)="openFileDialog.emit()">Choose an image</button>
    <button (click)="classify()" [disabled]="buttonState().disabled()">{{ buttonState().classifyText() }}</button>
    <button (click)="generateStory()" [disabled]="buttonState().disabled()">{{ buttonState().generateText() }}</button>
  `,
  changeDetection: ChangeDetectionStrategy.OnPush
})
export class ClassificationButtonsComponent {

  model = input.required<string>();
  imageSource = input.required<HTMLImageElement>();
  hasImage = input(false);

  categories = signal<string[]>([]);

  buttonState = computed(() => ({
    classifyText: signal('Classify the image'),
    generateText: signal('Generate a story'),
    disabled: signal(!this.hasImage()),
  }));

  results = output<CategoryScore[]>();
  story = output<string>();
  openFileDialog = output();

  classificationService = inject(ImageClassificationService);
  storytellingService = inject(StorytellingService);

  classify() {
    this.buttonState().disabled.set(true);
    this.buttonState().classifyText.set('Classifying...');
    const { categoryScores, categories } = this.classificationService.classify(this.model(), this.imageSource());
    this.results.emit(categoryScores);
    this.categories.set(categories);
    this.buttonState().classifyText.set('Classify the image');
    this.buttonState().disabled.set(false);  
  }

  async generateStory() {
    this.buttonState().disabled.set(true);
    this.buttonState().generateText.set('Generating...');
    const story = await this.storytellingService.generateStory(this.categories());
    this.story.emit(story);
    this.buttonState().generateText.set('Generate a story');
    this.buttonState().disabled.set(false);  
  }
} 
Enter fullscreen mode Exit fullscreen mode

When a user clicks the Classify the image button, the component calls the ImageClassificationService service to obtain the categories and output them to the PreviewImageComponent component. When a user clicks the Generate a story button, the component calls the StorytellingService service to generate a story and output the content to the parent component.

import { ChangeDetectionStrategy, Component, computed, input } from '@angular/core';
import { CategoryScore } from '../types/image-classification.type'.

@Component({
  selector: 'app-generated-story',
  standalone: true,
  imports: [],
  template: `
    <div>
      <h3>Classifications:</h3>
      @for (result of results(); track result.categoryName) {
        <p>{{ result.categoryName }}: {{ result.score + '%' }}</p>
      }
      <h3>Prompt:</h3>
      <p><label>System:</label>{{ systemPrompt() }}</p>
      @let userPromptLines = userPrompt().split('\n');
      @for (prompt of userPromptLines; track $index) {
        <p>
          @if ($index === 0) {
            <label>User:</label>
          }
          {{ prompt }}
        </p>
      }
      <h3>Story:</h3>
      <p>{{ story() }}</p>
    </div>
  `,
   changeDetection: ChangeDetectionStrategy.OnPush
})
export class GeneratedStoryComponent {
  results = input.required<CategoryScore[]>();
  story = input.required<string>();

  categories = computed(() => this.results().map(({ categoryName }) => categoryName).join(', '));

  systemPrompt = computed(() => 
    'You are a professional storyteller with a vivid imagination who can tell a story about given categories.'
  );

  userPrompt = computed(() =>  
    `Please write a story with the following categories delimited by triple dashes:
    ---${this.categories()}---

    The story should be written in one paragraph, 300 words max.
    Story:
    -----------------------------------------------------------
    `);
}
Enter fullscreen mode Exit fullscreen mode

The GeneratedStoryComponent displays the categories, system and user prompts, and the story of the LLM.

In conclusion, software engineers can create Web AI applications without an AI/ML background.

Resources:

. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .