Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
7f203cb
remove tutorials
May 22, 2025
7ba5259
added About tab to app
May 22, 2025
7548615
update text
May 22, 2025
8c18094
renamed chatbot_ready arg
May 22, 2025
ca1a15d
renamed chatbot_ready arg
May 22, 2025
fc2a3ee
added API input requirement
May 22, 2025
9ea37dc
use gpt-4.1 instead of 4o
May 22, 2025
6313891
working sidebar and api validation
May 23, 2025
bdfa20b
message container to ensure message history is grouped above chat bar
May 23, 2025
621969b
all functionality working
May 23, 2025
211a99e
typo
May 23, 2025
c2c5246
typo
May 23, 2025
52a95a1
move text chunks to components
May 23, 2025
a3b53f0
added default chatbots
May 23, 2025
7753726
move key validation to own file
May 23, 2025
6cf5535
removed unused warning
May 23, 2025
a2c438c
remove unused function
May 23, 2025
b1e589d
unit tested
May 23, 2025
003b132
update model
May 23, 2025
91ab347
added key validation tests
May 23, 2025
67d0e96
All literature and StrainRelief 0.3 code
May 23, 2025
6ea1b98
Revert "All literature and StrainRelief 0.3 code"
May 23, 2025
c4c52af
Add vectorstore using Git LFS
May 23, 2025
dce3b56
exclude vectorstore
May 23, 2025
b9825e5
Add vectorstore using Git LFS
May 23, 2025
8c3f71d
Add vectorstore files using Git LFS
May 23, 2025
8289171
Add vectorstore files using Git LFS
May 23, 2025
5a37b39
Merge branch 'strain-relief-chatbot' of https://github.com/erwallace/…
May 23, 2025
c3ad8e6
Add vectorstore files using Git LFS
May 23, 2025
8f26876
Add vectorstore files using Git LFS
May 23, 2025
2fd76f2
use test vectorstore for tests
May 23, 2025
b2c4cb1
Revert "added key validation tests"
May 23, 2025
ce8de0f
revert changes
May 23, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
exclude: ^vectorstore/
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
Expand Down
2 changes: 1 addition & 1 deletion src/paper_query/base_chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def main():
parser.add_argument(
"--model",
type=str,
default="gpt-4o",
default="gpt-4.1",
help="Model name to use for the chatbot",
)
parser.add_argument(
Expand Down
2 changes: 1 addition & 1 deletion src/paper_query/code_query_chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def main():
parser.add_argument(
"--model",
type=str,
default="gpt-4o",
default="gpt-4.1",
help="Model name to use for the chatbot",
)
parser.add_argument(
Expand Down
4 changes: 3 additions & 1 deletion src/paper_query/constants/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from ._api_keys import GROQ_API_KEY, HUGGINGFACE_API_KEY, OPENAI_API_KEY
from ._paths import PERSIST_DIRECTORY, assets_dir, data_dir, project_dir, src_dir, test_dir
from ._strings import RAG_DOC_ID
from ._strings import RAG_DOC_ID, STREAMLIT_CHEAP_MODEL, STREAMLIT_EXPENSIVE_MODEL

__all__ = [
"OPENAI_API_KEY",
Expand All @@ -13,4 +13,6 @@
"data_dir",
"assets_dir",
"RAG_DOC_ID",
"STREAMLIT_CHEAP_MODEL",
"STREAMLIT_EXPENSIVE_MODEL",
]
3 changes: 3 additions & 0 deletions src/paper_query/constants/_strings.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
RAG_DOC_ID = "file"

STREAMLIT_CHEAP_MODEL = "GPT-4.1-nano"
STREAMLIT_EXPENSIVE_MODEL = "GPT-4.1"
2 changes: 1 addition & 1 deletion src/paper_query/hybrid_query_chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def main():
parser.add_argument(
"--model",
type=str,
default="gpt-4o",
default="gpt-4.1",
help="Model name to use for the chatbot",
)
parser.add_argument(
Expand Down
2 changes: 1 addition & 1 deletion src/paper_query/paper_query_chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def main():
parser.add_argument(
"--model",
type=str,
default="gpt-4o",
default="gpt-4.1",
help="Model name to use for the chatbot",
)
parser.add_argument(
Expand Down
2 changes: 1 addition & 1 deletion src/paper_query/paper_query_plus_chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def main():
parser.add_argument(
"--model",
type=str,
default="gpt-4o",
default="gpt-4.1",
help="Model name to use for the chatbot",
)
parser.add_argument(
Expand Down
33 changes: 21 additions & 12 deletions src/paper_query/ui/components/chat_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,33 @@ def display_chat_interface() -> None:
if "messages" not in st.session_state:
st.session_state.messages = []

for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
message_container = st.container()

if "chatbot_confirmed" in st.session_state and st.session_state.chatbot_confirmed:
# Display all past messages in the message container
with message_container:
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])

# Create the input at the bottom
if "chatbot_ready" in st.session_state and st.session_state.chatbot_ready:
if user_input := st.chat_input("What is your question?", key="user_input"):
st.chat_message("user").markdown(user_input)
# Add user message to UI
with message_container:
st.chat_message("user").markdown(user_input)
st.session_state.messages.append({"role": "user", "content": user_input})

with st.chat_message("assistant"):
message_placeholder = st.empty()
full_response = ""
# Add assistant response to UI
with message_container:
with st.chat_message("assistant"):
message_placeholder = st.empty()
full_response = ""

for response_chunk in st.session_state.chatbot.stream_response(user_input):
full_response += response_chunk
message_placeholder.markdown(full_response)
for response_chunk in st.session_state.chatbot.stream_response(user_input):
full_response += response_chunk
message_placeholder.markdown(full_response)

message_placeholder.markdown(full_response)
message_placeholder.markdown(full_response)

st.session_state.messages.append({"role": "assistant", "content": full_response})
else:
Expand Down
2 changes: 1 addition & 1 deletion src/paper_query/ui/components/sidebar_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def get_class_params(cls) -> list[str]:
]


def model_name_input(name: str = "gpt-4o") -> str:
def model_name_input(name: str = "gpt-4.1") -> str:
"""Get the model name from the sidebar."""
return st.sidebar.text_input("Model Name", value=name, key="model_name_input")

Expand Down
55 changes: 55 additions & 0 deletions src/paper_query/ui/components/text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from paper_query.constants import STREAMLIT_EXPENSIVE_MODEL

ABOUT = f"""
**StrainRelief is a tool for calculating ligand strain energy with quantum mechanical
accuracy**.

##### What is ligand strain energy?
Ligand strain energy is the energy difference between the bound and unbound conformations
of a ligand. It's an important component in structure-based small molecule drug design.

##### How does StrainRelief work?
StrainRelief uses a MACE Neural Network Potential (NNP) trained on a large database of
Density Functional Theory (DFT) calculations to estimate ligand strain of neutral molecules
with quantum accuracy.

##### About this chatbot
This chatbot is built using a hybrid retrieval and cached augmented generation (RAG/CAG)
approach:

1. The full StrainRelief [paper](https://arxiv.org/abs/2503.13352) is loaded and cached
in the context window for all queries
2. Reference papers cited in StrainRelief are embedded and available for retrieval
3. The StrainRelief code [repository](https://github.com/prescient-design/StrainRelief)
is embedded and available for retrieval

The chatbot is currently has a naive modular framework. When you ask a question, the
system:
- Retrieves relevant information from the references and code
- Combines this with the full paper context
- Uses the LLM to generate a response based on all available information

The chatbot uses the following components:
- **LLM**: {STREAMLIT_EXPENSIVE_MODEL} from OpenAI for generating responses
- **Embedding**: OpenAI embeddings for vector search
- **Vector Database**: ChromaDB for storing and retrieving embedded documents

Feel free to ask about the StrainRelief methodology, implementation details, or
how to use the tool for drug discovery applications.
"""

ABSTRACT = """
:gray[**Abstract**: Ligand strain energy, the energy difference between the
bound and unbound conformations of a ligand, is an important component of
structure-based small molecule drug design. A large majority of observed
ligands in protein-small molecule co-crystal structures bind in low-strain
conformations, making strain energy a useful filter for structure-based drug
design. In this work we present a tool for calculating ligand strain with a
high accuracy. StrainRelief uses a MACE Neural Network Potential (NNP),
trained on a large database of Density Functional Theory (DFT) calculations
to estimate ligand strain of neutral molecules with quantum accuracy. We show
that this tool estimates strain energy differences relative to DFT to within
1.4 kcal/mol, more accurately than alternative NNPs. These results highlight
the utility of NNPs in drug discovery, and provide a useful tool for drug
discovery teams.]
"""
25 changes: 25 additions & 0 deletions src/paper_query/ui/components/validate_key.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import streamlit as st
from loguru import logger
from openai import OpenAI

from paper_query.constants import STREAMLIT_CHEAP_MODEL, STREAMLIT_EXPENSIVE_MODEL


def validate_openai_api_key(api_key: str):
"""Validates the OpenAI API key and updates the session state accordingly."""
if api_key and api_key != st.session_state.last_validated_key:
try:
client = OpenAI(api_key=api_key)
client.models.list()
logger.debug("API key validation successful.")
st.session_state.model_name = STREAMLIT_EXPENSIVE_MODEL
st.session_state.last_validated_key = api_key
except Exception as e:
logger.error(f"API key validation failed: {e}")
st.sidebar.error("Invalid API key. Please check your OpenAI API key.")
st.session_state.model_name = STREAMLIT_CHEAP_MODEL
st.session_state.last_validated_key = None # Reset if validation fails
elif not api_key:
# Reset to cheap model if key is cleared
st.session_state.model_name = STREAMLIT_CHEAP_MODEL
st.session_state.last_validated_key = None
2 changes: 1 addition & 1 deletion src/paper_query/ui/custom_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def streamlit_chatbot():
chatbot_args = get_chatbot_params(selected_chatbot_class)

if st.sidebar.button("Confirm Chatbot", key="confirm_chatbot_button"):
st.session_state.chatbot_confirmed = True
st.session_state.chatbot_ready = True
st.session_state.chatbot = selected_chatbot_class(**chatbot_args)
st.sidebar.success(f"{selected_label} is ready!")
st.title(f"{selected_label} Chatbot")
Expand Down
88 changes: 50 additions & 38 deletions src/paper_query/ui/strain_relief_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,59 +4,71 @@
from loguru import logger

from paper_query.chatbots import HybridQueryChatbot
from paper_query.constants import assets_dir
from paper_query.constants import STREAMLIT_CHEAP_MODEL, STREAMLIT_EXPENSIVE_MODEL, assets_dir
from paper_query.ui.components.chat_interface import display_chat_interface
from paper_query.ui.components.text import ABOUT, ABSTRACT
from paper_query.ui.components.validate_key import validate_openai_api_key

# Configure logger to use DEBUG level
logger.remove()
logger.add(sys.stderr, level="DEBUG")


def initialize_session_state():
"""Initialize session state variables."""
if "chatbot_ready" not in st.session_state:
st.session_state.chatbot_ready = True

if "chatbot" not in st.session_state:
st.session_state.chatbot = None

if "model_name" not in st.session_state:
st.session_state.model_name = STREAMLIT_CHEAP_MODEL


def strain_relief_chatbot():
"""Chatbot for the StrainRelief paper."""
st.session_state.chatbot_confirmed = True
if "chatbot" not in st.session_state:
st.session_state.chatbot = HybridQueryChatbot(
model_name="gpt-4o",
model_provider="openai",
paper_path=str(assets_dir / "strainrelief_preprint.pdf"),
references_dir=str(assets_dir / "references"),
)
initialize_session_state()

st.title("The StrainRelief Chatbot")
chat_tab, about_tab = st.tabs(["Chat", "About"])

st.markdown(
"This retrieval augmented generation (RAG) chatbot is designed to answer questions about "
"the StrainRelief. The chatbot has access to the [paper](https://arxiv.org/abs/2503.13352),"
" all references, and the code "
"[repository](https://github.com/prescient-design/StrainRelief)."
st.sidebar.title("API Configuration")
# Enter API key in sidebar
openai_api_key = st.sidebar.text_input(
"OpenAI API Key",
type="password",
help="If you don't have an API key, you can get one from [OpenAI](https://platform.openai.com/api-keys).",
key="api_input",
)
if "messages" not in st.session_state:
st.markdown(
":gray[**Abstract**: Ligand strain energy, the energy difference between the bound and "
"unbound conformations of a ligand, is an important component of structure-based small "
"molecule drug design. A large majority of observed ligands in protein-small molecule "
"co-crystal structures bind in low-strain conformations, making strain energy a useful "
"filter for structure-based drug design. In this work we present a tool for "
"calculating ligand strain with a high accuracy. StrainRelief uses a MACE Neural "
"Network Potential (NNP), trained on a large database of Density Functional Theory "
"(DFT) calculations to estimate ligand strain of neutral molecules with quantum "
"accuracy. We show that this tool estimates strain energy differences relative to DFT "
"to within 1.4 kcal/mol, more accurately than alternative NNPs. These results "
"highlight the utility of NNPs in drug discovery, and provide a useful tool for drug "
"discovery teams.]"
)

display_chat_interface()

validate_openai_api_key(openai_api_key)
# Display current model
st.sidebar.markdown(f"Using **{st.session_state.model_name}** model.")

if __name__ == "__main__":
if sys.platform != "linux": # Skip for GitHub actions
# Get API keys from Streamlit secrets
from paper_query import constants
st.session_state.chatbot = HybridQueryChatbot(
model_name=st.session_state.model_name.lower(),
model_provider="openai",
paper_path=str(assets_dir / "strainrelief_preprint.pdf"),
references_dir=str(assets_dir / "references"),
)

with chat_tab:
if "messages" not in st.session_state:
st.markdown(ABSTRACT)

# Show info message only when using nano model
if st.session_state.model_name == STREAMLIT_CHEAP_MODEL:
st.info(
f"You are currently using {STREAMLIT_CHEAP_MODEL}. Add a valid OpenAI API key "
f"to access the more powerful {STREAMLIT_EXPENSIVE_MODEL} model."
)

constants.OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"]
constants.GROQ_API_KEY = st.secrets["GROQ_API_KEY"]
constants.HUGGINGFACE_API_KEY = st.secrets["HUGGINGFACE_API_KEY"]
display_chat_interface()

with about_tab:
st.markdown(ABOUT)


if __name__ == "__main__":
strain_relief_chatbot()
2 changes: 1 addition & 1 deletion test/data/test_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_pypdf_loader_w_images(test_assets_dir):
"""Test the pypdf_loader_w_images function."""
path = test_assets_dir / "example_pdf.pdf"
# TODO: change to free model
doc = pypdf_loader_w_images(path, "gpt-4o-mini", "openai")
doc = pypdf_loader_w_images(path, "gpt-4.1-nano", "openai")
assert isinstance(doc, Document)


Expand Down
8 changes: 4 additions & 4 deletions test/ui/components/test_sidebar_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@


def test_model_name_input():
assert model_name_input() == "gpt-4o"
assert model_name_input("gpt-4o-mini") == "gpt-4o-mini"
assert model_name_input() == "gpt-4.1"
assert model_name_input("gpt-4.1-nano") == "gpt-4.1-nano"


def test_model_provider_input():
Expand All @@ -38,7 +38,7 @@ def test_code_dir_input():


def test_get_param():
assert get_param("model_name") == "gpt-4o"
assert get_param("model_name") == "gpt-4.1"
assert get_param("model_provider") == "openai"
assert get_param("paper_path") == str(assets_dir / "strainrelief_preprint.pdf")
assert get_param("references_dir") == str(assets_dir / "references")
Expand All @@ -48,4 +48,4 @@ def test_get_param():


def test_get_chatbot_params():
assert get_chatbot_params(BaseChatbot) == {"model_name": "gpt-4o", "model_provider": "openai"}
assert get_chatbot_params(BaseChatbot) == {"model_name": "gpt-4.1", "model_provider": "openai"}
22 changes: 22 additions & 0 deletions test/ui/components/test_validate_key.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import pytest
import streamlit as st
from paper_query.constants import OPENAI_API_KEY, STREAMLIT_CHEAP_MODEL, STREAMLIT_EXPENSIVE_MODEL
from paper_query.ui.components.validate_key import validate_openai_api_key


@pytest.mark.app
@pytest.mark.parametrize(
"api_key, last_key, model_name",
[
(OPENAI_API_KEY, OPENAI_API_KEY, STREAMLIT_EXPENSIVE_MODEL),
(None, None, STREAMLIT_CHEAP_MODEL),
("invalid_key", None, STREAMLIT_CHEAP_MODEL),
],
)
def test_validate_openai_api_key_correct(api_key, last_key, model_name):
"""Test the OpenAI API key validation."""
st.session_state.last_validated_key = True

validate_openai_api_key(api_key)
assert st.session_state.last_validated_key == last_key
assert st.session_state.model_name == model_name
2 changes: 1 addition & 1 deletion test/ui/test_custom_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_confirm_chatbot(app):
@pytest.mark.app
def test_model_selection(app):
"""Test model selection text input."""
assert app.sidebar.text_input("model_name_input").value == "gpt-4o"
assert app.sidebar.text_input("model_name_input").value == "gpt-4.1"
app.sidebar.text_input("model_name_input").set_value(MODEL_NAME)
assert app.sidebar.text_input("model_name_input").value == MODEL_NAME

Expand Down
Loading