from __future__ import annotations from typing_extensions import override import logging import re from typing import Any, Union from langgraph.graph import START, StateGraph from typing_extensions import Dict from ..rag.rag import RAG, State from ..llm.llm import LLM from ..config.settings import Settings from ..config.rat_config import RATConfig class RAT(RAG): """ Retrieval-Augmented Thinking (RAT) class that extends the Retrieval-Augmented Generation (RAG) pipeline to incorporate a reasoning step using a dedicated reasoning LLM. This class performs an iterative reflection process to refine the input query based on retrieved context and a reasoning LLM, before generating a final answer. Attributes: reasoning_llm (LLM): The LLM used for reasoning and generating reflections. reflection (int): The number of reasoning iterations to perform for the input query. """ def __init__(self, config: RATConfig) -> None: """ Initializes the RAT pipeline with the required components. Args: embedding_model (EmbeddingsModel): The embeddings model for vectorization. vector_store (VectorStore): The vector store for document retrieval. reasoning_llm (LLM): The LLM used for reasoning and iterative reflection. llm (LLM): The LLM used for generating the final answer. k (int, optional): The number of top documents to retrieve. Defaults to 2. reflection (int, optional): The number of reasoning iterations to perform. Defaults to 1. """ rag_config = config.get_rag_config() super().__init__(rag_config) self.reasoning_llm: LLM = config.reasoning_llm self.reflection: int = config.reflection def set_reflection(self, reflection: int) -> None: """ Set the reflection attribute for this instance. :param reflection: The new reflection value (integer) to assign. :type reflection: int :return: None """ self.reflection = reflection def think(self, input: Dict[str, str]) -> Dict[str, str]: """ Generates iterative reasoning or reflection based on the input question and retrieved documents. This method performs multiple reasoning iterations using the reasoning LLM. Each iteration refines the reflection based on the retrieved context and the reasoning generated by the LLM. Args: input (Dict[str, List[Document]]): A dictionary containing: - 'question': The input question. Returns: input (Dict[str, List[Document]]): A dictionary containing: - 'question': The input question with reflection about the question. """ reflection = "" for _ in range(self.reflection): retrieved_docs = self.retrieve({"question": reflection}) docs_content = "\n\n".join( doc.page_content for doc in retrieved_docs["context"] ) prompt_json = { "question": input["question"], "context": docs_content, "reflection": reflection, } response = "\n" + self.reasoning_llm.generate(prompt_json) think = re.findall(Settings.THINKING_PATTERN, response, re.DOTALL) if not think: logging.warning("No reasoning found in the LLM response.") else: reflection = reflection + "\n" + think[0] question = input["question"] + f"\nReflection about the problem: {reflection}" return {"question": question} def createGraph(self) -> Any: """ Creates and compiles the state graph for the RAG pipeline. Returns: StateGraph: The compiled state graph for managing the RAG process flow. """ graph_builder = StateGraph(State).add_sequence( [self.think, self.retrieve, self.generate] ) graph_builder.add_edge(START, "think") return graph_builder.compile() @override def question_graph(self, question: str) -> str: """ Executes the RAT pipeline for a given question by first generating a reflection and then invoking the state graph to generate a final answer. The reflection step enhances the reasoning process by iteratively refining the context and query using the `think` method before the final answer is generated. Args: question (str): The input question or query. Returns: str: The generated answer from the pipeline, which incorporates both the initial question and the reasoning generated by the `think` method. """ state = {"question": question} response = self.graph.invoke(state) return response["answer"]