Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export streamlit_app=True
app:

streamlit run amadeusgpt/app.py --server.fileWatcherType none
streamlit run amadeusgpt/app.py --server.fileWatcherType none --server.maxUploadSize 1000
2 changes: 0 additions & 2 deletions amadeusgpt/app_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,6 @@ def update_df_data(new_item, index_to_update, df, csv_file):
return csv_file, df


@st.cache_data(persist="disk")
def get_scene_image(example):
if AnimalBehaviorAnalysis.get_video_file_path() is not None:
scene_image = Scene.get_scene_frame()
Expand All @@ -599,7 +598,6 @@ def get_scene_image(example):
return get_scene_image(example)


@st.cache_data(persist="disk")
def get_sam_image(example):
if AnimalBehaviorAnalysis.get_video_file_path():
seg_objects = AnimalBehaviorAnalysis.get_seg_objects()
Expand Down
38 changes: 30 additions & 8 deletions amadeusgpt/implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
scene_frame_number = 0



class Database:
"""
A singleton that stores all data. Should be easy to integrate with a Nonsql database
Expand Down Expand Up @@ -1218,9 +1219,19 @@ def get_objects(self, video_file_path):
else:
return self.pickledata




class AnimalBehaviorAnalysis:
"""
This class holds methods and objects that are useful for analyzing animal behavior.
It no longer holds the states of objects directly. Instead, it references to the Database
singleton object. This is to make the class more stateless and easier to use in a web app.
"""

# to be deprecated
task_programs = {}
# to be deprecated
task_program_results = {}
# if a function has a parameter, it assumes the result_buffer has it
# special dataset flags set to be False
Expand Down Expand Up @@ -1251,7 +1262,7 @@ def result_buffer(cls):
@classmethod
def release_cache_objects(cls):
"""
For web app, switching from one example to the other requires a release of cached objects
For web app, switching from one example to the another requires a release of cached objects
"""
if Database.exist(cls.__name__, "animal_objects"):
Database.delete(cls.__name__, "animal_objects")
Expand Down Expand Up @@ -1914,21 +1925,32 @@ def reject_outlier_keypoints(cls, keypoints, threshold_in_stds=2):
return temp

@classmethod
def ast_fillna_2d(cls, arr):
def ast_fillna_2d(cls, arr: np.ndarray) -> np.ndarray:
"""
Fills NaN values in a 4D keypoints array using linear interpolation.

Parameters:
arr (np.ndarray): A 4D numpy array of shape (n_frames, n_individuals, n_kpts, n_dims).

