tail_recursion
tail_recursion

Reputation: 740

Does Mistral 7b work with Langchain tools?

I am following this tutorial which is the third search result on Google for 'langchain tools'. I am trying to get Mistral 7b Instruct to use a simple circumference calculator tool. I keep getting "Could not parse LLM output" errors. I tried setting 'handle_parsing_errors' to True, but it does not help.

Here is the code;

from langchain.llms import LlamaCpp
llm = LlamaCpp(model_path="./mistral_7b_instruct/mistral-7b-instruct-v0.1.Q4_K_M.gguf", verbose=True, n_ctx=4000, temperature=0)

from langchain.tools import BaseTool
from math import pi
from typing import Union
  

class CircumferenceTool(BaseTool):
    name = "Circumference calculator"
    description = "use this tool when you need to calculate a circumference using the radius of a circle"

    def _run(self, radius: Union[int, float]):
        return float(radius)*2.0*pi

    def _arun(self, radius: int):
        raise NotImplementedError("This tool does not support async")
        
from langchain.chains.conversation.memory import ConversationBufferWindowMemory

conversational_memory = ConversationBufferWindowMemory(
        memory_key='chat_history',
        k=5,
        return_messages=True
)

from langchain.agents import initialize_agent

tools = [CircumferenceTool()]

# initialize agent with tools
agent = initialize_agent(
    agent='chat-conversational-react-description',
    tools=tools,
    llm=llm,
    verbose=True,
    max_iterations=3,
    early_stopping_method='generate',
    memory=conversational_memory,
    handle_parsing_errors=True
)

agent("can you calculate the circumference of a circle that has a radius of 7.81mm")

and the output is this;

> Entering new AgentExecutor chain...
▅

ASSISTANT'S RESPONSE
--------------------
json
{
    "action": "Circumference calculator",
    "action_input": "7.81"
}

Observation: 49.071677249072565
Thought:

llama_print_timings:        load time =     445.58 ms
llama_print_timings:      sample time =       9.25 ms /    53 runs   (    0.17 ms per token,  5729.11 tokens per second)
llama_print_timings: prompt eval time =   25256.61 ms /   562 tokens (   44.94 ms per token,    22.25 tokens per second)
llama_print_timings:        eval time =    2743.39 ms /    52 runs   (   52.76 ms per token,    18.95 tokens per second)
llama_print_timings:       total time =   28164.19 ms
Llama.generate: prefix-match hit
Could not parse LLM output: 
Observation: Invalid or incomplete response
Thought:

llama_print_timings:        load time =     445.58 ms
llama_print_timings:      sample time =       0.17 ms /     1 runs   (    0.17 ms per token,  5780.35 tokens per second)
llama_print_timings: prompt eval time =    8081.89 ms /   172 tokens (   46.99 ms per token,    21.28 tokens per second)
llama_print_timings:        eval time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings:       total time =    8108.44 ms
Llama.generate: prefix-match hit
Could not parse LLM output: 
Observation: Invalid or incomplete response
Thought:

llama_print_timings:        load time =     445.58 ms
llama_print_timings:      sample time =       0.17 ms /     1 runs   (    0.17 ms per token,  5780.35 tokens per second)
llama_print_timings: prompt eval time =    5237.91 ms /   112 tokens (   46.77 ms per token,    21.38 tokens per second)
llama_print_timings:        eval time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_print_timings:       total time =    5254.49 ms
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[35], line 43
     31 # initialize agent with tools
     32 agent = initialize_agent(
     33     agent='chat-conversational-react-description',
     34     tools=tools,
   (...)
     40     handle_parsing_errors=True
     41 )
---> 43 agent("can you calculate the circumference of a circle that has a radius of 7.81mm")

File ~/anaconda3/lib/python3.10/site-packages/langchain/chains/base.py:310, in Chain.__call__(self, inputs, return_only_outputs, callbacks, tags, metadata, run_name, include_run_info)
    308 except BaseException as e:
    309     run_manager.on_chain_error(e)
--> 310     raise e
    311 run_manager.on_chain_end(outputs)
    312 final_outputs: Dict[str, Any] = self.prep_outputs(
    313     inputs, outputs, return_only_outputs
    314 )

File ~/anaconda3/lib/python3.10/site-packages/langchain/chains/base.py:304, in Chain.__call__(self, inputs, return_only_outputs, callbacks, tags, metadata, run_name, include_run_info)
    297 run_manager = callback_manager.on_chain_start(
    298     dumpd(self),
    299     inputs,
    300     name=run_name,
    301 )
    302 try:
    303     outputs = (
--> 304         self._call(inputs, run_manager=run_manager)
    305         if new_arg_supported
    306         else self._call(inputs)
    307     )
    308 except BaseException as e:
    309     run_manager.on_chain_error(e)

File ~/anaconda3/lib/python3.10/site-packages/langchain/agents/agent.py:1190, in AgentExecutor._call(self, inputs, run_manager)
   1188     iterations += 1
   1189     time_elapsed = time.time() - start_time
-> 1190 output = self.agent.return_stopped_response(
   1191     self.early_stopping_method, intermediate_steps, **inputs
   1192 )
   1193 return self._return(output, intermediate_steps, run_manager=run_manager)

File ~/anaconda3/lib/python3.10/site-packages/langchain/agents/agent.py:703, in Agent.return_stopped_response(self, early_stopping_method, intermediate_steps, **kwargs)
    701 new_inputs = {"agent_scratchpad": thoughts, "stop": self._stop}
    702 full_inputs = {**kwargs, **new_inputs}
