Matheus Torquato
Matheus Torquato

Reputation: 1629

Getting Tokens Usage Metadata from Gemini LLM calls in LangChain RAG RunnableSequence

I would like to have the token utilisation of my RAG chain each time it is invoked.

No matter what I do, I can't seem to find the right way to output the total tokens from the Gemini model I'm using.

import vertexai
from langchain_google_vertexai import VertexAI
from vertexai.generative_models import GenerativeModel

vertexai.init(
    project='MY_PROJECT',
    location="MY_LOCATION",
)

question = "What is the meaning of life"


llm = VertexAI(model_name="gemini-1.5-pro-001",)
response1 = llm.invoke(question)


llm2 = GenerativeModel("gemini-1.5-pro-001",)
response2 = llm2.generate_content(question)

response1 above is just a string.

response2 is what I want i.e. a dictionary containing usage_metadata, safety_rating, finish_reason, etc. But I haven't managed to make my RAG chain run using this approach.

My RAG chain is a RunnableSequence (from langchain_core.runnables) and also I've tried using callbacks as the chain does not support class 'vertexai.generative_models.GenerativeModel'

from langchain_google_vertexai import VertexAI
from langchain.callbacks.base import BaseCallbackHandler
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.outputs import LLMResult
from langchain_core.messages import BaseMessage


class LoggingHandler(BaseCallbackHandler):

    def on_llm_start(self, serialized, prompts, **kwargs) -> None:

        print('On LLM Start: {}'.format(prompts))

    def on_llm_end(self, response: LLMResult, **kwargs) -> None:

        print('On LLM End: {}'.format(response))


callbacks = [LoggingHandler()]
llm = VertexAI(model_name="gemini-1.5-pro-001",)
prompt = ChatPromptTemplate.from_template("What is 1 + {number}?")

chain = prompt | llm

chain_with_callbacks = chain.with_config(callbacks=callbacks)
response = chain_with_callbacks.invoke({"number": "2"})

This shows the content below

On LLM Start: ['Human: What is 1 + 2?']
On LLM End: generations=[[GenerationChunk(text='Human: What is 1 + 2?\nAssistant: 3 \n', generation_info={'is_blocked': False, 'safety_ratings': [{'category': 'HARM_CATEGORY_HATE_SPEECH', 'probability_label': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'probability_label': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HARASSMENT', 'probability_label': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 'probability_label': 'NEGLIGIBLE', 'blocked': False}], 'citation_metadata': })]] llm_output=None run=None

I.e. no usage metadata.

Any idea how to have the usage metadata for each RAG chain call?

Upvotes: 0

Views: 525

Answers (1)

Arindam
Arindam

Reputation: 81

Instead of VertexAI use, ChatVertexAI

from langchain_google_vertexai import ChatVertexAI

it has the token count in usage metadata total_token_count in response. parse it via JSON:

AIMessage(content="J'adore programmer. \n", response_metadata={'is_blocked': False, 'safety_ratings': [{'category': 'HARM_CATEGORY_HATE_SPEECH', 'probability_label': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_DANGEROUS_CONTENT', 'probability_label': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_HARASSMENT', 'probability_label': 'NEGLIGIBLE', 'blocked': False}, {'category': 'HARM_CATEGORY_SEXUALLY_EXPLICIT', 'probability_label': 'NEGLIGIBLE', 'blocked': False}], 'usage_metadata': {'prompt_token_count': 20, 'candidates_token_count': 7, 'total_token_count': 27}}, id='run-7032733c-d05c-4f0c-a17a-6c575fdd1ae0-0', usage_metadata={'input_tokens': 20, 'output_tokens': 7, 'total_tokens': 27})

Format of input:

messages = [
    (
        "system",
        "You are a helpful assistant that translates English to French. Translate the user sentence.",
    ),
    ("human", "I love programming."),
]
ai_msg = llm.invoke(messages)

Upvotes: 0

Related Questions