Returns:
np.ndarray: The 4D array with NaN values filled.
"""
n_frames, n_individuals, n_kpts, n_dims = arr.shape
arr_reshaped = arr.reshape(n_frames, -1)
x = np.arange(n_frames)
for i in range(arr_reshaped.shape[1]):
valid_mask = ~np.isnan(arr_reshaped[:, i])
if np.all(valid_mask):
continue
arr_reshaped[:, i] = np.interp(
x, x[valid_mask], arr_reshaped[valid_mask, i]
)
# Reshape the array back to 4D
arr = arr_reshaped.reshape(n_frames, n_individuals, n_kpts, n_dims)
elif np.any(valid_mask):
# Perform interpolation when there are some valid points
arr_reshaped[:, i] = np.interp(x, x[valid_mask], arr_reshaped[valid_mask, i])
else:
# Handle the case where all values are NaN
# Replace with a default value or another suitable handling
arr_reshaped[:, i].fill(0) # Example: filling with 0

return arr
return arr_reshaped.reshape(n_frames, n_individuals, n_kpts, n_dims)

@classmethod
@timer_decorator
Expand Down
42 changes: 28 additions & 14 deletions amadeusgpt/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,11 @@ class AMADEUS:
code_generator_brain.enforce_prompt = ""
usage = 0
behavior_modules_in_context = True
# to save the behavior module strings for context window use
# load the integration modules to context
smart_loading = True
# number of topk integration modules to load
load_module_top_k = 3
module_threshold = 0.7
context_window_dict = {}
plot = False
use_rephraser = True
Expand All @@ -124,6 +126,7 @@ def release_cache_objects(cls):

@classmethod
def load_module_smartly(cls, user_input):
# TODO: need to improve the module matching by vector database
sorted_query_results = match_module(user_input)
if len(sorted_query_results) == 0:
return None
Expand All @@ -134,7 +137,7 @@ def load_module_smartly(cls, user_input):
query_module = query_result[0]
query_score = query_result[1][0][0]

if query_score > 0.7:
if query_score > cls.module_threshold:
modules.append(query_module)
# parse the query result by loading active loading
module_path = os.sep.join(query_module.split(os.sep)[-2:]).replace(
Expand All @@ -158,7 +161,7 @@ def magic_command(cls, user_input):
AmadeusLogger.info(result.stdout.decode("utf-8"))

@classmethod
def save_state(cls):
def save_state(cls, output_path = 'soul.pickle'):
# save the class attributes of all classes that are under state_list.
def get_class_variables(_class):
return {
Expand All @@ -171,15 +174,14 @@ def get_class_variables(_class):

state = {k.__name__: get_class_variables(k) for k in cls.state_list}

output_filename = "soul.pickle"
with open(output_filename, "wb") as f:
with open(output_path, "wb") as f:
pickle.dump(state, f)
AmadeusLogger.info(f"memory saved to {output_filename}")
AmadeusLogger.info(f"memory saved to {output_path}")

@classmethod
def load_state(cls):
def load_state(cls, ckpt_path = 'soul.pickle'):
# load the class variables into 3 class
memory_filename = "soul.pickle"
memory_filename = ckpt_path
AmadeusLogger.info(f"loading memory from {memory_filename}")
with open(memory_filename, "rb") as f:
state = pickle.load(f)
Expand Down Expand Up @@ -296,6 +298,7 @@ def chat(
cls.interface_str, cls.behavior_modules_str
)
cls.code_generator_brain.update_history("user", rephrased_user_msg)

response = cls.code_generator_brain.connect_gpt(
cls.code_generator_brain.context_window, max_tokens=700, functions=functions
)
Expand All @@ -307,10 +310,12 @@ def chat(
thought_process,
) = cls.code_generator_brain.parse_openai_response(response)

# write down the task program for offline processing
with open("temp_for_debug.json", "w") as f:
out = {'function_code': function_code,
'query': rephrased_user_msg}
json.dump(out, f, indent=4)

# handle_function_codes gives the answer with function outputs
amadeus_answer = cls.core_loop(
rephrased_user_msg, text, function_code, thought_process
Expand All @@ -321,16 +326,19 @@ def chat(
original_user_msg, amadeus_answer.function_code, code_output
)

# if there is an error or the function code is empty, we want to make sure we prevent ChatGPT to learn to output nothing from few-shot learning
# is this used anymore?

# Could be used for in context feedback learning. Costly
if amadeus_answer.has_error:
cls.code_generator_brain.context_window[-1][
"content"
] += "\n While executing the code above, there was error so it is not correct answer\n"

elif amadeus_answer.has_error:
cls.code_generator_brain.context_window.pop()
cls.code_generator_brain.history.pop()

# if there is an error or the function code is empty, we want to make sure we prevent ChatGPT to learn to output nothing from few-shot learning
#elif amadeus_answer.has_error:
# cls.code_generator_brain.context_window.pop()
# cls.code_generator_brain.history.pop()

else:
# needs to manage memory of Amadeus for context window management and state restore etc.
# we have it remember user's original question instead of the rephrased one for better
Expand All @@ -351,10 +359,13 @@ def execute_python_function(
exec(function_code, globals())
if "task_program" not in globals():
return None

# TODO: to serialize and support different function arguments
func_sigs = inspect.signature(task_program)
if not func_sigs.parameters:
result = task_program()
else:
# TODO: We don't do this anymore. But in the future, Is passing result buffer from each function sustainable?
result_buffer = AnimalBehaviorAnalysis.result_buffer
AmadeusLogger.info(f"result_buffer: {result_buffer}")
if isinstance(result_buffer, tuple):
Expand All @@ -368,9 +379,11 @@ def execute_python_function(
@classmethod
def contribute(cls, program_name):
"""
Deprecated
Takes the program from the task program registry and write it into contribution folder
TODO: split the task program into implementation and api
"""
"""

AmadeusLogger.info(f"contributing {program_name}")
task_program = AnimalBehaviorAnalysis.task_programs[program_name]
# removing add_symbol or add_task_program line
Expand All @@ -390,6 +403,7 @@ def update_behavior_modules_str(cls):
Called during loading behavior modules from disk or when task program is updated
"""
modules_str = []
# context_window_dict is where integration modules are stored in current AMADEUS class
for name, task_program in cls.context_window_dict.items():
modules_str.append(task_program)
modules_str = modules_str[-cls.load_module_top_k :]
Expand Down