From 7be8d77e25117359a0c4d60976c7c89fcc54a435 Mon Sep 17 00:00:00 2001 From: Mackenzie Mathis Date: Sat, 16 Dec 2023 12:07:44 -0800 Subject: [PATCH 1/8] import .txt file --- setup.cfg | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/setup.cfg b/setup.cfg index 5603bde..c204778 100644 --- a/setup.cfg +++ b/setup.cfg @@ -62,5 +62,10 @@ streamlit = memory_profiler streamlit-profiler +[options.package_data] +amadeusgpt = interface.txt + [bdist_wheel] universal=1 + + From 9f65b6c38dc469b4f5fbac89b7a0243827b4a738 Mon Sep 17 00:00:00 2001 From: Mackenzie Mathis Date: Sat, 16 Dec 2023 12:52:38 -0800 Subject: [PATCH 2/8] adding app and app_utils to package --- amadeusgpt/app.py | 328 +++++++++++++++ amadeusgpt/app_utils.py | 909 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 1237 insertions(+) create mode 100644 amadeusgpt/app.py create mode 100644 amadeusgpt/app_utils.py diff --git a/amadeusgpt/app.py b/amadeusgpt/app.py new file mode 100644 index 0000000..6818a4f --- /dev/null +++ b/amadeusgpt/app.py @@ -0,0 +1,328 @@ +import os +import streamlit as st +import traceback +from collections import defaultdict +import uuid +from amadeusgpt.logger import AmadeusLogger +from datetime import datetime +import requests +from amadeusgpt import app_utils + + +def main(): + #import app_utils + st.title("Your Streamlit App") + + + def fetch_user_headers(): + """Fetch user and email info from HTTP headers. + + Output of this function is identical to querying + https://amadeusgpt.kinematik.ai/oauth2/userinfo, but + works from within the streamlit app. + """ + # TODO(stes): This could change without warning n future streamlit + # versions. So I'll leave the import here in case sth should go + # wrong in the future + from streamlit.web.server.websocket_headers import _get_websocket_headers + + headers = _get_websocket_headers() + AmadeusLogger.debug(f"Received Headers: {headers}") + return dict( + email=headers.get("X-Forwarded-Email", "no_email_in_header"), + user=headers.get("X-Forwarded-User", "no_user_in_header"), + ) + + + def fetch_user_info(): + url = "https://amadeusgpt.kinematik.ai/oauth2/userinfo" + try: + return fetch_user_headers() + # TODO(stes): Lets be on the safe side for now. + except Exception as e: + AmadeusLogger.info(f"Error: {e}") + return None + + + if "streamlit_app" in os.environ: + if "session_id" not in st.session_state: + session_id = str(uuid.uuid4()) + st.session_state["session_id"] = session_id + user_info = fetch_user_info() + if user_info: + st.session_state["username"] = user_info.get("user", "fake_username") + st.session_state["email"] = user_info.get("email", "fake_email") + else: + AmadeusLogger.info("Getting None from the endpoint") + st.session_state["username"] = "no_username" + st.session_state["email"] = "no_email" + + AmadeusLogger.debug("A new user logs in ") + + if f"database" not in st.session_state: + st.session_state[f"database"] = defaultdict(dict) + + from amadeusgpt.utils import validate_openai_api_key + import time + from streamlit_profiler import Profiler + + # TITLE PANEL + st.set_page_config(layout="wide") + app_utils.load_css("static/styles/style.css") + + + assert "streamlit_app" in os.environ + + ###### Initialize ###### + if "amadeus" not in st.session_state: + st.session_state["amadeus"] = app_utils.summon_the_beast()[0] + if "log_folder" not in st.session_state: + st.session_state["log_folder"] = app_utils.summon_the_beast()[1] + if "chatbot" not in st.session_state: + st.session_state["chatbot"] = [] + if "user" not in st.session_state: + st.session_state["user"] = [] + if "user_input" not in st.session_state: + st.session_state["user_input"] = "" + if "uploaded_files" not in st.session_state: + st.session_state["uploaded_files"] = [] + if "uploaded_video_file" not in st.session_state: + st.session_state["uploaded_video_file"] = None + if "uploaded_keypoint_file" not in st.session_state: + st.session_state["uploaded_keypoint_file"] = None + + if "example" not in st.session_state: + st.session_state["example"] = "" + if "chat_history" not in st.session_state: + st.session_state["chat_history"] = "" + if "previous_roi" not in st.session_state: + st.session_state["previous_roi"] = {} + if "roi_exist" not in st.session_state: + st.session_state["roi_exist"] = False + if "exist_valid_openai_api_key" not in st.session_state: + if "OPENAI_API_KEY" in os.environ: + st.session_state["exist_valid_openai_api_key"] = True + else: + st.session_state["exist_valid_openai_api_key"] = False + if "enable_explainer" not in st.session_state: + st.session_state["enable_explainer"] = False + + if "enable_SAM" not in st.session_state: + st.session_state["enable_SAM"] = False + + example_to_page = {} + + + def valid_api_key(): + if "OPENAI_API_KEY" in os.environ: + api_token = os.environ["OPENAI_API_KEY"] + else: + api_token = st.session_state["openAI_token"] + check_valid = validate_openai_api_key(api_token) + + if check_valid: + st.session_state["exist_valid_openai_api_key"] = True + st.session_state["OPENAI_API_KEY"] = api_token + st.success("OpenAI API Key Validated!") + else: + st.error("Invalid OpenAI API Key") + + + def welcome_page(text): + with st.sidebar as sb: + if st.session_state["exist_valid_openai_api_key"] is not True: + api_token = st.sidebar.text_input( + "Your openAI API token", + "place your token here", + key="openAI_token", + on_change=valid_api_key, + ) + + model_selection = st.sidebar.selectbox( + "Select a GPT-4 model", + ("gpt-4", "gpt-4-1106-preview"), + ) + st.session_state["gpt_model"] = model_selection + + enable_explainer = st.sidebar.selectbox( + "Do you want to use our LLM Explainer Module? This outputs a written description of the query results, but can be slow.", + ("No", "Yes"), + ) + st.session_state["enable_explainer"] = enable_explainer + + enable_SAM = st.sidebar.selectbox( + "Do you want to use Segment Anything on your own data? This can be slow and requires you to download the model weights.", + ("No", "Yes"), + ) + st.session_state["enable_SAM"] = enable_SAM + + + + # remove this for now + # st.caption(f"git hash: {app_utils.get_git_hash()}") + + st.image( + os.path.join(os.getcwd(), "static/images/amadeusgpt_logo.png"), + caption=None, + width=None, + use_column_width=None, + clamp=False, + channels="RGB", + output_format="auto", + ) + + st.markdown( + "##### 🪄 We turn natural language descriptions of behaviors into machine-executable code" + ) + + small_head = "#" * 6 + small_font = "" + + st.markdown("### 👥 Instructions") + + st.markdown( + f"{small_font} - We use LLMs to bridge natural language and behavior analysis code. For more details, check out our NeurIPS 2023 paper '[AmadeusGPT: a natural language interface for interactive animal behavioral analysis' by Shaokai Ye, Jessy Lauer, Mu Zhou, Alexander Mathis \& Mackenzie W. Mathis](https://github.com/AdaptiveMotorControlLab/AmadeusGPT)." + ) + st.markdown( + f"{small_font} - 🤗 Please note that depending on openAI, the runtimes can vary - you can see the app is `🏃RUNNING` in the top right when you run demos or ask new queries.\n" + ) + st.markdown( + f"{small_font} - Please give us feedback if the output is correct 👍, or needs improvement 👎. This is an ` academic research project` demo, so expect a few bumps please, but we are actively working to make it better 💕.\n" + ) + st.markdown( + f"{small_font} - ⬅️ To get started, watch the quick video below ⬇️, and then select a demo from the drop-down menu. 🔮 We recommend to refresh the browser/app between demos." + ) + + st.markdown( + f"{small_font} - To create an OpenAI API key please see: https://platform.openai.com/overview.\n" + ) + + st.markdown("### How AmadeusGPT🎻 works") + + st.markdown( + f"{small_font} - To capture animal-environment states, AmadeusGPT🎻 leverages state-of-the-art pretrained models, such as SuperAnimals for animal pose estimation and Segment-Anything (SAM) for object segmentation. The platform enables spatio-temporal reasoning to parse the outputs of computer vision models into quantitative behavior analysis. Additionally, AmadeusGPT🎻 simplifies the integration of arbitrary behavioral modules, making it easier to combine tools for task-specific models and interface with machine code." + ) + st.markdown( + f"{small_font} - We built core modules that interface with several integrations, plus built a dual-memory system to augment chatGPT, thereby allowing longer reasoning." + ) + st.markdown( + f"{small_font} - This demo serves to highlight a hosted user-experience, but does not include all the features yet..." + ) + st.markdown(f"{small_font} - Watch the video below to see how to use the App.") + + st.video("static/demo_withvoice.mp4") + + st.markdown("### ⚠️ Disclaimers") + + st.markdown( + f"{small_font} Refer to https://streamlit.io/privacy-policy for the privacy policy for your personal information.\n" + f"{small_font} Please note that to improve AmadeusGPT🎻 we log your queries and the generated code on our demos." + f"{small_font} Note, we do *not* log your openAI API key under any circumstances and we rely on streamlit cloud for privately securing your connections.\n" + f"{small_font} If you have security concerns over the API key, we suggest that you re-set your API key after you finish using our app.\n" + ) + + st.markdown("### 💻 The underlying core computer vision models explained") + st.markdown( + f"{small_font} We use pretrained computer vision models to capture the state of the animal and the environment. We hope this can reduce the entry barrier to behavior analysis.\n" + f"{small_font} Therefore, we can ask questions about animals' behaviors that are composed by animal's state, animal-animal interactions or animal-environment interactions.\n" + ) + st.markdown( + f"{small_font} DeepLabCut-SuperAnimal models, see https://arxiv.org/abs/2203.07436" + ) + st.markdown( + f"{small_font} MetaAI Segment-Anything models, see https://arxiv.org/abs/2304.02643" + ) + + st.markdown("### FAQ") + st.markdown(f"{small_font} Q: What can be done by AmadeusGPT🎻?") + st.markdown( + f"{small_font} - A: We provide a natural language interface to analyze video-based behavioral data. \n" + f"{small_font} We expect the user to describe a behavior before asking about the behaviors.\n" + f"{small_font} In general, one can define behaviors related to the movement of an animal (see EPM example), animal to animal interactions (see the MABe example) and\n" + f"{small_font} animal-environment interaction (check MausHaus example)." + ) + st.markdown(f"{small_font} Q: Can I run my own videos?") + st.markdown( + f"{small_font} - A: Not yet - due to limited compute resources we disabled on-demand pose estimation and object segmentation thus we cannot take new videos at this time. For your best experience, we pre-compute the pose and segmentation for example videos we provided. However, running DeepLabCut, SAM and other computer vision models is possible with AmadeusGPT🎻 so stay tuned!" + ) + st.markdown( + f"{small_font} Q: in the demos you use the term 'unit' - What is the unit being used?" + ) + st.markdown( + f"{small_font} - A: Pixels for distance and pixel per frame for speed and velocity given we don't have real-world values in distance" + ) + st.markdown( + f"{small_font} Q: How can I draw ROI and use the ROI to define a behavior?" + ) + st.markdown(f"{small_font} - A: Check the video on the EPM tab!") + st.markdown(f"{small_font} Q: How can I ask AmadeusGPT🎻 to plot something?") + st.markdown(f"{small_font} - A: Check the demo video and prompts in the examples") + st.markdown( + f"{small_font} Q: Why did AmadeusGPT🎻 produce errors or give me unexpected answers to my questions?" + ) + st.markdown( + f"{small_font} - A: Most likely that you are asking for something that is beyond the current capability of AmadeusGPT🎻 or\n" + "you are asking questions in a way that is unexpected. In either cases, we appreciate it if you can provide feedback \n" + "to us in our GitHub repo so we can improve our system (and note we log your queries and will use this to improve AmadeusGPT🎻)." + ) + + st.markdown(f"{small_font} Q: Does it work with mice only?") + st.markdown( + f"{small_font} - A: No, AmadeusGPT🎻 can work with a range of animals as long as poses are extracted and behaviors can be defined with those poses. We will add examples of other animals in the future." + ) + st.markdown(f"{small_font} Q: How do I know I can trust AmadeusGPT🎻's answers?") + st.markdown( + f"{small_font} - A: For people who are comfortable with reading Python code, reading the code can help validate the answer. We welcome the community to check our APIs. Otherwise, try visualize your questions by asking \n" + f"{small_font} AmadeusGPT🎻 to plot the related data and use the visualization as a cross validation. We are also developing new features\n" + f"{small_font} to help you gain more confidence on the results and how the results are obtained." + ) + st.markdown( + f"{small_font} Q: Why the page is blocked for a long time and there is no response?" + ) + st.markdown( + f"{small_font} - A: There might be a high traffic for either ChatGPT API or the Streamlit server. Refresh the page and retry or come back later." + ) + + + if st.session_state["exist_valid_openai_api_key"]: + example_list = ["Welcome", "Custom", "EPM", "MausHaus", "MABe", "Horse"] + else: + example_list = ["Welcome"] + + for key in example_list: + if key == "Welcome": + example_to_page[key] = welcome_page + else: + example_to_page[key] = app_utils.render_page_by_example + + with st.sidebar as sb: + example_bar = st.sidebar.selectbox( + "Select an example dataset", example_to_page.keys() + ) + + try: + if "enable_profiler" in os.environ: + with Profiler(): + example_to_page[example_bar](example_bar) + else: + example_to_page[example_bar](example_bar) + + except Exception as e: + print(traceback.format_exc()) + if "streamlit_cloud" in os.environ: + if "session_id" in st.session_state: + AmadeusLogger.store_chats("errors", str(e) + "\n" + traceback.format_exc()) + AmadeusLogger.debug(traceback.format_exc()) + + # with st.sidebar as sb: + # if "chat_history" in st.session_state and 'creation_time' in st.session_state: + # if example_bar != "Welcome": + # st.download_button( + # label="Download current chat", + # data=st.session_state["chat_history"], + # file_name=f"conversations_{st.session_state['creation_time']}.csv", + # mime="text/csv", + # key="chat_download", + # ) +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/amadeusgpt/app_utils.py b/amadeusgpt/app_utils.py new file mode 100644 index 0000000..1ec6ffb --- /dev/null +++ b/amadeusgpt/app_utils.py @@ -0,0 +1,909 @@ +import base64 +from io import BytesIO + +import matplotlib.pyplot as plt +import requests +import streamlit as st +from PIL import Image +from collections import defaultdict + +plt.style.use("dark_background") +import glob +import io +import math +import os +import pickle +import tempfile +from datetime import datetime +import copy +import cv2 +import numpy as np +import pandas as pd +import subprocess +from numpy import nan +from PIL import Image +from streamlit_drawable_canvas import st_canvas +from amadeusgpt.logger import AmadeusLogger +from amadeusgpt.implementation import AnimalBehaviorAnalysis, Object, Scene +from amadeusgpt.main import AMADEUS +import gc +from memory_profiler import profile as memory_profiler +import matplotlib +import json +import base64 +import io + +LOG_DIR = os.path.join(os.path.expanduser("~"), "Amadeus_logs") +VIDEO_EXTS = "mp4", "avi", "mov" +user_profile_path = "static/images/cat.png" +bot_profile_path = "static/images/chatbot.png" + + +def get_git_hash(): + import importlib.util + + basedir = os.path.split(importlib.util.find_spec("amadeusgpt").origin)[0] + git_hash = "" + try: + git_hash = subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=basedir) + git_hash = git_hash.decode("utf-8").rstrip("\n") + except subprocess.CalledProcessError: # Not installed from git + pass + return git_hash + + +def load_profile_image(image_path): + if image_path.startswith("http"): + response = requests.get(image_path) + img = Image.open(BytesIO(response.content)) + else: + img = Image.open(image_path) + return img + + +USER_PROFILE = load_profile_image(user_profile_path) +BOT_PROFILE = load_profile_image(bot_profile_path) + + +def conditional_memory_profile(func): + if os.environ.get("enable_profiler"): + return memory_profiler(func) + return func + + +def feedback_onclick(whether_like, user_msg, bot_msg): + feedback_type = "like" if whether_like else "dislike" + AmadeusLogger.log_feedback(feedback_type, (user_msg, bot_msg)) + + +class BaseMessage: + def __init__(self, json_entry=None): + if json_entry: + self.data = self.parse_from_json_entry(json_entry) + else: + self.data = {} + + def __getitem__(self, key): + return self.data[key] + + def parse_from_json_entry(self, json_entry): + return json_entry + + def __str__(self): + return str(self.data) + + def format_caption(self, caption): + temp = caption.split('\n') + temp = [f'
  • {content}
  • ' for content in temp] + ret = '' + return ret + + def render(self): + raise NotImplementedError("Must implement this") + +class HumanMessage(BaseMessage): + def __init__(self, query=None, json_entry=None): + if json_entry: + super().__init__(json_entry=json_entry) + else: + self.data = {} + self.data["role"] = "human" + if query: + self.data["query"] = query + + def render(self): + if len(self.data) > 0: + for render_key, render_value in self.data.items(): + if render_key == "query": + st.markdown( + f'
    {render_value}
    ', unsafe_allow_html=True + ) +class AIMessage(BaseMessage): + def __init__(self, amadeus_answer=None, json_entry=None): + if json_entry: + super().__init__(json_entry=json_entry) + else: + self.data = {} + self.data["role"] = "ai" + if amadeus_answer: + self.data.update(amadeus_answer.asdict()) + + def render(self): + """ + We use the getter for better encapsulation + overall structure of what to be rendered + + | chain of thoughts + error | str_answer | nd_array + | helper_code + | code + | figure + | | figure_explanation + | figure + | | figure_explanation + | overall_explanation + + -------- + """ + + render_keys = ['error_function_code', 'error_message', 'chain_of_thoughts', 'plots', 'str_answer', 'ndarray', 'summary'] + #for render_key, render_value in self.data.items(): + if len(self.data) > 0: + for render_key in render_keys: + if render_key not in self.data: + continue + render_value = self.data[render_key] + if render_value is None: + # skip empty field + continue + if render_key == "str_answer": + if render_value!="": + st.markdown(f" After executing the code, we get: {render_value}\n ") + + elif render_key == 'error_message': + st.markdown(f"The error says: {render_value}\n ") + elif render_key == 'error_function_code': + if 'task_program' in render_value and "```python" not in render_value: + st.code(render_value, language = 'python') + else: + st.markdown( + f'
    {render_value}
    ', unsafe_allow_html=True + ) + st.markdown(f"When executing the the code above, an error occurs:") + elif render_key == "chain_of_thoughts" or render_key == "summary": + # there should be a better matching than this + if 'task_program' in render_value and "```python" not in render_value and render_key == 'chain_of_thoughts': + st.code(render_value, language = 'python') + else: + st.markdown( + f'
    {render_value}
    ', unsafe_allow_html=True + ) + elif render_key == "ndarray": + for content_array in render_value: + content_array = content_array.squeeze() + # no point of showing array that's large + + hint_message = "Here is the output:" + st.markdown( + f'
    {hint_message}
    ', + unsafe_allow_html=True, + ) + if isinstance(content_array, str): + st.markdown( + f'
    {content_array}
    ', + unsafe_allow_html=True, + ) + else: + if len(content_array.shape) == 2: + df = pd.DataFrame(content_array) + st.dataframe(df, use_container_width=True) + else: + raise ValueError("returned array cannot be non 2D array.") + elif render_key == "plots": + # are there better ways in streamlit now to support plot display? + for fig_obj in render_value: + caption = self.format_caption(fig_obj["plot_caption"]) + if isinstance(fig_obj["figure"], str): + img_obj = Image.open(fig_obj["figure"]) + st.image(img_obj, width=600) + elif isinstance(fig_obj["figure"], matplotlib.figure.Figure): + # avoid using pyplot + filename = save_figure_to_tempfile(fig_obj["figure"]) + st.image(filename, width=600) + st.markdown( + f'
    {caption}
    ', + unsafe_allow_html=True, + ) + +class Messages: + """ + The data class that is used in the front end (i.e., streamlit) + methods + ------- + parse_from_csv + parse from example.csv for demo prompts + render_message: + render messages + to_csv: + to csv file + + attributes + ---------- + {'plot': List[str], + a list of file paths to image files that can be used by streamlit to render plots + + 'code': List[str], + a list of str that contains functions. Though in the future, + there might be more functions than the task program functions + maybe it should be dictionary + + 'ndarray': List[nd.array], + Some sort of rendering for ndarray might be useful + + 'text': List[str] + The text needs to be further decomposed into different + kinds of text, such as the main text, the plot documentation, the error + each might have different colors and different positions + } + + parameters + ---------- + data: a dictionary + """ + + def __init__(self): + self.raw_dict = None + self.messages = [] + + def parse_from_json(self, path): + with open(path, "r") as f: + json_obj = json.load(f) + for json_entry in json_obj: + if "query" in json_entry: + self.append(HumanMessage(json_entry=json_entry)) + else: + self.append(AIMessage(json_entry=json_entry)) + + def render(self): + """ + make sure those media types match what is in amadeus answer class + """ + for message in self.messages: + message.render() + + def append(self, e): + self.messages.append(e) + + def insert(self, ind, e): + self.messages.insert(ind, e) + + def __len__(self): + return len(self.messages) + + def __iter__(self): + return iter(self.messages) + + def __getitem__(self, ind): + return self.messages[ind] + + def __setitem__(self, ind, value): + self.messages[ind] = value + + +@st.cache_data(persist=False) +def summon_the_beast(): + # Get the current date and time + now = datetime.now() + timestamp = now.strftime("%Y%m%d_%H%M%S") + # Create the folder name with the timestamp + log_folder = os.path.join(LOG_DIR, timestamp) + # os.makedirs(log_folder, exist_ok=True) + return AMADEUS, log_folder, timestamp + + +def ask_amadeus(question): + answer = AMADEUS.chat_iteration( + question + ) # use chat_iteration to support some magic commands + # Get the current process + AmadeusLogger.log_process_memory(log_position="ask_amadeus") + return answer + + +def load_css(css_file): + with open(css_file, "r") as f: + css = f.read() + st.markdown(f"", unsafe_allow_html=True) + + +# caching display roi will make the roi stick to +# the display of initial state +def display_roi(example): + roi_objects = AnimalBehaviorAnalysis.get_roi_objects() + frame = Scene.get_scene_frame() + colormap = plt.cm.get_cmap("rainbow", len(roi_objects)) + + for i, (k, v) in enumerate(roi_objects.items()): + name = k + vertices = v.Path.vertices + pts = np.array(vertices, np.int32) + pts = pts.reshape((-1, 1, 2)) + color = colormap(i)[:3] + color = tuple(int(c * 255) for c in color[::-1]) + cv2.polylines(frame, [pts], isClosed=True, color=color, thickness=5) + text = name + font = cv2.FONT_HERSHEY_SIMPLEX + text_position = (pts[0, 0, 0], pts[0, 0, 1]) + cv2.putText(frame, text, text_position, font, 1, color, 2, cv2.LINE_AA) + with st.sidebar: + st.caption("ROIs in the scene") + st.image(frame) + + +def update_roi(result_json, ratios): + w_ratio, h_ratio = ratios + objects = pd.json_normalize(result_json["objects"]) + for col in objects.select_dtypes(include=["object"]).columns: + objects[col] = objects[col].astype("str") + roi_objects = {} + objects = objects.to_dict(orient="dict") + if "path" in objects: + paths = objects["path"] + count = 0 + for path_id in paths: + temp = eval(paths[path_id]) + paths[path_id] = temp + canvas_path = paths[path_id] + if not isinstance(canvas_path, list): + continue + points = [[p[1], p[2]] for p in canvas_path if len(p) == 3] + points = np.array(points) + points[:, 0] = points[:, 0] * w_ratio + points[:, 1] = points[:, 1] * h_ratio + _object = Object(f"ROI{count}", canvas_path=points) + roi_objects[f"ROI{count}"] = _object + + count += 1 + AnimalBehaviorAnalysis.set_roi_objects(roi_objects) + + AmadeusLogger.debug("User just drawed roi") + + +def finish_drawing(canvas_result, ratio): + update_roi(canvas_result.json_data, ratio) + +def place_st_canvas(key, scene_image): + + width, height = scene_image.size + # we always resize the canvas to its default values and keep the ratio + + w_ratio = width / 600 + h_ratio = height / 400 + + with st.sidebar: + st.caption( + "Left click to draw a polygon. Right click to confirm the drawing. Refresh the page if you need new ROIs or if the ROI canvas does not display" + ) + canvas_result = st_canvas( + # initial_drawing=st.session_state["previous_roi"], + fill_color="rgba(255, 165, 0, 0.9)", + stroke_width=3, + background_image=scene_image, + # update_streamlit = realtime_update, + width=600, + height=400, + drawing_mode="polygon", + key=f"{key}_canvas", + ) + + if ( + canvas_result.json_data is not None + and "path" in canvas_result.json_data + and len(canvas_result.json_data["path"]) > 0 + ): + pass + # st.session_state["previous_roi"] = canvas_result.json_data + if canvas_result.json_data is not None: + update_roi(canvas_result.json_data, (w_ratio, h_ratio)) + + if AnimalBehaviorAnalysis.roi_objects_exist(): + display_roi(key) + + if key == "EPM" and not AnimalBehaviorAnalysis.roi_objects_exist(): + with open("examples/EPM/roi_objects.pickle", "rb") as f: + roi_objects = pickle.load(f) + AnimalBehaviorAnalysis.set_roi_objects(roi_objects) + display_roi(key) + + +def chat_box_submit(): + if "user_input" in st.session_state: + AmadeusLogger.store_chats("user_query", st.session_state["user_input"]) + query = st.session_state["user_input"] + amadeus_answer = ask_amadeus(query) + + user_message = HumanMessage(query=query) + amadeus_message = AIMessage(amadeus_answer=amadeus_answer) + + st.session_state["messages"].append(user_message) + st.session_state["messages"].append(amadeus_message) + AmadeusLogger.debug("Submitted a query") + + +def check_uploaded_files(): + ## if upload files -> check if same and existing, + # check if multiple h5 -> replace / warning + if st.session_state["uploaded_files"]: + filenames = [f.name for f in st.session_state["uploaded_files"]] + folder_path = os.path.join(st.session_state["log_folder"], "uploaded_files") + if not os.path.exists(folder_path): + os.makedirs(folder_path) + files = st.session_state["uploaded_files"] + count_h5 = sum([int(file.name.endswith(".h5")) for file in files]) + # Remove the existing h5 file if there is a new one + if count_h5 > 1: + st.error("Oooops, you can only upload one *.h5 file! :ghost:") + for file in files: + if file.name.endswith(".h5"): + + with tempfile.NamedTemporaryFile( + dir=folder_path, suffix=".h5", delete=False + ) as temp: + temp.write(file.getbuffer()) + st.session_state['uploaded_keypoint_file'] = temp.name + AnimalBehaviorAnalysis.set_keypoint_file_path(temp.name) + if any(file.name.endswith(ext) for ext in VIDEO_EXTS): + with tempfile.NamedTemporaryFile( + dir=folder_path, suffix=".mp4", delete=False + ) as temp: + temp.write(file.getbuffer()) + AnimalBehaviorAnalysis.set_video_file_path(temp.name) + st.session_state["uploaded_video_file"] = temp.name + + +def set_up_sam(): + # check whether SAM model is there, if no, just return + static_root = "static" + if os.path.exists(os.path.join(static_root, "sam_vit_b_01ec64.pth")): + model_path = os.path.join(static_root, "sam_vit_b_01ec64.pth") + model_type = "vit_b" + elif os.path.exists(os.path.join(static_root, "sam_vit_l_0b3195.pth")): + model_path = os.path.join(static_root, "sam_vit_l_0b3195.pth") + model_type = "vit_l" + elif os.path.exists(os.path.join(static_root, "sam_vit_h_4b8939.pth")): + model_path = os.path.join(static_root, "sam_vit_h_4b8939.pth") + model_type = "vit_h" + else: + # on streamlit cloud, we do not even put those checkpoints + model_path = None + model_type = None + + if "log_folder" in st.session_state: + AnimalBehaviorAnalysis.set_sam_info( + ckpt_path=model_path, + model_type=model_type, + pickle_path=os.path.join( + st.session_state["log_folder"], "sam_object.pickle" + ), + ) + return model_path is not None + + +def init_files2amadeus(file, log_folder): + folder_path = os.path.join(log_folder, "uploaded_files") + if not os.path.exists(folder_path): + os.makedirs(folder_path) + + if "h5" in file: + AnimalBehaviorAnalysis.set_keypoint_file_path(file) + if os.path.splitext(file)[1][1:] in VIDEO_EXTS: + AnimalBehaviorAnalysis.set_video_file_path(file) + + +def rerun_prompt(query, ind): + messages = st.session_state["messages"] + amadeus_answer = ask_amadeus(query) + amadeus_message = AIMessage(amadeus_answer=amadeus_answer) + + if ind != len(messages) - 1 and messages[ind + 1]["role"] == "ai": + messages[ind + 1] = amadeus_message + else: + messages.insert(ind + 1, amadeus_message) + + +def render_messages(): + example = st.session_state["example"] + messages = st.session_state["messages"] + + if len(messages) == 0 and example!="Custom": + example_history_json = os.path.join(f"examples/{example}/example.json") + messages.parse_from_json(example_history_json) + + for ind, msg in enumerate(messages): + _role = msg["role"] + with st.chat_message( + _role, avatar=USER_PROFILE if _role == "human" else BOT_PROFILE + ): + # parse and render message which is a list of dictionary + if _role == "human": + msg.render() + disabled = not st.session_state["exist_valid_openai_api_key"] + button_name = "Generate Response" + st.button( + button_name, + key=f"{example}_user_{ind}", + on_click=rerun_prompt, + kwargs={"query": msg["query"], "ind": ind}, + disabled=disabled, + ) + else: + print ('debug msg') + print (msg) + msg.render() + + st.session_state["messages"] = messages + + disabled = not st.session_state["exist_valid_openai_api_key"] + st.chat_input( + "Ask me new questions here ...", + key="user_input", + on_submit=chat_box_submit, + disabled=disabled, + ) + # # Convert the saved conversations to a DataFrame + # df = pd.DataFrame(conversation_history) + # df.index.name = "Index" + # csv = df.to_csv().encode("utf-8") + # ## auto-save the conversation to logs + + # csv_path = os.path.join(st.session_state["log_folder"], "conversation.csv") + # df.to_csv(csv_path) + csv = None + return csv + + +def update_df_data(new_item, index_to_update, df, csv_file): + flat_items = [item[0] for item in new_item] + df.iloc[index_to_update] = None + df.loc[index_to_update, "Index"] = index_to_update + for item in flat_items: + key = item["type"] + value = item["content"] + if key in df.columns: + df.loc[index_to_update, key] = value + df.to_csv(csv_file, index=False) + + return csv_file, df + + +@st.cache_data(persist="disk") +def get_scene_image(example): + if AnimalBehaviorAnalysis.get_video_file_path() is not None: + scene_image = Scene.get_scene_frame() + if scene_image is not None: + scene_image = Image.fromarray(scene_image) + buffered = io.BytesIO() + scene_image.save(buffered, format="JPEG") + img_str = base64.b64encode(buffered.getvalue()).decode() + return img_str + elif example!='Custom': + video_file = glob.glob(os.path.join("examples", example, "*.mp4"))[0] + keypoint_file = glob.glob(os.path.join("examples", example, "*.h5"))[0] + AnimalBehaviorAnalysis.set_keypoint_file_path(keypoint_file) + AnimalBehaviorAnalysis.set_video_file_path(video_file) + return get_scene_image(example) + + +@st.cache_data(persist="disk") +def get_sam_image(example): + if AnimalBehaviorAnalysis.get_video_file_path(): + seg_objects = AnimalBehaviorAnalysis.get_seg_objects() + frame = Scene.get_scene_frame() + # number text on objects + mask_frame = AnimalBehaviorAnalysis.show_seg(seg_objects) + mask_frame = (mask_frame * 255).astype(np.uint8) + frame = (frame).astype(np.uint8) + image1 = Image.fromarray(frame, "RGB") + image1 = image1.convert("RGBA") + image2 = Image.fromarray(mask_frame, mode="RGBA") + sam_image = Image.blend(image1, image2, alpha=0.5) + sam_image = np.array(sam_image) + for obj_name, obj in seg_objects.items(): + x, y = obj.center + cv2.putText( + sam_image, + obj_name, + (int(x), int(y)), + cv2.FONT_HERSHEY_SIMPLEX, + 1, + (0, 0, 255), + 1, + ) + return sam_image + else: + return None + + +@conditional_memory_profile +def render_page_by_example(example): + st.image( + os.path.join(os.getcwd(), "static/images/amadeusgpt_logo.png"), + caption=None, + width=None, + use_column_width=None, + clamp=False, + channels="RGB", + output_format="auto", + ) + # st.markdown("# Welcome to AmadeusGPT🎻") + + if example == 'Custom': + st.markdown( + "Provide your own video and keypoint file (in pairs)" + ) + uploaded_files = st.file_uploader( + "Choose data or video files to upload", + ["h5", *VIDEO_EXTS], + accept_multiple_files=True, + ) + st.session_state['uploaded_files'] = uploaded_files + check_uploaded_files() + + ###### USER INPUT PANEL ###### + # get user input once getting the uploaded files + disabled = True if len(st.session_state["uploaded_files"])==0 else False + if disabled: + st.warning("Please upload a file before entering text.") + + + if example == "EPM": + st.markdown( + "Elevated plus maze (EPM) is a widely used behavioral test. The mouse is put on an elevated platform with two open arms (without walls) and two closed arms (with walls). \ + In this example we used a video from https://www.nature.com/articles/s41386-020-0776-y." + ) + st.markdown( + "- ⬅️ On the left you can see the video data auto-tracked with DeepLabCut and keypoint names (below). You can also draw ROIs to ask questions to AmadeusGPT🎻 about the ROIs. You can drag the divider between the panels to increase the video/image size." + ) + st.markdown( + "- We suggest you start by clicking 'Generate Response' to our demo queries." + ) + st.markdown( + "- Ask additional questions in the chatbox at the bottom of the page." + ) + st.markdown( + "- Here are some example queries you might consider: 'The <|open arm|> is the ROI0. How much time does the mouse spend in the open arm?' (NOTE here you can re-draw an ROI0 if you want. Be sure to click 'finish drawing') | 'Define head_dips as a behavior where the mouse's mouse_center and neck are in ROI0 which is open arm while head_midpoint is outside ROI1 which is the cross-shape area. When does head_dips happen and what is the number of bouts for head_dips?' " + ) + st.markdown("- ⬇️🎥 Watch this short clip on how to draw the ROI(s)🤗") + st.video("static/customEPMprompt_short.mp4") + + if example == "MABe": + st.markdown( + "MABe Mouse Triplets is part of a behavior benchmark presented in Sun et al 2022 https://arxiv.org/abs/2207.10553. In the videos, three mice exhibit multiple social behaviors including chasing. \ + In this example, we take one video where chasing happens between mice." + ) + st.markdown( + "- ⬅️ On the left you can see the video data and keypoint names, which could be useful for your queries. You can drag the divider between the panels to increase the video/image size." + ) + st.markdown( + "- We suggest you start by clicking 'Generate Response' to our demo queries." + ) + st.markdown( + "- Ask additional questions in the chatbox at the bottom of the page." + ) + + if example == "MausHaus": + st.markdown( + "MausHaus is a dataset that records a freely moving mouse within a rich environment with objects. More details can be found https://arxiv.org/pdf/2203.07436.pdf." + ) + st.markdown( + "- ⬅️ On the left you can see the video data auto-tracked with DeepLabCut, segmented image with SAM, and the keypoint guide, which could be useful for your queries. You can drag the divider between the panels to increase the video/image size." + ) + st.markdown( + "- We suggest you start by clicking 'Generate Response' to our demo queries." + ) + st.markdown( + "- Ask additional questions in the chatbox at the bottom of the page." + ) + st.markdown( + "- Here are some example queries you might consider: 'Give me events where the animal overlaps with the treadmill, which is object 5' | 'Define <|drinking|> as a behavior where the animal's nose is over object 28, which is a waterbasin. The minimum time window for this behavior should be 20 frames. When is the animal drinking?'" + ) + + if example == "Horse": + st.markdown( + "This horse video is part of a benchmark by Mathis et al 2021 https://arxiv.org/abs/1909.11229." + ) + + AnimalBehaviorAnalysis.set_cache_objects(True) + if example == "EPM" or example == 'Custom': + # in EPM and Custom, we allow people add more objects + AnimalBehaviorAnalysis.set_cache_objects(False) + + if st.session_state["example"] != example: + st.session_state["messages"] = Messages() + AmadeusLogger.debug("The user switched dataset") + + st.session_state["example"] = example + + st.session_state["log_folder"] = f"examples/{example}" + + video_file = None + scene_image = None + scene_image_str = None + if example =='Custom': + if st.session_state['uploaded_video_file']: + video_file = st.session_state['uploaded_video_file'] + scene_image_str = get_scene_image(example) + else: + video_file = glob.glob(os.path.join("examples", example, "*.mp4"))[0] + keypoint_file = glob.glob(os.path.join("examples", example, "*.h5"))[0] + AnimalBehaviorAnalysis.set_keypoint_file_path(keypoint_file) + AnimalBehaviorAnalysis.set_video_file_path(video_file) + # get the corresponding scene image for display + scene_image_str = get_scene_image(example) + + if scene_image_str is not None: + img_data = base64.b64decode(scene_image_str) + image_stream = io.BytesIO(img_data) + image_stream.seek(0) + scene_image = Image.open(image_stream) + + + col1, col2, col3 = st.columns([2, 1, 1]) + + sam_image = None + sam_success = set_up_sam() + if example == "MausHaus" or st.session_state['enable_SAM'] == "Yes": + if sam_success: + sam_image = get_sam_image(example) + else: + st.error("Cannot find SAM checkpoints. Skipping SAM") + + with st.sidebar as sb: + if example == "MABe": + st.caption("Raw video from MABe") + elif example == "Horse": + st.caption("Raw video from Horse-30") + else: + st.caption("DeepLabCut-SuperAnimal tracked video") + if video_file: + st.video(video_file) + # we only show objects for MausHaus for demo + if sam_image is not None: + st.caption("SAM segmentation results") + st.image(sam_image, channels="RGBA") + + if ( + st.session_state["example"] == "EPM" + or st.session_state["example"] == "MausHaus" + and scene_image is not None + ): + place_st_canvas(example, scene_image) + + if st.session_state["example"] == 'Custom' and scene_image: + place_st_canvas(example, scene_image) + + + if example == "EPM" or example == "MausHaus": + # will read the keypoints from h5 file to avoid hard coding + with st.sidebar: + st.image("static/images/supertopview.png") + with st.sidebar: + st.write("Keypoints:") + st.write(AnimalBehaviorAnalysis.get_bodypart_names()) + + render_messages() + + AmadeusLogger.log_process_memory(log_position=f"after_display_chats_{example}") + gc.collect() + AmadeusLogger.log_process_memory(log_position=f"after_garbage_collection_{example}") + + +def get_history_chat(chat_time): + csv_file = glob.glob(os.path.join(LOG_DIR, chat_time, "*.csv"))[0] + df = pd.read_csv(csv_file) + return df + + +def get_example_history_chat(example): + if example == "": + return None, None + csv_files = glob.glob(os.path.join("examples", example, "example.csv")) + if len(csv_files) > 0: + csv_file = csv_files[0] + df = pd.read_csv(csv_file) + return csv_file, df + else: + return None, None + + +def save_figure_to_tempfile(fig): + # save the figure + folder_path = os.path.join(st.session_state["log_folder"], "tmp_imgs") + if not os.path.exists(folder_path): + os.makedirs(folder_path) + # Generate a unique temporary filename in the specified folder + temp_file = tempfile.NamedTemporaryFile( + dir=folder_path, suffix=".png", delete=False + ) + filename = temp_file.name + temp_file.close() + fig.savefig( + filename, + format="png", + bbox_inches="tight", + pad_inches=0.0, + dpi=400, + transparent=True, + ) + return filename + + +def make_plot_pretty4dark_mode(fig, ax): + fig = plt.gcf() + fig.set_facecolor("none") + ax = plt.gca() + ax.set_facecolor("none") + # Set axes and legend colors to white or other light colors + ax.spines["bottom"].set_color("white") + ax.spines["top"].set_color("white") + ax.spines["right"].set_color("white") + ax.spines["left"].set_color("white") + + ax.xaxis.label.set_color("white") + ax.yaxis.label.set_color("white") + ax.title.set_color("white") + ax.tick_params(axis="x", colors="white") + ax.tick_params(axis="y", colors="white") + legend = plt.legend() + for text in legend.get_texts(): + text.set_color("white") + + return fig, ax + + +def display_image(temp_file): + full_image = Image.open(temp_file) + st.image(full_image) + + +def display_temp_text(text_content): + # Convert the text content to base64 + text_bytes = text_content.encode("utf-8") + text_base64 = base64.b64encode(text_bytes).decode() + # Display the link to the text file + st.markdown( + f'Check error.', + unsafe_allow_html=True, + ) + + +def style_button_row(clicked_button_ix, n_buttons): + def get_button_indices(button_ix): + return {"nth_child": button_ix, "nth_last_child": n_buttons - button_ix + 1} + + clicked_style = """ + div[data-testid*="stHorizontalBlock"] > div:nth-child(%(nth_child)s):nth-last-child(%(nth_last_child)s) button { + border-color: rgb(255, 75, 75); + color: rgb(255, 75, 75); + box-shadow: rgba(255, 75, 75, 0.5) 0px 0px 0px 0.2rem; + outline: currentcolor none medium; + } + """ + unclicked_style = """ + div[data-testid*="stHorizontalBlock"] > div:nth-child(%(nth_child)s):nth-last-child(%(nth_last_child)s) button { + pointer-events: none; + cursor: not-allowed; + opacity: 0.65; + filter: alpha(opacity=65); + -webkit-box-shadow: none; + box-shadow: none; + } + """ + style = "" + for ix in range(n_buttons): + ix += 1 + if ix == clicked_button_ix: + style += clicked_style % get_button_indices(ix) + else: + style += unclicked_style % get_button_indices(ix) + st.markdown(f"", unsafe_allow_html=True) From c8f9fec38966022a29e016db37874e5e96f4a88a Mon Sep 17 00:00:00 2001 From: Mackenzie Mathis Date: Sat, 16 Dec 2023 12:53:03 -0800 Subject: [PATCH 3/8] Delete app.py - moved --- app.py | 323 --------------------------------------------------------- 1 file changed, 323 deletions(-) delete mode 100644 app.py diff --git a/app.py b/app.py deleted file mode 100644 index 3a65368..0000000 --- a/app.py +++ /dev/null @@ -1,323 +0,0 @@ -import os -import streamlit as st -import traceback -from collections import defaultdict -import uuid -from amadeusgpt.logger import AmadeusLogger -from datetime import datetime -import requests - - -def fetch_user_headers(): - """Fetch user and email info from HTTP headers. - - Output of this function is identical to querying - https://amadeusgpt.kinematik.ai/oauth2/userinfo, but - works from within the streamlit app. - """ - # TODO(stes): This could change without warning n future streamlit - # versions. So I'll leave the import here in case sth should go - # wrong in the future - from streamlit.web.server.websocket_headers import _get_websocket_headers - - headers = _get_websocket_headers() - AmadeusLogger.debug(f"Received Headers: {headers}") - return dict( - email=headers.get("X-Forwarded-Email", "no_email_in_header"), - user=headers.get("X-Forwarded-User", "no_user_in_header"), - ) - - -def fetch_user_info(): - url = "https://amadeusgpt.kinematik.ai/oauth2/userinfo" - try: - return fetch_user_headers() - # TODO(stes): Lets be on the safe side for now. - except Exception as e: - AmadeusLogger.info(f"Error: {e}") - return None - - -if "streamlit_app" in os.environ: - if "session_id" not in st.session_state: - session_id = str(uuid.uuid4()) - st.session_state["session_id"] = session_id - user_info = fetch_user_info() - if user_info: - st.session_state["username"] = user_info.get("user", "fake_username") - st.session_state["email"] = user_info.get("email", "fake_email") - else: - AmadeusLogger.info("Getting None from the endpoint") - st.session_state["username"] = "no_username" - st.session_state["email"] = "no_email" - - AmadeusLogger.debug("A new user logs in ") - - if f"database" not in st.session_state: - st.session_state[f"database"] = defaultdict(dict) - -import app_utils - - -from amadeusgpt.utils import validate_openai_api_key -import time -from streamlit_profiler import Profiler - -# TITLE PANEL -st.set_page_config(layout="wide") -app_utils.load_css("static/styles/style.css") - - -assert "streamlit_app" in os.environ - -###### Initialize ###### -if "amadeus" not in st.session_state: - st.session_state["amadeus"] = app_utils.summon_the_beast()[0] -if "log_folder" not in st.session_state: - st.session_state["log_folder"] = app_utils.summon_the_beast()[1] -if "chatbot" not in st.session_state: - st.session_state["chatbot"] = [] -if "user" not in st.session_state: - st.session_state["user"] = [] -if "user_input" not in st.session_state: - st.session_state["user_input"] = "" -if "uploaded_files" not in st.session_state: - st.session_state["uploaded_files"] = [] -if "uploaded_video_file" not in st.session_state: - st.session_state["uploaded_video_file"] = None -if "uploaded_keypoint_file" not in st.session_state: - st.session_state["uploaded_keypoint_file"] = None - -if "example" not in st.session_state: - st.session_state["example"] = "" -if "chat_history" not in st.session_state: - st.session_state["chat_history"] = "" -if "previous_roi" not in st.session_state: - st.session_state["previous_roi"] = {} -if "roi_exist" not in st.session_state: - st.session_state["roi_exist"] = False -if "exist_valid_openai_api_key" not in st.session_state: - if "OPENAI_API_KEY" in os.environ: - st.session_state["exist_valid_openai_api_key"] = True - else: - st.session_state["exist_valid_openai_api_key"] = False -if "enable_explainer" not in st.session_state: - st.session_state["enable_explainer"] = False - -if "enable_SAM" not in st.session_state: - st.session_state["enable_SAM"] = False - -example_to_page = {} - - -def valid_api_key(): - if "OPENAI_API_KEY" in os.environ: - api_token = os.environ["OPENAI_API_KEY"] - else: - api_token = st.session_state["openAI_token"] - check_valid = validate_openai_api_key(api_token) - - if check_valid: - st.session_state["exist_valid_openai_api_key"] = True - st.session_state["OPENAI_API_KEY"] = api_token - st.success("OpenAI API Key Validated!") - else: - st.error("Invalid OpenAI API Key") - - -def welcome_page(text): - with st.sidebar as sb: - if st.session_state["exist_valid_openai_api_key"] is not True: - api_token = st.sidebar.text_input( - "Your openAI API token", - "place your token here", - key="openAI_token", - on_change=valid_api_key, - ) - - model_selection = st.sidebar.selectbox( - "Select a GPT-4 model", - ("gpt-4", "gpt-4-1106-preview"), - ) - st.session_state["gpt_model"] = model_selection - - enable_explainer = st.sidebar.selectbox( - "Do you want to use our LLM Explainer Module? This outputs a written description of the query results, but can be slow.", - ("No", "Yes"), - ) - st.session_state["enable_explainer"] = enable_explainer - - enable_SAM = st.sidebar.selectbox( - "Do you want to use Segment Anything on your own data? This can be slow and requires you to download the model weights.", - ("No", "Yes"), - ) - st.session_state["enable_SAM"] = enable_SAM - - - - # remove this for now - # st.caption(f"git hash: {app_utils.get_git_hash()}") - - st.image( - os.path.join(os.getcwd(), "static/images/amadeusgpt_logo.png"), - caption=None, - width=None, - use_column_width=None, - clamp=False, - channels="RGB", - output_format="auto", - ) - - st.markdown( - "##### 🪄 We turn natural language descriptions of behaviors into machine-executable code" - ) - - small_head = "#" * 6 - small_font = "" - - st.markdown("### 👥 Instructions") - - st.markdown( - f"{small_font} - We use LLMs to bridge natural language and behavior analysis code. For more details, check out our NeurIPS 2023 paper '[AmadeusGPT: a natural language interface for interactive animal behavioral analysis' by Shaokai Ye, Jessy Lauer, Mu Zhou, Alexander Mathis \& Mackenzie W. Mathis](https://github.com/AdaptiveMotorControlLab/AmadeusGPT)." - ) - st.markdown( - f"{small_font} - 🤗 Please note that depending on openAI, the runtimes can vary - you can see the app is `🏃RUNNING` in the top right when you run demos or ask new queries.\n" - ) - st.markdown( - f"{small_font} - Please give us feedback if the output is correct 👍, or needs improvement 👎. This is an ` academic research project` demo, so expect a few bumps please, but we are actively working to make it better 💕.\n" - ) - st.markdown( - f"{small_font} - ⬅️ To get started, watch the quick video below ⬇️, and then select a demo from the drop-down menu. 🔮 We recommend to refresh the browser/app between demos." - ) - - st.markdown( - f"{small_font} - To create an OpenAI API key please see: https://platform.openai.com/overview.\n" - ) - - st.markdown("### How AmadeusGPT🎻 works") - - st.markdown( - f"{small_font} - To capture animal-environment states, AmadeusGPT🎻 leverages state-of-the-art pretrained models, such as SuperAnimals for animal pose estimation and Segment-Anything (SAM) for object segmentation. The platform enables spatio-temporal reasoning to parse the outputs of computer vision models into quantitative behavior analysis. Additionally, AmadeusGPT🎻 simplifies the integration of arbitrary behavioral modules, making it easier to combine tools for task-specific models and interface with machine code." - ) - st.markdown( - f"{small_font} - We built core modules that interface with several integrations, plus built a dual-memory system to augment chatGPT, thereby allowing longer reasoning." - ) - st.markdown( - f"{small_font} - This demo serves to highlight a hosted user-experience, but does not include all the features yet..." - ) - st.markdown(f"{small_font} - Watch the video below to see how to use the App.") - - st.video("static/demo_withvoice.mp4") - - st.markdown("### ⚠️ Disclaimers") - - st.markdown( - f"{small_font} Refer to https://streamlit.io/privacy-policy for the privacy policy for your personal information.\n" - f"{small_font} Please note that to improve AmadeusGPT🎻 we log your queries and the generated code on our demos." - f"{small_font} Note, we do *not* log your openAI API key under any circumstances and we rely on streamlit cloud for privately securing your connections.\n" - f"{small_font} If you have security concerns over the API key, we suggest that you re-set your API key after you finish using our app.\n" - ) - - st.markdown("### 💻 The underlying core computer vision models explained") - st.markdown( - f"{small_font} We use pretrained computer vision models to capture the state of the animal and the environment. We hope this can reduce the entry barrier to behavior analysis.\n" - f"{small_font} Therefore, we can ask questions about animals' behaviors that are composed by animal's state, animal-animal interactions or animal-environment interactions.\n" - ) - st.markdown( - f"{small_font} DeepLabCut-SuperAnimal models, see https://arxiv.org/abs/2203.07436" - ) - st.markdown( - f"{small_font} MetaAI Segment-Anything models, see https://arxiv.org/abs/2304.02643" - ) - - st.markdown("### FAQ") - st.markdown(f"{small_font} Q: What can be done by AmadeusGPT🎻?") - st.markdown( - f"{small_font} - A: We provide a natural language interface to analyze video-based behavioral data. \n" - f"{small_font} We expect the user to describe a behavior before asking about the behaviors.\n" - f"{small_font} In general, one can define behaviors related to the movement of an animal (see EPM example), animal to animal interactions (see the MABe example) and\n" - f"{small_font} animal-environment interaction (check MausHaus example)." - ) - st.markdown(f"{small_font} Q: Can I run my own videos?") - st.markdown( - f"{small_font} - A: Not yet - due to limited compute resources we disabled on-demand pose estimation and object segmentation thus we cannot take new videos at this time. For your best experience, we pre-compute the pose and segmentation for example videos we provided. However, running DeepLabCut, SAM and other computer vision models is possible with AmadeusGPT🎻 so stay tuned!" - ) - st.markdown( - f"{small_font} Q: in the demos you use the term 'unit' - What is the unit being used?" - ) - st.markdown( - f"{small_font} - A: Pixels for distance and pixel per frame for speed and velocity given we don't have real-world values in distance" - ) - st.markdown( - f"{small_font} Q: How can I draw ROI and use the ROI to define a behavior?" - ) - st.markdown(f"{small_font} - A: Check the video on the EPM tab!") - st.markdown(f"{small_font} Q: How can I ask AmadeusGPT🎻 to plot something?") - st.markdown(f"{small_font} - A: Check the demo video and prompts in the examples") - st.markdown( - f"{small_font} Q: Why did AmadeusGPT🎻 produce errors or give me unexpected answers to my questions?" - ) - st.markdown( - f"{small_font} - A: Most likely that you are asking for something that is beyond the current capability of AmadeusGPT🎻 or\n" - "you are asking questions in a way that is unexpected. In either cases, we appreciate it if you can provide feedback \n" - "to us in our GitHub repo so we can improve our system (and note we log your queries and will use this to improve AmadeusGPT🎻)." - ) - - st.markdown(f"{small_font} Q: Does it work with mice only?") - st.markdown( - f"{small_font} - A: No, AmadeusGPT🎻 can work with a range of animals as long as poses are extracted and behaviors can be defined with those poses. We will add examples of other animals in the future." - ) - st.markdown(f"{small_font} Q: How do I know I can trust AmadeusGPT🎻's answers?") - st.markdown( - f"{small_font} - A: For people who are comfortable with reading Python code, reading the code can help validate the answer. We welcome the community to check our APIs. Otherwise, try visualize your questions by asking \n" - f"{small_font} AmadeusGPT🎻 to plot the related data and use the visualization as a cross validation. We are also developing new features\n" - f"{small_font} to help you gain more confidence on the results and how the results are obtained." - ) - st.markdown( - f"{small_font} Q: Why the page is blocked for a long time and there is no response?" - ) - st.markdown( - f"{small_font} - A: There might be a high traffic for either ChatGPT API or the Streamlit server. Refresh the page and retry or come back later." - ) - - -if st.session_state["exist_valid_openai_api_key"]: - example_list = ["Welcome", "Custom", "EPM", "MausHaus", "MABe", "Horse"] -else: - example_list = ["Welcome"] - -for key in example_list: - if key == "Welcome": - example_to_page[key] = welcome_page - else: - example_to_page[key] = app_utils.render_page_by_example - -with st.sidebar as sb: - example_bar = st.sidebar.selectbox( - "Select an example dataset", example_to_page.keys() - ) - -try: - if "enable_profiler" in os.environ: - with Profiler(): - example_to_page[example_bar](example_bar) - else: - example_to_page[example_bar](example_bar) - -except Exception as e: - print(traceback.format_exc()) - if "streamlit_cloud" in os.environ: - if "session_id" in st.session_state: - AmadeusLogger.store_chats("errors", str(e) + "\n" + traceback.format_exc()) - AmadeusLogger.debug(traceback.format_exc()) - -# with st.sidebar as sb: -# if "chat_history" in st.session_state and 'creation_time' in st.session_state: -# if example_bar != "Welcome": -# st.download_button( -# label="Download current chat", -# data=st.session_state["chat_history"], -# file_name=f"conversations_{st.session_state['creation_time']}.csv", -# mime="text/csv", -# key="chat_download", -# ) From d08d482e917fb27008917bd621423f01c8e83dd2 Mon Sep 17 00:00:00 2001 From: Mackenzie Mathis Date: Sat, 16 Dec 2023 12:53:18 -0800 Subject: [PATCH 4/8] Delete app_utils.py - moved --- app_utils.py | 909 --------------------------------------------------- 1 file changed, 909 deletions(-) delete mode 100644 app_utils.py diff --git a/app_utils.py b/app_utils.py deleted file mode 100644 index 1ec6ffb..0000000 --- a/app_utils.py +++ /dev/null @@ -1,909 +0,0 @@ -import base64 -from io import BytesIO - -import matplotlib.pyplot as plt -import requests -import streamlit as st -from PIL import Image -from collections import defaultdict - -plt.style.use("dark_background") -import glob -import io -import math -import os -import pickle -import tempfile -from datetime import datetime -import copy -import cv2 -import numpy as np -import pandas as pd -import subprocess -from numpy import nan -from PIL import Image -from streamlit_drawable_canvas import st_canvas -from amadeusgpt.logger import AmadeusLogger -from amadeusgpt.implementation import AnimalBehaviorAnalysis, Object, Scene -from amadeusgpt.main import AMADEUS -import gc -from memory_profiler import profile as memory_profiler -import matplotlib -import json -import base64 -import io - -LOG_DIR = os.path.join(os.path.expanduser("~"), "Amadeus_logs") -VIDEO_EXTS = "mp4", "avi", "mov" -user_profile_path = "static/images/cat.png" -bot_profile_path = "static/images/chatbot.png" - - -def get_git_hash(): - import importlib.util - - basedir = os.path.split(importlib.util.find_spec("amadeusgpt").origin)[0] - git_hash = "" - try: - git_hash = subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=basedir) - git_hash = git_hash.decode("utf-8").rstrip("\n") - except subprocess.CalledProcessError: # Not installed from git - pass - return git_hash - - -def load_profile_image(image_path): - if image_path.startswith("http"): - response = requests.get(image_path) - img = Image.open(BytesIO(response.content)) - else: - img = Image.open(image_path) - return img - - -USER_PROFILE = load_profile_image(user_profile_path) -BOT_PROFILE = load_profile_image(bot_profile_path) - - -def conditional_memory_profile(func): - if os.environ.get("enable_profiler"): - return memory_profiler(func) - return func - - -def feedback_onclick(whether_like, user_msg, bot_msg): - feedback_type = "like" if whether_like else "dislike" - AmadeusLogger.log_feedback(feedback_type, (user_msg, bot_msg)) - - -class BaseMessage: - def __init__(self, json_entry=None): - if json_entry: - self.data = self.parse_from_json_entry(json_entry) - else: - self.data = {} - - def __getitem__(self, key): - return self.data[key] - - def parse_from_json_entry(self, json_entry): - return json_entry - - def __str__(self): - return str(self.data) - - def format_caption(self, caption): - temp = caption.split('\n') - temp = [f'
  • {content}
  • ' for content in temp] - ret = '
      \n' + ''.join(temp).rstrip() + '\n
    ' - return ret - - def render(self): - raise NotImplementedError("Must implement this") - -class HumanMessage(BaseMessage): - def __init__(self, query=None, json_entry=None): - if json_entry: - super().__init__(json_entry=json_entry) - else: - self.data = {} - self.data["role"] = "human" - if query: - self.data["query"] = query - - def render(self): - if len(self.data) > 0: - for render_key, render_value in self.data.items(): - if render_key == "query": - st.markdown( - f'
    {render_value}
    ', unsafe_allow_html=True - ) -class AIMessage(BaseMessage): - def __init__(self, amadeus_answer=None, json_entry=None): - if json_entry: - super().__init__(json_entry=json_entry) - else: - self.data = {} - self.data["role"] = "ai" - if amadeus_answer: - self.data.update(amadeus_answer.asdict()) - - def render(self): - """ - We use the getter for better encapsulation - overall structure of what to be rendered - - | chain of thoughts - error | str_answer | nd_array - | helper_code - | code - | figure - | | figure_explanation - | figure - | | figure_explanation - | overall_explanation - - -------- - """ - - render_keys = ['error_function_code', 'error_message', 'chain_of_thoughts', 'plots', 'str_answer', 'ndarray', 'summary'] - #for render_key, render_value in self.data.items(): - if len(self.data) > 0: - for render_key in render_keys: - if render_key not in self.data: - continue - render_value = self.data[render_key] - if render_value is None: - # skip empty field - continue - if render_key == "str_answer": - if render_value!="": - st.markdown(f" After executing the code, we get: {render_value}\n ") - - elif render_key == 'error_message': - st.markdown(f"The error says: {render_value}\n ") - elif render_key == 'error_function_code': - if 'task_program' in render_value and "```python" not in render_value: - st.code(render_value, language = 'python') - else: - st.markdown( - f'
    {render_value}
    ', unsafe_allow_html=True - ) - st.markdown(f"When executing the the code above, an error occurs:") - elif render_key == "chain_of_thoughts" or render_key == "summary": - # there should be a better matching than this - if 'task_program' in render_value and "```python" not in render_value and render_key == 'chain_of_thoughts': - st.code(render_value, language = 'python') - else: - st.markdown( - f'
    {render_value}
    ', unsafe_allow_html=True - ) - elif render_key == "ndarray": - for content_array in render_value: - content_array = content_array.squeeze() - # no point of showing array that's large - - hint_message = "Here is the output:" - st.markdown( - f'
    {hint_message}
    ', - unsafe_allow_html=True, - ) - if isinstance(content_array, str): - st.markdown( - f'
    {content_array}
    ', - unsafe_allow_html=True, - ) - else: - if len(content_array.shape) == 2: - df = pd.DataFrame(content_array) - st.dataframe(df, use_container_width=True) - else: - raise ValueError("returned array cannot be non 2D array.") - elif render_key == "plots": - # are there better ways in streamlit now to support plot display? - for fig_obj in render_value: - caption = self.format_caption(fig_obj["plot_caption"]) - if isinstance(fig_obj["figure"], str): - img_obj = Image.open(fig_obj["figure"]) - st.image(img_obj, width=600) - elif isinstance(fig_obj["figure"], matplotlib.figure.Figure): - # avoid using pyplot - filename = save_figure_to_tempfile(fig_obj["figure"]) - st.image(filename, width=600) - st.markdown( - f'
    {caption}
    ', - unsafe_allow_html=True, - ) - -class Messages: - """ - The data class that is used in the front end (i.e., streamlit) - methods - ------- - parse_from_csv - parse from example.csv for demo prompts - render_message: - render messages - to_csv: - to csv file - - attributes - ---------- - {'plot': List[str], - a list of file paths to image files that can be used by streamlit to render plots - - 'code': List[str], - a list of str that contains functions. Though in the future, - there might be more functions than the task program functions - maybe it should be dictionary - - 'ndarray': List[nd.array], - Some sort of rendering for ndarray might be useful - - 'text': List[str] - The text needs to be further decomposed into different - kinds of text, such as the main text, the plot documentation, the error - each might have different colors and different positions - } - - parameters - ---------- - data: a dictionary - """ - - def __init__(self): - self.raw_dict = None - self.messages = [] - - def parse_from_json(self, path): - with open(path, "r") as f: - json_obj = json.load(f) - for json_entry in json_obj: - if "query" in json_entry: - self.append(HumanMessage(json_entry=json_entry)) - else: - self.append(AIMessage(json_entry=json_entry)) - - def render(self): - """ - make sure those media types match what is in amadeus answer class - """ - for message in self.messages: - message.render() - - def append(self, e): - self.messages.append(e) - - def insert(self, ind, e): - self.messages.insert(ind, e) - - def __len__(self): - return len(self.messages) - - def __iter__(self): - return iter(self.messages) - - def __getitem__(self, ind): - return self.messages[ind] - - def __setitem__(self, ind, value): - self.messages[ind] = value - - -@st.cache_data(persist=False) -def summon_the_beast(): - # Get the current date and time - now = datetime.now() - timestamp = now.strftime("%Y%m%d_%H%M%S") - # Create the folder name with the timestamp - log_folder = os.path.join(LOG_DIR, timestamp) - # os.makedirs(log_folder, exist_ok=True) - return AMADEUS, log_folder, timestamp - - -def ask_amadeus(question): - answer = AMADEUS.chat_iteration( - question - ) # use chat_iteration to support some magic commands - # Get the current process - AmadeusLogger.log_process_memory(log_position="ask_amadeus") - return answer - - -def load_css(css_file): - with open(css_file, "r") as f: - css = f.read() - st.markdown(f"", unsafe_allow_html=True) - - -# caching display roi will make the roi stick to -# the display of initial state -def display_roi(example): - roi_objects = AnimalBehaviorAnalysis.get_roi_objects() - frame = Scene.get_scene_frame() - colormap = plt.cm.get_cmap("rainbow", len(roi_objects)) - - for i, (k, v) in enumerate(roi_objects.items()): - name = k - vertices = v.Path.vertices - pts = np.array(vertices, np.int32) - pts = pts.reshape((-1, 1, 2)) - color = colormap(i)[:3] - color = tuple(int(c * 255) for c in color[::-1]) - cv2.polylines(frame, [pts], isClosed=True, color=color, thickness=5) - text = name - font = cv2.FONT_HERSHEY_SIMPLEX - text_position = (pts[0, 0, 0], pts[0, 0, 1]) - cv2.putText(frame, text, text_position, font, 1, color, 2, cv2.LINE_AA) - with st.sidebar: - st.caption("ROIs in the scene") - st.image(frame) - - -def update_roi(result_json, ratios): - w_ratio, h_ratio = ratios - objects = pd.json_normalize(result_json["objects"]) - for col in objects.select_dtypes(include=["object"]).columns: - objects[col] = objects[col].astype("str") - roi_objects = {} - objects = objects.to_dict(orient="dict") - if "path" in objects: - paths = objects["path"] - count = 0 - for path_id in paths: - temp = eval(paths[path_id]) - paths[path_id] = temp - canvas_path = paths[path_id] - if not isinstance(canvas_path, list): - continue - points = [[p[1], p[2]] for p in canvas_path if len(p) == 3] - points = np.array(points) - points[:, 0] = points[:, 0] * w_ratio - points[:, 1] = points[:, 1] * h_ratio - _object = Object(f"ROI{count}", canvas_path=points) - roi_objects[f"ROI{count}"] = _object - - count += 1 - AnimalBehaviorAnalysis.set_roi_objects(roi_objects) - - AmadeusLogger.debug("User just drawed roi") - - -def finish_drawing(canvas_result, ratio): - update_roi(canvas_result.json_data, ratio) - -def place_st_canvas(key, scene_image): - - width, height = scene_image.size - # we always resize the canvas to its default values and keep the ratio - - w_ratio = width / 600 - h_ratio = height / 400 - - with st.sidebar: - st.caption( - "Left click to draw a polygon. Right click to confirm the drawing. Refresh the page if you need new ROIs or if the ROI canvas does not display" - ) - canvas_result = st_canvas( - # initial_drawing=st.session_state["previous_roi"], - fill_color="rgba(255, 165, 0, 0.9)", - stroke_width=3, - background_image=scene_image, - # update_streamlit = realtime_update, - width=600, - height=400, - drawing_mode="polygon", - key=f"{key}_canvas", - ) - - if ( - canvas_result.json_data is not None - and "path" in canvas_result.json_data - and len(canvas_result.json_data["path"]) > 0 - ): - pass - # st.session_state["previous_roi"] = canvas_result.json_data - if canvas_result.json_data is not None: - update_roi(canvas_result.json_data, (w_ratio, h_ratio)) - - if AnimalBehaviorAnalysis.roi_objects_exist(): - display_roi(key) - - if key == "EPM" and not AnimalBehaviorAnalysis.roi_objects_exist(): - with open("examples/EPM/roi_objects.pickle", "rb") as f: - roi_objects = pickle.load(f) - AnimalBehaviorAnalysis.set_roi_objects(roi_objects) - display_roi(key) - - -def chat_box_submit(): - if "user_input" in st.session_state: - AmadeusLogger.store_chats("user_query", st.session_state["user_input"]) - query = st.session_state["user_input"] - amadeus_answer = ask_amadeus(query) - - user_message = HumanMessage(query=query) - amadeus_message = AIMessage(amadeus_answer=amadeus_answer) - - st.session_state["messages"].append(user_message) - st.session_state["messages"].append(amadeus_message) - AmadeusLogger.debug("Submitted a query") - - -def check_uploaded_files(): - ## if upload files -> check if same and existing, - # check if multiple h5 -> replace / warning - if st.session_state["uploaded_files"]: - filenames = [f.name for f in st.session_state["uploaded_files"]] - folder_path = os.path.join(st.session_state["log_folder"], "uploaded_files") - if not os.path.exists(folder_path): - os.makedirs(folder_path) - files = st.session_state["uploaded_files"] - count_h5 = sum([int(file.name.endswith(".h5")) for file in files]) - # Remove the existing h5 file if there is a new one - if count_h5 > 1: - st.error("Oooops, you can only upload one *.h5 file! :ghost:") - for file in files: - if file.name.endswith(".h5"): - - with tempfile.NamedTemporaryFile( - dir=folder_path, suffix=".h5", delete=False - ) as temp: - temp.write(file.getbuffer()) - st.session_state['uploaded_keypoint_file'] = temp.name - AnimalBehaviorAnalysis.set_keypoint_file_path(temp.name) - if any(file.name.endswith(ext) for ext in VIDEO_EXTS): - with tempfile.NamedTemporaryFile( - dir=folder_path, suffix=".mp4", delete=False - ) as temp: - temp.write(file.getbuffer()) - AnimalBehaviorAnalysis.set_video_file_path(temp.name) - st.session_state["uploaded_video_file"] = temp.name - - -def set_up_sam(): - # check whether SAM model is there, if no, just return - static_root = "static" - if os.path.exists(os.path.join(static_root, "sam_vit_b_01ec64.pth")): - model_path = os.path.join(static_root, "sam_vit_b_01ec64.pth") - model_type = "vit_b" - elif os.path.exists(os.path.join(static_root, "sam_vit_l_0b3195.pth")): - model_path = os.path.join(static_root, "sam_vit_l_0b3195.pth") - model_type = "vit_l" - elif os.path.exists(os.path.join(static_root, "sam_vit_h_4b8939.pth")): - model_path = os.path.join(static_root, "sam_vit_h_4b8939.pth") - model_type = "vit_h" - else: - # on streamlit cloud, we do not even put those checkpoints - model_path = None - model_type = None - - if "log_folder" in st.session_state: - AnimalBehaviorAnalysis.set_sam_info( - ckpt_path=model_path, - model_type=model_type, - pickle_path=os.path.join( - st.session_state["log_folder"], "sam_object.pickle" - ), - ) - return model_path is not None - - -def init_files2amadeus(file, log_folder): - folder_path = os.path.join(log_folder, "uploaded_files") - if not os.path.exists(folder_path): - os.makedirs(folder_path) - - if "h5" in file: - AnimalBehaviorAnalysis.set_keypoint_file_path(file) - if os.path.splitext(file)[1][1:] in VIDEO_EXTS: - AnimalBehaviorAnalysis.set_video_file_path(file) - - -def rerun_prompt(query, ind): - messages = st.session_state["messages"] - amadeus_answer = ask_amadeus(query) - amadeus_message = AIMessage(amadeus_answer=amadeus_answer) - - if ind != len(messages) - 1 and messages[ind + 1]["role"] == "ai": - messages[ind + 1] = amadeus_message - else: - messages.insert(ind + 1, amadeus_message) - - -def render_messages(): - example = st.session_state["example"] - messages = st.session_state["messages"] - - if len(messages) == 0 and example!="Custom": - example_history_json = os.path.join(f"examples/{example}/example.json") - messages.parse_from_json(example_history_json) - - for ind, msg in enumerate(messages): - _role = msg["role"] - with st.chat_message( - _role, avatar=USER_PROFILE if _role == "human" else BOT_PROFILE - ): - # parse and render message which is a list of dictionary - if _role == "human": - msg.render() - disabled = not st.session_state["exist_valid_openai_api_key"] - button_name = "Generate Response" - st.button( - button_name, - key=f"{example}_user_{ind}", - on_click=rerun_prompt, - kwargs={"query": msg["query"], "ind": ind}, - disabled=disabled, - ) - else: - print ('debug msg') - print (msg) - msg.render() - - st.session_state["messages"] = messages - - disabled = not st.session_state["exist_valid_openai_api_key"] - st.chat_input( - "Ask me new questions here ...", - key="user_input", - on_submit=chat_box_submit, - disabled=disabled, - ) - # # Convert the saved conversations to a DataFrame - # df = pd.DataFrame(conversation_history) - # df.index.name = "Index" - # csv = df.to_csv().encode("utf-8") - # ## auto-save the conversation to logs - - # csv_path = os.path.join(st.session_state["log_folder"], "conversation.csv") - # df.to_csv(csv_path) - csv = None - return csv - - -def update_df_data(new_item, index_to_update, df, csv_file): - flat_items = [item[0] for item in new_item] - df.iloc[index_to_update] = None - df.loc[index_to_update, "Index"] = index_to_update - for item in flat_items: - key = item["type"] - value = item["content"] - if key in df.columns: - df.loc[index_to_update, key] = value - df.to_csv(csv_file, index=False) - - return csv_file, df - - -@st.cache_data(persist="disk") -def get_scene_image(example): - if AnimalBehaviorAnalysis.get_video_file_path() is not None: - scene_image = Scene.get_scene_frame() - if scene_image is not None: - scene_image = Image.fromarray(scene_image) - buffered = io.BytesIO() - scene_image.save(buffered, format="JPEG") - img_str = base64.b64encode(buffered.getvalue()).decode() - return img_str - elif example!='Custom': - video_file = glob.glob(os.path.join("examples", example, "*.mp4"))[0] - keypoint_file = glob.glob(os.path.join("examples", example, "*.h5"))[0] - AnimalBehaviorAnalysis.set_keypoint_file_path(keypoint_file) - AnimalBehaviorAnalysis.set_video_file_path(video_file) - return get_scene_image(example) - - -@st.cache_data(persist="disk") -def get_sam_image(example): - if AnimalBehaviorAnalysis.get_video_file_path(): - seg_objects = AnimalBehaviorAnalysis.get_seg_objects() - frame = Scene.get_scene_frame() - # number text on objects - mask_frame = AnimalBehaviorAnalysis.show_seg(seg_objects) - mask_frame = (mask_frame * 255).astype(np.uint8) - frame = (frame).astype(np.uint8) - image1 = Image.fromarray(frame, "RGB") - image1 = image1.convert("RGBA") - image2 = Image.fromarray(mask_frame, mode="RGBA") - sam_image = Image.blend(image1, image2, alpha=0.5) - sam_image = np.array(sam_image) - for obj_name, obj in seg_objects.items(): - x, y = obj.center - cv2.putText( - sam_image, - obj_name, - (int(x), int(y)), - cv2.FONT_HERSHEY_SIMPLEX, - 1, - (0, 0, 255), - 1, - ) - return sam_image - else: - return None - - -@conditional_memory_profile -def render_page_by_example(example): - st.image( - os.path.join(os.getcwd(), "static/images/amadeusgpt_logo.png"), - caption=None, - width=None, - use_column_width=None, - clamp=False, - channels="RGB", - output_format="auto", - ) - # st.markdown("# Welcome to AmadeusGPT🎻") - - if example == 'Custom': - st.markdown( - "Provide your own video and keypoint file (in pairs)" - ) - uploaded_files = st.file_uploader( - "Choose data or video files to upload", - ["h5", *VIDEO_EXTS], - accept_multiple_files=True, - ) - st.session_state['uploaded_files'] = uploaded_files - check_uploaded_files() - - ###### USER INPUT PANEL ###### - # get user input once getting the uploaded files - disabled = True if len(st.session_state["uploaded_files"])==0 else False - if disabled: - st.warning("Please upload a file before entering text.") - - - if example == "EPM": - st.markdown( - "Elevated plus maze (EPM) is a widely used behavioral test. The mouse is put on an elevated platform with two open arms (without walls) and two closed arms (with walls). \ - In this example we used a video from https://www.nature.com/articles/s41386-020-0776-y." - ) - st.markdown( - "- ⬅️ On the left you can see the video data auto-tracked with DeepLabCut and keypoint names (below). You can also draw ROIs to ask questions to AmadeusGPT🎻 about the ROIs. You can drag the divider between the panels to increase the video/image size." - ) - st.markdown( - "- We suggest you start by clicking 'Generate Response' to our demo queries." - ) - st.markdown( - "- Ask additional questions in the chatbox at the bottom of the page." - ) - st.markdown( - "- Here are some example queries you might consider: 'The <|open arm|> is the ROI0. How much time does the mouse spend in the open arm?' (NOTE here you can re-draw an ROI0 if you want. Be sure to click 'finish drawing') | 'Define head_dips as a behavior where the mouse's mouse_center and neck are in ROI0 which is open arm while head_midpoint is outside ROI1 which is the cross-shape area. When does head_dips happen and what is the number of bouts for head_dips?' " - ) - st.markdown("- ⬇️🎥 Watch this short clip on how to draw the ROI(s)🤗") - st.video("static/customEPMprompt_short.mp4") - - if example == "MABe": - st.markdown( - "MABe Mouse Triplets is part of a behavior benchmark presented in Sun et al 2022 https://arxiv.org/abs/2207.10553. In the videos, three mice exhibit multiple social behaviors including chasing. \ - In this example, we take one video where chasing happens between mice." - ) - st.markdown( - "- ⬅️ On the left you can see the video data and keypoint names, which could be useful for your queries. You can drag the divider between the panels to increase the video/image size." - ) - st.markdown( - "- We suggest you start by clicking 'Generate Response' to our demo queries." - ) - st.markdown( - "- Ask additional questions in the chatbox at the bottom of the page." - ) - - if example == "MausHaus": - st.markdown( - "MausHaus is a dataset that records a freely moving mouse within a rich environment with objects. More details can be found https://arxiv.org/pdf/2203.07436.pdf." - ) - st.markdown( - "- ⬅️ On the left you can see the video data auto-tracked with DeepLabCut, segmented image with SAM, and the keypoint guide, which could be useful for your queries. You can drag the divider between the panels to increase the video/image size." - ) - st.markdown( - "- We suggest you start by clicking 'Generate Response' to our demo queries." - ) - st.markdown( - "- Ask additional questions in the chatbox at the bottom of the page." - ) - st.markdown( - "- Here are some example queries you might consider: 'Give me events where the animal overlaps with the treadmill, which is object 5' | 'Define <|drinking|> as a behavior where the animal's nose is over object 28, which is a waterbasin. The minimum time window for this behavior should be 20 frames. When is the animal drinking?'" - ) - - if example == "Horse": - st.markdown( - "This horse video is part of a benchmark by Mathis et al 2021 https://arxiv.org/abs/1909.11229." - ) - - AnimalBehaviorAnalysis.set_cache_objects(True) - if example == "EPM" or example == 'Custom': - # in EPM and Custom, we allow people add more objects - AnimalBehaviorAnalysis.set_cache_objects(False) - - if st.session_state["example"] != example: - st.session_state["messages"] = Messages() - AmadeusLogger.debug("The user switched dataset") - - st.session_state["example"] = example - - st.session_state["log_folder"] = f"examples/{example}" - - video_file = None - scene_image = None - scene_image_str = None - if example =='Custom': - if st.session_state['uploaded_video_file']: - video_file = st.session_state['uploaded_video_file'] - scene_image_str = get_scene_image(example) - else: - video_file = glob.glob(os.path.join("examples", example, "*.mp4"))[0] - keypoint_file = glob.glob(os.path.join("examples", example, "*.h5"))[0] - AnimalBehaviorAnalysis.set_keypoint_file_path(keypoint_file) - AnimalBehaviorAnalysis.set_video_file_path(video_file) - # get the corresponding scene image for display - scene_image_str = get_scene_image(example) - - if scene_image_str is not None: - img_data = base64.b64decode(scene_image_str) - image_stream = io.BytesIO(img_data) - image_stream.seek(0) - scene_image = Image.open(image_stream) - - - col1, col2, col3 = st.columns([2, 1, 1]) - - sam_image = None - sam_success = set_up_sam() - if example == "MausHaus" or st.session_state['enable_SAM'] == "Yes": - if sam_success: - sam_image = get_sam_image(example) - else: - st.error("Cannot find SAM checkpoints. Skipping SAM") - - with st.sidebar as sb: - if example == "MABe": - st.caption("Raw video from MABe") - elif example == "Horse": - st.caption("Raw video from Horse-30") - else: - st.caption("DeepLabCut-SuperAnimal tracked video") - if video_file: - st.video(video_file) - # we only show objects for MausHaus for demo - if sam_image is not None: - st.caption("SAM segmentation results") - st.image(sam_image, channels="RGBA") - - if ( - st.session_state["example"] == "EPM" - or st.session_state["example"] == "MausHaus" - and scene_image is not None - ): - place_st_canvas(example, scene_image) - - if st.session_state["example"] == 'Custom' and scene_image: - place_st_canvas(example, scene_image) - - - if example == "EPM" or example == "MausHaus": - # will read the keypoints from h5 file to avoid hard coding - with st.sidebar: - st.image("static/images/supertopview.png") - with st.sidebar: - st.write("Keypoints:") - st.write(AnimalBehaviorAnalysis.get_bodypart_names()) - - render_messages() - - AmadeusLogger.log_process_memory(log_position=f"after_display_chats_{example}") - gc.collect() - AmadeusLogger.log_process_memory(log_position=f"after_garbage_collection_{example}") - - -def get_history_chat(chat_time): - csv_file = glob.glob(os.path.join(LOG_DIR, chat_time, "*.csv"))[0] - df = pd.read_csv(csv_file) - return df - - -def get_example_history_chat(example): - if example == "": - return None, None - csv_files = glob.glob(os.path.join("examples", example, "example.csv")) - if len(csv_files) > 0: - csv_file = csv_files[0] - df = pd.read_csv(csv_file) - return csv_file, df - else: - return None, None - - -def save_figure_to_tempfile(fig): - # save the figure - folder_path = os.path.join(st.session_state["log_folder"], "tmp_imgs") - if not os.path.exists(folder_path): - os.makedirs(folder_path) - # Generate a unique temporary filename in the specified folder - temp_file = tempfile.NamedTemporaryFile( - dir=folder_path, suffix=".png", delete=False - ) - filename = temp_file.name - temp_file.close() - fig.savefig( - filename, - format="png", - bbox_inches="tight", - pad_inches=0.0, - dpi=400, - transparent=True, - ) - return filename - - -def make_plot_pretty4dark_mode(fig, ax): - fig = plt.gcf() - fig.set_facecolor("none") - ax = plt.gca() - ax.set_facecolor("none") - # Set axes and legend colors to white or other light colors - ax.spines["bottom"].set_color("white") - ax.spines["top"].set_color("white") - ax.spines["right"].set_color("white") - ax.spines["left"].set_color("white") - - ax.xaxis.label.set_color("white") - ax.yaxis.label.set_color("white") - ax.title.set_color("white") - ax.tick_params(axis="x", colors="white") - ax.tick_params(axis="y", colors="white") - legend = plt.legend() - for text in legend.get_texts(): - text.set_color("white") - - return fig, ax - - -def display_image(temp_file): - full_image = Image.open(temp_file) - st.image(full_image) - - -def display_temp_text(text_content): - # Convert the text content to base64 - text_bytes = text_content.encode("utf-8") - text_base64 = base64.b64encode(text_bytes).decode() - # Display the link to the text file - st.markdown( - f'Check error.', - unsafe_allow_html=True, - ) - - -def style_button_row(clicked_button_ix, n_buttons): - def get_button_indices(button_ix): - return {"nth_child": button_ix, "nth_last_child": n_buttons - button_ix + 1} - - clicked_style = """ - div[data-testid*="stHorizontalBlock"] > div:nth-child(%(nth_child)s):nth-last-child(%(nth_last_child)s) button { - border-color: rgb(255, 75, 75); - color: rgb(255, 75, 75); - box-shadow: rgba(255, 75, 75, 0.5) 0px 0px 0px 0.2rem; - outline: currentcolor none medium; - } - """ - unclicked_style = """ - div[data-testid*="stHorizontalBlock"] > div:nth-child(%(nth_child)s):nth-last-child(%(nth_last_child)s) button { - pointer-events: none; - cursor: not-allowed; - opacity: 0.65; - filter: alpha(opacity=65); - -webkit-box-shadow: none; - box-shadow: none; - } - """ - style = "" - for ix in range(n_buttons): - ix += 1 - if ix == clicked_button_ix: - style += clicked_style % get_button_indices(ix) - else: - style += unclicked_style % get_button_indices(ix) - st.markdown(f"", unsafe_allow_html=True) From 7a8d7103c8f7bba9774ca51ff45567ddbef607d6 Mon Sep 17 00:00:00 2001 From: Mackenzie Mathis Date: Sat, 16 Dec 2023 12:54:02 -0800 Subject: [PATCH 5/8] include app and app_utils --- setup.cfg | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index c204778..662a1d2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -63,7 +63,12 @@ streamlit = streamlit-profiler [options.package_data] -amadeusgpt = interface.txt +amadeusgpt = interface.txt, app.py, app_utils.py + + +[options.entry_points] +console_scripts = + launch_amadeusGPT = amadeusgpt.app:main [bdist_wheel] universal=1 From 3a403a5f97a628d5da0df92b3498657243d042af Mon Sep 17 00:00:00 2001 From: Mackenzie Mathis Date: Sat, 16 Dec 2023 12:57:58 -0800 Subject: [PATCH 6/8] Update app.py - title fix --- amadeusgpt/app.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/amadeusgpt/app.py b/amadeusgpt/app.py index 6818a4f..cf6e380 100644 --- a/amadeusgpt/app.py +++ b/amadeusgpt/app.py @@ -11,7 +11,7 @@ def main(): #import app_utils - st.title("Your Streamlit App") + st.title("AmadeusGPT") def fetch_user_headers(): @@ -325,4 +325,4 @@ def welcome_page(text): # key="chat_download", # ) if __name__ == "__main__": - main() \ No newline at end of file + main() From 2c18712ea78e8b8f703586590fade710bd465c5b Mon Sep 17 00:00:00 2001 From: MacKenzie Mathis Date: Sat, 16 Dec 2023 13:30:31 -0800 Subject: [PATCH 7/8] app edits --- amadeusgpt/app.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/amadeusgpt/app.py b/amadeusgpt/app.py index cf6e380..e86dba1 100644 --- a/amadeusgpt/app.py +++ b/amadeusgpt/app.py @@ -8,11 +8,23 @@ import requests from amadeusgpt import app_utils +st._is_running_with_streamlit = True +os.environ["streamlit_app"] = "True" +assert "streamlit_app" in os.environ, "The 'streamlit_app' environment variable is not set." + + def main(): - #import app_utils st.title("AmadeusGPT") + from amadeusgpt.utils import validate_openai_api_key + import time + from streamlit_profiler import Profiler + + # Initialize session state variables if not present + if "exist_valid_openai_api_key" not in st.session_state: + st.session_state["exist_valid_openai_api_key"] = False + def fetch_user_headers(): """Fetch user and email info from HTTP headers. @@ -62,9 +74,6 @@ def fetch_user_info(): if f"database" not in st.session_state: st.session_state[f"database"] = defaultdict(dict) - from amadeusgpt.utils import validate_openai_api_key - import time - from streamlit_profiler import Profiler # TITLE PANEL st.set_page_config(layout="wide") @@ -314,15 +323,6 @@ def welcome_page(text): AmadeusLogger.store_chats("errors", str(e) + "\n" + traceback.format_exc()) AmadeusLogger.debug(traceback.format_exc()) - # with st.sidebar as sb: - # if "chat_history" in st.session_state and 'creation_time' in st.session_state: - # if example_bar != "Welcome": - # st.download_button( - # label="Download current chat", - # data=st.session_state["chat_history"], - # file_name=f"conversations_{st.session_state['creation_time']}.csv", - # mime="text/csv", - # key="chat_download", - # ) + if __name__ == "__main__": main() From 0f54c12e38440e0f4116197bb02428e40f6af3ea Mon Sep 17 00:00:00 2001 From: MacKenzie Mathis Date: Sat, 16 Dec 2023 14:36:16 -0800 Subject: [PATCH 8/8] closer --- amadeusgpt/app.py | 65 +++++------------------------------------------ 1 file changed, 6 insertions(+), 59 deletions(-) diff --git a/amadeusgpt/app.py b/amadeusgpt/app.py index e86dba1..34ebbce 100644 --- a/amadeusgpt/app.py +++ b/amadeusgpt/app.py @@ -7,81 +7,28 @@ from datetime import datetime import requests from amadeusgpt import app_utils +from amadeusgpt.utils import validate_openai_api_key st._is_running_with_streamlit = True os.environ["streamlit_app"] = "True" -assert "streamlit_app" in os.environ, "The 'streamlit_app' environment variable is not set." - +assert "streamlit_app" in os.environ, "The 'streamlit_app' environment variable is not set!" +# Initialize session state variables if not present +if "exist_valid_openai_api_key" not in st.session_state: + st.session_state["exist_valid_openai_api_key"] = False def main(): + subprocess.run(["./launch_app"], check=True) st.title("AmadeusGPT") - from amadeusgpt.utils import validate_openai_api_key import time from streamlit_profiler import Profiler - # Initialize session state variables if not present - if "exist_valid_openai_api_key" not in st.session_state: - st.session_state["exist_valid_openai_api_key"] = False - - - def fetch_user_headers(): - """Fetch user and email info from HTTP headers. - - Output of this function is identical to querying - https://amadeusgpt.kinematik.ai/oauth2/userinfo, but - works from within the streamlit app. - """ - # TODO(stes): This could change without warning n future streamlit - # versions. So I'll leave the import here in case sth should go - # wrong in the future - from streamlit.web.server.websocket_headers import _get_websocket_headers - - headers = _get_websocket_headers() - AmadeusLogger.debug(f"Received Headers: {headers}") - return dict( - email=headers.get("X-Forwarded-Email", "no_email_in_header"), - user=headers.get("X-Forwarded-User", "no_user_in_header"), - ) - - - def fetch_user_info(): - url = "https://amadeusgpt.kinematik.ai/oauth2/userinfo" - try: - return fetch_user_headers() - # TODO(stes): Lets be on the safe side for now. - except Exception as e: - AmadeusLogger.info(f"Error: {e}") - return None - - - if "streamlit_app" in os.environ: - if "session_id" not in st.session_state: - session_id = str(uuid.uuid4()) - st.session_state["session_id"] = session_id - user_info = fetch_user_info() - if user_info: - st.session_state["username"] = user_info.get("user", "fake_username") - st.session_state["email"] = user_info.get("email", "fake_email") - else: - AmadeusLogger.info("Getting None from the endpoint") - st.session_state["username"] = "no_username" - st.session_state["email"] = "no_email" - - AmadeusLogger.debug("A new user logs in ") - - if f"database" not in st.session_state: - st.session_state[f"database"] = defaultdict(dict) - - # TITLE PANEL st.set_page_config(layout="wide") app_utils.load_css("static/styles/style.css") - assert "streamlit_app" in os.environ - ###### Initialize ###### if "amadeus" not in st.session_state: st.session_state["amadeus"] = app_utils.summon_the_beast()[0]