Querying databases using LangChain and Ollama
5 September 2024
One of the most well-known issues with LLMs is the phenomenon of “hallucinations”, where it makes up facts and presents them confidently as true. It’s one of the main challenges of using LLMs in corporations, since improper validation by the user may result in people drawing conclusions based on flawed information (like this famous example, and this one). If it wasn’t obvious already not to blindly trust anything you find on the internet (on an app for that matter), even if it comes from fancy models most users hardly understand, it does highlight again that it’s important to be smart about how you ask questions and to validate the output from LLMs.
One major use case for LLMs in use is in large corporations with a vast collection of documents, databases, and files related to their business area is to create something like “institutional memory”. This is typically done using Retrieval-Augmented Generation (or RAG). Simply said, RAG is a bit of a hack to add information the original model did not have access to to the model so it can include it in the response when asked about it. It also adds source transparency, which means it can show where it found the answer to the question, which is useful. In more complicated terms, RAG is a method to compare a query from a user to a database containing vectors representing embeddings, retrieve the most relevant data (i.e. the nearest vectors to the vector representing the query), link to the part of the data it retrieved the information from, and summarizing the result again the model if necessary. Since basically all companies have databases with proprietary and possibly sensitive data, the out-of-the-box LLMs will not generate accurate responses. In many cases, even the questions asked by analysts can already be sensitive, which means the company needs to use a model running on their own servers or otherwise ensure that the data is not returned to for example OpenAI, Microsoft, Meta, or Google. What we’re doing here in this blogpost is not RAG, but it is related.
In this post I want to explore how one might go about prompt engineering to retrieve more accurate results incorporating a local SQL database, return a SQL how to use LangChain to talk to a local database and return not just a SQL query, but also get the model to run the SQL query and return the answer in text form. I won’t go into embeddings and vector databases this time around but I might write a separate blogpost about those in the future.
Setup
First we’ll setup our environment. To run the model locally we’ll use Ollama which we’ll install in the usual way. Since I’m writing this blogpost on my MacBook Pro (2020 model, 2GHz Quad-Code Intel Core i5 with 16GB RAM, I haven’t gotten the accelerator yet for the Raspberry Pi). And I don’t have much storage space to download large models so we’ll use a smaller one and hope it does the trick. For Mac Ollama is just a tool you can install, on Linux the installation is done using the snippet below:
curl -fsSL https://ollama.com/install.sh | sh
Ollama basically runs a server in the background hosting whatever model you chose to download. One downloads models by using the ollama pull <model>
in the Terminal (Mac or Linux). You can find all models available here. Obviously, Ollama only includes models with open weights such as llama 3.1, gemma, phi, and Mistral. For this post I’ll download the 7 billion parameter Mistral model (ollama pull mistral
). I’ll also do this experiment entirely in Python so I’ll install some packages we’ll use such as duckdb, and a variet of libraries from LangChain. Since the LangChain imports can be quite lengthy, I’ll violate the best practice and import the relevant packages only in the chunk when they’re needed so it’s easier to follow which import belongs to what snippet.
Let’s connect to the model first. LangChain lets us connect to any type of model, also online ones if we specify the access key. For models running locally using Ollama we can use the ChatOllama()
function from langchain_ollama
. There we can specify the model we have downloaded.
In my case, I currently have the following models locally on my Mac that serve different purposes:
ollama list
NAME ID SIZE MODIFIED
mistral:latest f974a74358d6 4.1 GB 27 hours ago
For this purpose we’ll use the Mistral model and make it balanced between accuracy and creativity by setting the temperature to 0. In a production setting one might want to experiment a bit this.
from langchain_ollama import ChatOllama
llm = ChatOllama(
model="mistral",
temperature=0,
)
Let’s see if it works by asking it a simple question using the evoke()
function.
import pandas as pd
response = llm.invoke("What is the current date?")
print(pd.DataFrame(response, columns=["variable", "value"]))
variable value
0 content The current date is March 20, 2023.
1 additional_kwargs {}
2 response_metadata {'model': 'mistral', 'created_at': '2024-09-05...
3 type ai
4 name None
5 id run-1d59e5c1-32a3-4152-947a-723515913979-0
6 example False
7 tool_calls []
8 invalid_tool_calls []
9 usage_metadata {'input_tokens': 11, 'output_tokens': 16, 'tot...
You’ll see the llm
object returns a dictionary with a number of key-value pairs, most important is the content
object that contains the written answer. Other items you might be interested in is the duration and the number of tokens in the usage_metadata
object. For performance reasons you might want to limit the number of tokens per interaction when deploying these models in production to limit the strain on the servers and keep costs down.
Simple prompt engineering
Let’s first begin without LangChain and just do some simple prompt engineering. This requires some preparations. The plan here is to tell the LLM as much as we can about a SQL database and then tell it some limitations that are fairly common to limit the creativity of the model a bit without adjusting the temperature. Since this isn’t a database the model has seen before (since I created it just for this purpose with some uncommon column names etc.), we want it to only take the information from the data provided. One way we can do that is to get the schema of the database and include it in the prompt. We’ll ignore token limits for now. We’ll also include “Please think carefully before you answer”, since I’ve been told by others using these models that even with a stricter token limit, it’s usually still worth it. Since this is not a RAG, we will not ask for an answer, but instead to provide a SQL query that will give the answer.
The database is a collection of sampled transactions from my bank account. Don’t worry, I’m not dumb enough to actually share any personal information here so I simulated the dataset so that none of the information in the database is real, but still realistic. I could have used a publicly accessible database like flights.db
or chinook.db
, but since these are so prevalent in the education space the LLMs already include them in their training dataset, and we want to explore data and databases completely unseen to the models. Let’s connect to the database and get the schema:
import duckdb
conn = duckdb.connect("data/budget.db")
cursor = conn.cursor()
schema_dict = {}
for tbl in cursor.execute("SHOW TABLES").fetchall():
schema_dict[tbl[0]] = cursor.execute(f"DESCRIBE {tbl[0]};").fetchall()
schema_dict
{'account': [('id', 'INTEGER', 'NO', 'PRI', "nextval('account_id')", None), ('account_number', 'VARCHAR', 'NO', None, None, None), ('account_type', 'VARCHAR', 'NO', None, None, None), ('owner', 'VARCHAR', 'NO', None, None, None)], 'category': [('id', 'INTEGER', 'NO', 'PRI', "nextval('category_id')", None), ('cat1', 'VARCHAR', 'NO', None, None, None), ('cat2', 'VARCHAR', 'NO', None, None, None), ('cat3', 'VARCHAR', 'YES', None, None, None), ('search', 'VARCHAR', 'NO', None, None, None)], 'transactions': [('id', 'INTEGER', 'NO', 'PRI', "nextval('tnx_id')", None), ('date', 'DATE', 'NO', None, None, None), ('account_type', 'VARCHAR', 'NO', None, None, None), ('amount', 'DOUBLE', 'NO', None, None, None), ('overarching_cat', 'VARCHAR', 'YES', None, None, None), ('category', 'VARCHAR', 'YES', None, None, None), ('retailer', 'VARCHAR', 'YES', None, None, None), ('details', 'VARCHAR', 'YES', None, None, None)]}
Let’s then create a function with our elaborate prompt, we’ll insert the database schema, the tables to use and the question. We’ll then parse the SQL query from the answer using regular expressions. Then we’ll take the parsed query and run it through the database connection we created earlier.
import re
def simple_prompt_engineering(question, schema, llm, conn):
direct_prompt = f"""
Given an input question, create a syntactically correct SQL query to
run in DuckDB. The database is defined by the following schema:
{schema}
Only use the following tables:
{conn.execute("SHOW TABLES").fetchall()}
Don't use more columns than strictly necessary. Be careful to not
query for columns that do not exist. Also, pay attention to which
column is in which table. Please think carefully before you answer.
Return only a SQL query and nothing else.
Question: """
response = llm.invoke(direct_prompt + question)
print(response.content)
result = str(conn.execute(response.content).fetchall()[0][0])
print(f"Answer: {result}")
simple_prompt_engineering(
"How many items did I buy at bunnpris?", schema_dict, llm, conn
)
SELECT COUNT(*) FROM transactions
WHERE retailer = 'bunnpris';
Answer: 47
This seems to work! Let’s see if it does for more complex queries too.
simple_prompt_engineering(
"How much did I on average spend at xxl?", schema_dict, llm, conn
)
SELECT AVG(amount) as Average_Spend_at_Xxl
FROM transactions
WHERE retailer = 'xxl';
Answer: 427.0309523809524
We can close the database connection now since we’ll use another connection when we deploy the LangChain implementation.
conn.close()
This worked too, but it’s not RAG method, it’s just hacky. And these SQL queries are simple and analysts would be able to write this without consulting a chatbot. However, this was just an illustration. Let’s move onto something a little bit more advanced and complex. Let’s whip out LangChain.
LangChain
LangChain is a tool that helps building chatbots, RAG methods, and other LLM-based tools. It aids interaction with vector databases, APIs, PDFs, SQL databases, and many more. However, it is not perfect. As far as I know, it’s the first package of its kind, so it suffers a bit by having little other tools to learn from, which makes its widespread use even moere impressive. Let’s connect to the database again, but now using the LangChain function.
from langchain_community.utilities import SQLDatabase
db = SQLDatabase.from_uri("duckdb:///data/budget.db")
Let’s now do the same thing we did earlier but now use the LangChain functionality to provide the schema. It’s a bit different than we did earlier where we just provided a dictionary with the table and column names and column types. This time we’ll use the get_context()
function. Instead, this function provides a SQL definition starting with CREATE TABLE ....
instead. Research by Rajkumar et al showed that this provides the best results when using few-shot and zero-shot tasks instead of providing the schema in other fashions (e.g. using SELECT
). The get_context()
function provides the following string:
db.get_table_info()[:100]
"\nCREATE TABLE account (\n\tid INTEGER DEFAULT nextval('account_id') NOT NULL, \n\taccount_number VARCHAR"
def with_a_little_langchain(question, verbose=False):
direct_prompt = f"""
Given an input question, create a syntactically correct SQL query to
run in DuckDB. The database is defined by the following schema:
{db.get_table_info()}
Only use the following tables:
{db.get_usable_table_names()}
Pay attention to use only the column names you can see in the tables
below. Be careful to not query for columns that do not exist. Also, pay
attention to which column is in which table. Please carefully think
before you answer.
Question:
"""
response = llm.invoke(direct_prompt + question)
if verbose:
print(response)
print(re.search("(SELECT.*);", response.content.replace("\n", " ")).group(1))
print(
db.run(re.search("(SELECT.*);", response.content.replace("\n", " ")).group(1))
)
with_a_little_langchain("How many items did I buy at bunnpris in 2023?", verbose=True)
content=" To answer the question, we need to find transactions where the retailer is 'bunnpris' and the year is 2023. However, there is no 'year' column in the 'transactions' table. We can extract the year from the 'date' column instead. Here is a SQL query that should work:\n\n```sql\nSELECT COUNT(*)\nFROM transactions\nWHERE retailer = 'bunnpris' AND EXTRACT(YEAR FROM date) = 2023;\n```\n\nThis query will return the number of transactions made at bunnpris in the year 2023." response_metadata={'model': 'mistral', 'created_at': '2024-09-05T20:21:44.484906Z', 'message': {'role': 'assistant', 'content': ''}, 'done_reason': 'stop', 'done': True, 'total_duration': 89934803165, 'load_duration': 30176790, 'prompt_eval_count': 690, 'prompt_eval_duration': 66605094000, 'eval_count': 142, 'eval_duration': 23291640000} id='run-43024891-75b6-444a-a5c6-4b6ea425e3f5-0' usage_metadata={'input_tokens': 690, 'output_tokens': 142, 'total_tokens': 832}
SELECT COUNT(*) FROM transactions WHERE retailer = 'bunnpris' AND EXTRACT(YEAR FROM date) = 2023
[(8,)]
That seems to have worked, let’s see if that one was just a fluke and try another
with_a_little_langchain("How much money did I spend on streaming in 2024?")
SELECT SUM(amount) as total_spent_on_streaming FROM transactions WHERE account_type = 'daily_account' AND YEAR(date) = 2024 AND category = 'streaming'
[(1759.0,)]
This worked better than I actually had expected. Just shows once more how much you can get out of your models by using careful prompt engineering. The response includes a description of how the model got to the SQL query it generated, and when the SQL query it run it gives the correct answer. If this was all you were interested in, this might suffice. However, it would be nice if you could get the result also in readable text. For that we need to build a chain where the result from the SQL query is incorporated in the response.
LangChain with only a SQL query chain
Before we go about implementing the answer from the query in the response, let’s first take the SQL querying one step further by using the create_sql_query_chain()
function directly. This function uses the LLM model specified and a database connection to generate a SQL query that should be able to run on the database. It is a step from our previous implementation in that it is purpose-built to generate SQL queries and that they don’t need to be parsed from the response text. In essence, it does some prompt engineering and requesting a response in a specific format it can then parse. By default it returns a SQL block only, without comments, wrapped in backticks, as you would write it in a Markdown document. You can see the template it uses by requesting it using the .get_prompts()
function.
from langchain.chains import create_sql_query_chain
create_sql_query_chain(llm, db).get_prompts()[0].pretty_print()
You are a DuckDB expert. Given an input question, first create a syntactically correct DuckDB query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most 5 results using the LIMIT clause as per DuckDB. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use today() function to get the current date, if the question involves "today".
Use the following format:
Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here
Only use the following tables:
{table_info}
Question: {input}
You’ll see that it comes with variables plugged in. The table_info
variable will be filled with the schema from the .get_table_info()
function we discussed earlier. The question from the user is plugged in the input
variable. We could change the prompt here too, and we’ll get to that in the next step, but for now let’s keep it simple and use the default one. You’ve already seen it above, but the create_sql_query_chain()
function takes two mandatory arguments, the model and the database. A custom prompt is an optional argument here. So we’ll plug those in. Then we’ll plug in the question as a key-value pair in the .invoke()
function. As mentioned, even though create_sql_query_chain()
only returns SQL queries, we still need to remove the backticks from the response so we can run it on the database. So let’s try it:
def langchain_sql_query(question, llm, db):
chain = create_sql_query_chain(llm, db)
response = chain.invoke({"question": question})
print(re.search("(SELECT.*);", response.replace("\n", " ")).group(1))
print(db.run(re.search("(SELECT.*);", response.replace("\n", " ")).group(1)))
return response
_ = langchain_sql_query("How much did I on average spend at bunnpris?", llm, db)
SELECT AVG(amount) as AverageSpend FROM transactions WHERE retailer = 'bunnpris' AND account_type = 'daily_account'
[(75.97978723404255,)]
Cool, let’s try it on another one.
_ = langchain_sql_query(
"How often did I buy something at xxl in the last 12 months?", llm, db
)
SELECT COUNT(*) FROM transactions WHERE retailer = 'xxl' AND date >= DATE_TRUNC('month', today()) - INTERVAL '1 year'
[(8,)]
This solution does a good job with a few lines of code. However, the answer is still returned in a tuple inside a list. It’s not very elegant and not the most user-friendly solution. Also, despite the fact that create_sql_query_chain()
returns a SQL query, it still required some parsing to return the answer and then the answer wasn’t presented in a particularly useful format. For that we need to bring out the full force of LangChain functions.
Full chain
So thus far we’ve taken multiple functions from LangChain and used them individually, but one only unleashes the full benefit of LangChain if you, like the name suggests, chain the operations. In short, the purpose here is to provide a natural language question that is interpreted by the LLM, have LangChain help generate a SQL query by providing the database, run the generated SQL query, and use the result from the SQL query as input to the LLM again so it can provide a natural language response incorporating the answer it got from the database.
So we’ll pick it up at the create_sql_query_chain()
function again. This time we’ll provide a custom template to wrap around the prompt. This template needs to contain three variables, among which is the question (input), the schema (from .get_table_info()
), and a variable to limit the query to the top N rows (top_k
). Why this last one is required I don’t fully understand yet, but we’ll deal with it. We can just write the prompt so that it never actually uses this variable if need be. We’ll be a bit more strict this time with what we want the function to return as well. A neat trick here is that we can also make the prompt more flexible by using f-strings such that we can also provide the dialect without having to have as a separate input variable. That just means that the remaining variables that are provided by create_sql_query_chain()
need to be double-wrapped in curly braces. This step is wrapped in the custom get_sql_chain()
function and it is important to note the distinction between the prompt template for the SQL chain and the prompt template for the response that we’ll generate later. Both need a separate template (unless you want to make life kinda complicated of course).
Now in the main function I called natural_language_chain()
we’ll first get the SQL chain and then we’ll define the prompt template for the response. This prompt needs a couple of variables as well, the user question, the SQL query (that will be generated by the SQL chain), and the response that running the SQL query generated will provide. Only the user question needs to be provided up front, the others will be filled by the chain.
Now we get to the part where we actually “chain” together the different parts we have introduced above. We’ll use it with the |
operator which functions as a pipe. This chain starts with a RunnablePassthrough
object where we’ll provide the SQL chain (including its prompt). As this SQL chain returns a SQL query, we’ll grab this query and pipe it to LangChain’s QuerySQLDataBaseTool()
function that will run the query on the database and return the result. Next we’ll provide the response prompt so that it fills the variables and pipe it to the LLM so that it incorporates the response from the query into the natural language answer. The final step uses the StrOutputParser()
so it prints it in a nice format (removing \n’s for examples).
We’ll call this chain the same as we have ealier, using the .invoke()
function. Unlike some of the alternatives above, the response from this chain is only a string that contains the answer. Let’s try it with a simple question first.
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
from operator import itemgetter
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from langchain_core.output_parsers import StrOutputParser
def get_sql_chain(llm, db, table_info, top_k=10):
template = f"""Given an input question, first create a syntactically
correct SQL query to run in {db.dialect}, then look at the results of the
query and return the answer to the input question. You can order the
results to return the most informative data in the database.
Unless otherwise specified, do not return more than {{top_k}} rows.
Never query for all columns from a table. You must query only the
columns that are needed to answer the question. Wrap each column name
in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names present in the tables
below. Be careful to not query for columns that do not exist. Also, pay
attention to which column is in which table. Query only the columns you
need to answer the question.
Please carefully think before you answer.
Here is the schema for the database:
{{table_info}}
Additional info: {{input}}
Return only the SQL query such that your response could be copied
verbatim into the SQL terminal.
"""
prompt = PromptTemplate.from_template(template)
sql_chain = create_sql_query_chain(llm, db, prompt)
return sql_chain
def natural_language_chain(question, llm, db):
table_info = db.get_table_info()
sql_chain = get_sql_chain(llm, db, table_info=table_info)
template = f"""
You are a data scientist at a compony. Based on the table schema provided
below, the SQL query and the SQL response, provide an answer that is as
accurate as possible. Please provide a natural language response that can
be understood by the user. Note that the currency is in Norwegian crowns.
Please think carefully about your answer.
SQL Query: {{query}}
User question: {{question}}
SQL Response: {{response}}
"""
prompt = PromptTemplate.from_template(template)
chain = (
RunnablePassthrough.assign(query=sql_chain).assign(
response=itemgetter("query") | QuerySQLDataBaseTool(db=db)
)
| prompt
| llm
| StrOutputParser()
)
response = chain.invoke({"question": question})
print(response)
return response
_ = natural_language_chain("How many transactions are there?", llm, db)
There are 3,517 transactions in the database.
This is a lot more user friendly than the previous options! And it also happens to be correct, even using a little model. This setup can fail in a number of steps, most importantly it can err on the SQL generation part. The response part is simple enough that when it gets a valid (though not necessarily correct if the SQL query contains an error) it will return a decent answer. You can build in additional restrictions and checks if need be. This is fun, let’s try another one.
_ = natural_language_chain(
"When was the first time I bought something from tise?", llm, db
)
The first time you bought something from 'tise' was on August 19th, 2022.
And since the hard part is still the SQL query generation we can ask more complex questions and it will provide an answer incorporating this.
_ = natural_language_chain("How much did I on average spend at xxl?", llm, db)
On average, you spent approximately 427 Norwegian crowns when shopping at Xxl.
And another one for good measure.
_ = natural_language_chain(
"How often did I buy something at elkjøp for more than 2000 NOK?", llm, db
)
Based on the data you provided, you made approximately 3 transactions at Elkjøp where the amount spent was over 2000 Norwegian crowns.
This was a fun one to write, and I wonder why no single data scientist on the internet has written about LLMs yet (/s just in case it wasn’t obvious). As with most the posts on this website, it also served as a way to think about these challenges for myself. One last thing, I have already added this badge the About page on this website, but since this is a blogpost about AI I’ll add it again here. Apart from the parts in this blogposts explicitly generated by AI, the entire blogpost was written by me, a human.