--> 703 full_output = self.llm_chain.predict(**full_inputs)
    704 # We try to extract a final answer
    705 parsed_output = self.output_parser.parse(full_output)

File ~/anaconda3/lib/python3.10/site-packages/langchain/chains/llm.py:298, in LLMChain.predict(self, callbacks, **kwargs)
    283 def predict(self, callbacks: Callbacks = None, **kwargs: Any) -> str:
    284     """Format prompt with kwargs and pass to LLM.
    285 
    286     Args:
   (...)
    296             completion = llm.predict(adjective="funny")
    297     """
--> 298     return self(kwargs, callbacks=callbacks)[self.output_key]

File ~/anaconda3/lib/python3.10/site-packages/langchain/chains/base.py:310, in Chain.__call__(self, inputs, return_only_outputs, callbacks, tags, metadata, run_name, include_run_info)
    308 except BaseException as e:
    309     run_manager.on_chain_error(e)
--> 310     raise e
    311 run_manager.on_chain_end(outputs)
    312 final_outputs: Dict[str, Any] = self.prep_outputs(
    313     inputs, outputs, return_only_outputs
    314 )

File ~/anaconda3/lib/python3.10/site-packages/langchain/chains/base.py:304, in Chain.__call__(self, inputs, return_only_outputs, callbacks, tags, metadata, run_name, include_run_info)
    297 run_manager = callback_manager.on_chain_start(
    298     dumpd(self),
    299     inputs,
    300     name=run_name,
    301 )
    302 try:
    303     outputs = (
--> 304         self._call(inputs, run_manager=run_manager)
    305         if new_arg_supported
    306         else self._call(inputs)
    307     )
    308 except BaseException as e:
    309     run_manager.on_chain_error(e)

File ~/anaconda3/lib/python3.10/site-packages/langchain/chains/llm.py:108, in LLMChain._call(self, inputs, run_manager)
    103 def _call(
    104     self,
    105     inputs: Dict[str, Any],
    106     run_manager: Optional[CallbackManagerForChainRun] = None,
    107 ) -> Dict[str, str]:
--> 108     response = self.generate([inputs], run_manager=run_manager)
    109     return self.create_outputs(response)[0]

File ~/anaconda3/lib/python3.10/site-packages/langchain/chains/llm.py:117, in LLMChain.generate(self, input_list, run_manager)
    111 def generate(
    112     self,
    113     input_list: List[Dict[str, Any]],
    114     run_manager: Optional[CallbackManagerForChainRun] = None,
    115 ) -> LLMResult:
    116     """Generate LLM result from inputs."""
--> 117     prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
    118     callbacks = run_manager.get_child() if run_manager else None
    119     if isinstance(self.llm, BaseLanguageModel):

File ~/anaconda3/lib/python3.10/site-packages/langchain/chains/llm.py:179, in LLMChain.prep_prompts(self, input_list, run_manager)
    177 for inputs in input_list:
    178     selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
--> 179     prompt = self.prompt.format_prompt(**selected_inputs)
    180     _colored_text = get_colored_text(prompt.to_string(), "green")
    181     _text = "Prompt after formatting:\n" + _colored_text

File ~/anaconda3/lib/python3.10/site-packages/langchain/prompts/chat.py:339, in BaseChatPromptTemplate.format_prompt(self, **kwargs)
    330 def format_prompt(self, **kwargs: Any) -> PromptValue:
    331     """
    332     Format prompt. Should return a PromptValue.
    333     Args:
   (...)
    337         PromptValue.
    338     """
--> 339     messages = self.format_messages(**kwargs)
    340     return ChatPromptValue(messages=messages)

File ~/anaconda3/lib/python3.10/site-packages/langchain/prompts/chat.py:588, in ChatPromptTemplate.format_messages(self, **kwargs)
    580 elif isinstance(
    581     message_template, (BaseMessagePromptTemplate, BaseChatPromptTemplate)
    582 ):
    583     rel_params = {
    584         k: v
    585         for k, v in kwargs.items()
    586         if k in message_template.input_variables
    587     }
--> 588     message = message_template.format_messages(**rel_params)
    589     result.extend(message)
    590 else:

File ~/anaconda3/lib/python3.10/site-packages/langchain/prompts/chat.py:99, in MessagesPlaceholder.format_messages(self, **kwargs)
     97 value = kwargs[self.variable_name]
     98 if not isinstance(value, list):
---> 99     raise ValueError(
    100         f"variable {self.variable_name} should be a list of base messages, "
    101         f"got {value}"
    102     )
    103 for v in value:
    104     if not isinstance(v, BaseMessage):

ValueError: variable agent_scratchpad should be a list of base messages, got ▅

ASSISTANT'S RESPONSE
--------------------
json
{
    "action": "Circumference calculator",
    "action_input": "7.81"
}

Observation: 49.071677249072565
Thought:Could not parse LLM output: 
Observation: Invalid or incomplete response
Thought:Could not parse LLM output: 
Observation: Invalid or incomplete response
Thought:

I now need to return a final answer based on the previous steps:

EDIT:

I am thinking the best way might be to look at the observations and extract the answer - in my actual application the answer will have a specific format that makes it easy to detect using regex.

Upvotes: 0

Views: 2841

Answers (1)

ZKS
ZKS

Reputation: 2836

Temporary work around would be as below

 try:
         response= agent("can you calculate the circumference of a circle that has a radius of 7.81mm")
 except Exception as e:
         response = str(e)
         if response.startswith("Could not parse LLM output: `"):
             response = response.removeprefix("Could not parse LLM output: `").removesuffix("`")
             print(response)

Upvotes: 0

Related Questions