diff --git a/source/data labeler/main.py b/source/data labeler/labeler.py similarity index 99% rename from source/data labeler/main.py rename to source/data labeler/labeler.py index 7b2d15e4f..92858f8d5 100644 --- a/source/data labeler/main.py +++ b/source/data labeler/labeler.py @@ -21,7 +21,6 @@ gesture_vector = [] - def main(): """Main driver method which initilizes all children and starts pygame render pipeline""" @@ -34,7 +33,7 @@ def main(): hands_surface.set_colorkey((0, 0, 0)) myRenderHands = RenderHands(hands_surface, 3) - filename = "wave.csv" + filename = "test.csv" myReader = Reader(filename) gesture_list = [ diff --git a/source/gesture recognition/GetHands.py b/source/data recorder/RecordHands.py similarity index 59% rename from source/gesture recognition/GetHands.py rename to source/data recorder/RecordHands.py index e55cdbeff..a56f08458 100644 --- a/source/gesture recognition/GetHands.py +++ b/source/data recorder/RecordHands.py @@ -1,16 +1,18 @@ # https://github.com/nrsyed/computer-vision/blob/master/multithread/VideoShow.py from threading import Thread -import cv2 import mediapipe as mp import time -import numpy as np import math -from FeedForward import NeuralNet import traceback -from Console import GestureConsole +from Webcam import Webcam -class GetHands: +import os +abspath = os.path.abspath(__file__) +dname = os.path.dirname(abspath) +os.chdir(dname) + +class RecordHands: """ Class that continuously gets frames and extracts hand data with a dedicated thread and Mediapipe @@ -22,14 +24,11 @@ def __init__( surface=None, show_window=False, confidence=0.5, - webcam_id=0, model_path="hand_landmarker.task", - control_mouse=None, write_csv=None, gesture_list=None, gesture_confidence=0.50, flags=None, - keyboard=None, ): """Builds a Mediapipe hand model and a PyTorch gesture recognition model @@ -55,30 +54,23 @@ def __init__( self.render_hands = render_hands self.confidence = confidence self.stopped = False - self.last_origin = [(0, 0)] - self.control_mouse = control_mouse self.write_csv = write_csv self.gesture_vector = flags["gesture_vector"] self.gesture_list = gesture_list self.gesture_confidence = gesture_confidence self.flags = flags self.sensitinity = 0.05 - self.keyboard = keyboard - self.console = GestureConsole() + self.last_origin = [(0, 0)] + self.camera = Webcam() - self.gesture_model = NeuralNet("SimpleModel.pth") - # OpenCV setup - self.stream = cv2.VideoCapture(webcam_id) - # motion JPG format - self.stream.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc("M", "J", "P", "G")) - (self.grabbed, self.frame) = self.stream.read() - self.frame = cv2.flip(self.frame, 1) + (self.grabbed, self.frame) = self.camera.read() + self.last_timestamp = mp.Timestamp.from_seconds(time.time()).value self.timer1 = 0 self.timer2 = 0 - self.build_model(flags["number_of_hands"]) + self.build_model(1) def build_model(self, hands_num): """Takes in option parameters for the Mediapipe hands model @@ -104,84 +96,6 @@ def build_model(self, hands_num): # build hands model self.hands_detector = self.HandLandmarker.create_from_options(self.options) - def gesture_input(self, result, velocity): - """Converts Mediapipe landmarks and a velocity into a format usable by the gesture recognition model - - Args: - result (Mediapipe.hands.result): The result object returned by Mediapipe - velocity ([(float, float)]): An array of tuples containing the velocity of hands - - Returns: - array: An array of length 65 - """ - model_inputs = [] - - for index, hand in enumerate(result.hand_world_landmarks): - model_inputs.append([]) - for point in hand: - model_inputs[index].append(point.x) - model_inputs[index].append(point.y) - model_inputs[index].append(point.z) - if velocity != []: - model_inputs[index].append(velocity[index][0]) - model_inputs[index].append(velocity[index][1]) - - out = [] - for input in model_inputs: - out.append(np.array([input], dtype="float32")) - - return out - - def find_velocity_and_location(self, result): - """Given a Mediapipe result object, calculates the velocity and origin of hands. - - Args: - result (Mediapipe.hands.result): Direct output object from Mediapipe hands model - - Returns: - (origins, velocity): A tuple containing an array of tuples representing hand origins, and an array of tuples containing hand velocitys - """ - - normalized_origin_offset = [] - hands_location_on_screen = [] - velocity = [] - - for hand in result.hand_world_landmarks: - # take middle finger knuckle - normalized_origin_offset.append(hand[9]) - - for index, hand in enumerate(result.hand_landmarks): - originX = hand[9].x - normalized_origin_offset[index].x - originY = hand[9].y - normalized_origin_offset[index].y - originZ = hand[9].z - normalized_origin_offset[index].z - hands_location_on_screen.append((originX, originY, originZ)) - velocityX = self.last_origin[index][0] - hands_location_on_screen[index][0] - velocityY = self.last_origin[index][1] - hands_location_on_screen[index][1] - velocity.append((velocityX, velocityY)) - self.last_origin = hands_location_on_screen - - return hands_location_on_screen, velocity - - def move_mouse(self, hands_location_on_screen, mouse_button_text): - """Wrapper method to control the mouse - - Args: - hands_location_on_screen (origins): The origins result from find_velocity_and_location() - mouse_button_text (str): Type of click - """ - if callable(self.control_mouse): - if hands_location_on_screen != []: - # (0,0) is the top left corner - self.control_mouse( - hands_location_on_screen[0][0], - hands_location_on_screen[0][1], - mouse_button_text, - ) - - def reset_gesture_vector(self): - for i in range(len(self.gesture_vector) - 1): - self.gesture_vector[i] = "0" - def results_callback( self, result: mp.tasks.vision.HandLandmarkerResult, @@ -190,47 +104,7 @@ def results_callback( ): # this try catch block is for debuggin. this code runs in a different thread and doesn't automatically raise its own exceptions try: - hands_location_on_screen, velocity = self.find_velocity_and_location(result) - - self.reset_gesture_vector() - - if self.flags["run_model_flag"]: - model_input = self.gesture_input(result, velocity) - - table = [] - - for index, hand in enumerate(model_input): - - row = [] - - row.append(str(index)) - - self.reset_gesture_vector() - confidence, gesture = self.gesture_model.get_gesture(hand) - - self.gesture_vector[gesture[0]] = "1" - - row.append(str(f"{confidence[0]:.3f}")) - row.append(self.gesture_list[gesture[0]]) - - if index == 0: - if gesture[0] == 0: - self.keyboard.press("space") - if gesture[0] == 1: - self.keyboard.press("none") - if gesture[0] == 2: - self.keyboard.press("toggle") - - table.append(row) - - self.console.generate_table(table) - - mouse_button_text = "" - if self.flags["move_mouse_flag"] and hands_location_on_screen != []: - hand = result.hand_world_landmarks[0] - if self.is_clicking(hand[8], hand[4]): - mouse_button_text = "left" - self.move_mouse(hands_location_on_screen, mouse_button_text) + location, velocity = self.find_velocity_and_location(result) # write to CSV # flag for writing is saved in the last index of this vector @@ -249,25 +123,14 @@ def results_callback( output_image, (total_delay, hands_delay), self.surface, - self.flags["render_hands_mode"], - hands_location_on_screen, + location, velocity, - mouse_button_text, ) except Exception as e: traceback.print_exc() quit() - def is_clicking(self, tip1, tip2): - distance = math.sqrt( - (tip1.x - tip2.x) ** 2 + (tip1.y - tip2.y) ** 2 + (tip1.z - tip2.z) ** 2 - ) - if distance < self.sensitinity: - return True - else: - return False - def start(self): Thread(target=self.run, args=()).start() return self @@ -285,17 +148,13 @@ def run(self): """Continuously grabs new frames from the webcam and uses Mediapipe to detect hands""" while not self.stopped: if not self.grabbed: - self.stream.release() - cv2.destroyAllWindows() + self.camera.stop() self.stop() else: - (self.grabbed, self.frame) = self.stream.read() - self.frame = cv2.flip(self.frame, 1) + (self.grabbed, self.frame) = self.camera.read() # Detect hand landmarks self.detect_hands(self.frame) - if self.show_window: - self.show() def detect_hands(self, frame): """Wrapper function for Mediapipe's hand detector in livestream mode @@ -309,12 +168,35 @@ def detect_hands(self, frame): ) self.timer1 = mp.Timestamp.from_seconds(time.time()).value - def show(self): - """Displays another window with the raw webcam stream""" - cv2.imshow("Video", self.frame) - if cv2.waitKey(1) == ord("q"): - self.stopped = True - cv2.destroyAllWindows() - def stop(self): self.stopped = True + + def find_velocity_and_location(self, result): + """Given a Mediapipe result object, calculates the velocity and origin of hands. + + Args: + result (Mediapipe.hands.result): Direct output object from Mediapipe hands model + + Returns: + (origins, velocity): A tuple containing an array of tuples representing hand origins, and an array of tuples containing hand velocitys + """ + + normalized_origin_offset = [] + hands_location_on_screen = [] + velocity = [] + + for hand in result.hand_world_landmarks: + # take middle finger knuckle + normalized_origin_offset.append(hand[9]) + + for index, hand in enumerate(result.hand_landmarks): + originX = hand[9].x - normalized_origin_offset[index].x + originY = hand[9].y - normalized_origin_offset[index].y + originZ = hand[9].z - normalized_origin_offset[index].z + hands_location_on_screen.append((originX, originY, originZ)) + velocityX = self.last_origin[index][0] - hands_location_on_screen[index][0] + velocityY = self.last_origin[index][1] - hands_location_on_screen[index][1] + velocity.append((velocityX, velocityY)) + self.last_origin = hands_location_on_screen + + return hands_location_on_screen, velocity diff --git a/source/gesture recognition/RenderHands.py b/source/data recorder/RenderHands.py similarity index 79% rename from source/gesture recognition/RenderHands.py rename to source/data recorder/RenderHands.py index e566dc62a..1953d9313 100644 --- a/source/gesture recognition/RenderHands.py +++ b/source/data recorder/RenderHands.py @@ -1,6 +1,5 @@ import pygame - class RenderHands: """Given the Mediapipe hands output data, renders the hands in a normilzed view or camera perspective view on a pygame surface""" @@ -16,7 +15,7 @@ def __init__(self, surface, hand_scale): self.font = pygame.font.Font("freesansbold.ttf", 30) self.last_velocity = [(0.5, 0.5)] - def connections(self, landmarks, mode): + def connections(self, landmarks, ): """Renders lines between hand joints Args: @@ -31,20 +30,13 @@ def connections(self, landmarks, mode): for index, hand in enumerate(landmarks): xy.append([]) for point in hand: - if mode: - xy[index].append( - ( - (point.x * self.hand_scale + 0.5) * w, - (point.y * self.hand_scale + 0.5) * h, - ) - ) - else: - xy[index].append( - ( - point.x * w, - point.y * h, - ) + xy[index].append( + ( + (point.x * self.hand_scale + 0.5) * w, + (point.y * self.hand_scale + 0.5) * h, ) + ) + for hand in range(len(xy)): # thumb @@ -85,7 +77,7 @@ def render_line(self, start, end): pygame.draw.line(self.surface, (255, 255, 255), start, end, 5) def render_hands( - self, result, output_image, delay_ms, surface, mode, origins, velocity, pinch + self, result, output_image, delay_ms, surface, origins, velocity ): """ Renders the hands and other associated data from Mediapipe onto a pygame surface. @@ -103,28 +95,21 @@ def render_hands( surface.fill((0, 0, 0)) # Render hand landmarks # print(delay_ms) - if pinch != "": - text = self.font.render(pinch, False, (255, 255, 255)) - surface.blit(text, (0, 90)) - w, h = surface.get_size() if result.handedness != []: - if mode: - hand_points = result.hand_world_landmarks - pygame.draw.circle(surface, (255, 0, 255), (0.5 * w, 0.5 * h), 5) - pygame.draw.line( - self.surface, - (255, 255, 0), - ((velocity[0][0] + 0.5) * w, (velocity[0][1] + 0.5) * h), - ((0.5) * w, (0.5) * h), - 3, - ) - self.last_velocity = velocity - - else: - hand_points = result.hand_landmarks - self.connections(hand_points, mode) + hand_points = result.hand_world_landmarks + pygame.draw.circle(surface, (255, 0, 255), (0.5 * w, 0.5 * h), 5) + pygame.draw.line( + self.surface, + (255, 255, 0), + ((velocity[0][0] + 0.5) * w, (velocity[0][1] + 0.5) * h), + ((0.5) * w, (0.5) * h), + 3, + ) + self.last_velocity = velocity + + self.connections(hand_points) if hand_points: # define colors for different hands @@ -146,11 +131,10 @@ def render_hands( landmark.y, surface, delay_ms, - mode, ) hand_color += 1 - def __render_hands_pygame(self, color, x, y, surface, delay_ms, mode): + def __render_hands_pygame(self, color, x, y, surface, delay_ms): """Renders a single landmark of a hand in pygame and scales the hand. Args: @@ -164,11 +148,10 @@ def __render_hands_pygame(self, color, x, y, surface, delay_ms, mode): w, h = self.surface.get_size() - if mode: - x *= self.hand_scale - y *= self.hand_scale - x += 0.5 - y += 0.5 + x *= self.hand_scale + y *= self.hand_scale + x += 0.5 + y += 0.5 pygame.draw.circle(surface, color, (x * w, y * h), 5) delay_cam = self.font.render( diff --git a/source/data recorder/Webcam.py b/source/data recorder/Webcam.py new file mode 100644 index 000000000..92140e902 --- /dev/null +++ b/source/data recorder/Webcam.py @@ -0,0 +1,28 @@ +# https://github.com/nrsyed/computer-vision/blob/master/multithread/VideoShow.py + +import cv2 +class Webcam: + """ + Class that continuously gets frames and extracts hand data + with a dedicated thread and Mediapipe + """ + def __init__( + self, + webcam_id=0, + ): + # OpenCV setup + self.stream = cv2.VideoCapture(webcam_id) + # motion JPG format + self.stream.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc("M", "J", "P", "G")) + (self.grabbed, self.frame) = self.stream.read() + self.frame = cv2.flip(self.frame, 1) + + def stop(self): + self.stream.release() + cv2.destroyAllWindows() + + def read(self, flip=True): + (self.grabbed, self.frame) = self.stream.read() + if flip: + self.frame = cv2.flip(self.frame, 1) + return (self.grabbed, self.frame) \ No newline at end of file diff --git a/source/gesture recognition/Writer.py b/source/data recorder/Writer.py similarity index 100% rename from source/gesture recognition/Writer.py rename to source/data recorder/Writer.py diff --git a/source/gesture recognition/hand_landmarker.task b/source/data recorder/hand_landmarker.task similarity index 100% rename from source/gesture recognition/hand_landmarker.task rename to source/data recorder/hand_landmarker.task diff --git a/source/gesture recognition/main.py b/source/data recorder/recordData.py similarity index 79% rename from source/gesture recognition/main.py rename to source/data recorder/recordData.py index f1cd48c6a..f335dfecc 100644 --- a/source/gesture recognition/main.py +++ b/source/data recorder/recordData.py @@ -1,16 +1,12 @@ # https://developers.google.com/mediapipe/framework/getting_started/gpu_support # https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html # https://pygame-menu.readthedocs.io/en/latest/_source/add_widgets.html -import pydoc import pygame import pygame_menu -from GetHands import GetHands +from RecordHands import RecordHands from RenderHands import RenderHands -from Mouse import Mouse from Writer import Writer -from Keyboard import Keyboard import os -from Console import GestureConsole abspath = os.path.abspath(__file__) dname = os.path.dirname(abspath) @@ -22,15 +18,10 @@ clock = pygame.time.Clock() flags = { - "render_hands_mode": True, "gesture_vector": [], - "number_of_hands": 2, - "move_mouse_flag": False, - "run_model_flag": False, + "number_of_hands": 1, } -console = GestureConsole() - def main() -> None: """Main driver method which initilizes all children and starts pygame render pipeline""" @@ -46,10 +37,6 @@ def main() -> None: gesture_list = ["fist", "palm", "pinky"] - myWriter = Writer(gesture_list=gesture_list, write_labels=True) - - mouse_controls = Mouse(mouse_scale=2) - gesture_menu_selection = [] for index, gesture in enumerate(gesture_list): @@ -58,21 +45,16 @@ def main() -> None: flags["gesture_vector"].append(False) - keyboard = Keyboard( - threshold=0, toggle_key_threshold=0.3, toggle_mouse_func=toggle_mouse - ) - + myWriter = Writer(gesture_list=gesture_list, write_labels=True) # control_mouse=mouse_controls.control, - hands = GetHands( + hands = RecordHands( myRenderHands.render_hands, show_window=True, surface=hands_surface, confidence=0.5, - control_mouse=mouse_controls.control, write_csv=myWriter.write, gesture_list=gesture_list, flags=flags, - keyboard=keyboard, ) menu = pygame_menu.Menu( @@ -81,18 +63,22 @@ def main() -> None: window_height * 0.8, theme=pygame_menu.themes.THEME_BLUE, ) + + def change_write(value, tuple): + (write_status, writer) = tuple + writer.write_labels=write_status menu.add.selector( - "Render Mode :", [("Normalized", True), ("World", False)], onchange=set_coords + "Labels :", [("From Selection", (True, myWriter)), ("Don't save labels", (False, myWriter))], onchange=change_write ) + + menu.add.dropselect( "Gesture :", gesture_menu_selection, onchange=set_current_gesture ) menu.add.button("Close Menu", pygame_menu.events.CLOSE) - menu.add.button("Turn On Model", action=toggle_model) - menu.add.button("Turn On Mouse", action=toggle_mouse) menu.add.button("Quit", pygame_menu.events.EXIT) menu.enable() @@ -112,17 +98,6 @@ def main() -> None: pygame.quit() -def toggle_mouse() -> None: - """Enable or disable mouse control""" - console.print("toggling mouse control") - flags["move_mouse_flag"] = not flags["move_mouse_flag"] - - -def toggle_model() -> None: - console.print("toggling model") - flags["run_model_flag"] = not flags["run_model_flag"] - - def set_write_status() -> None: """Tell the the writer class to write data""" flags["gesture_vector"][len(flags["gesture_vector"]) - 1] = not flags[ @@ -130,16 +105,6 @@ def set_write_status() -> None: ][len(flags["gesture_vector"]) - 1] -def set_coords(value, mode) -> None: - """Defines the coordinate space for rendering hands - - Args: - value (_type_): used by pygame_menu - mode (_type_): True for normalized, False for world - """ - flags["render_hands_mode"] = mode - - def set_current_gesture(value, index) -> None: """Define the current gesutre of a matching gesture list @@ -196,8 +161,6 @@ def game_loop( set_write_status() - if event.key == pygame.K_m: - toggle_mouse() if menu.is_enabled(): menu.update(events) diff --git a/source/gesture recognition/Console.py b/source/gesture inference/Console.py similarity index 73% rename from source/gesture recognition/Console.py rename to source/gesture inference/Console.py index 54c227f05..1a152a307 100644 --- a/source/gesture recognition/Console.py +++ b/source/gesture inference/Console.py @@ -1,4 +1,3 @@ - from rich.live import Live from rich.table import Table from rich.layout import Layout @@ -7,10 +6,11 @@ import os class GestureConsole: - #make this class a singleton + # make this class a singleton _initialized = False + def __new__(cls): - if not hasattr(cls, 'instance'): + if not hasattr(cls, "instance"): cls.instance = super(GestureConsole, cls).__new__(cls) return cls.instance @@ -23,27 +23,29 @@ def __init__(self) -> None: self.live = Live(self.layout, auto_refresh=False) self.live.start() - def generate_table(self, outputs: str): + def table(self, headers, rows): table = Table() - table.add_column("Hand") - table.add_column("Confidence") - table.add_column("Gesture") + for header in headers: + table.add_column(header) - for output in outputs: - table.add_row(output[0], output[1], output[2]) + for row in rows: + table_row = [] + for item in row: + table_row.append(str(f"{float(item):.3f}")) + table.add_row(*table_row) self.layout["upper"].update(Panel(table)) - self.update() + self.update() def print(self, string: str): self.console.print(string) self.layout["lower"].update(Panel(self.console)) self.update() - + def update(self): self.live.update(self.layout, refresh=True) -#https://stackoverflow.com/questions/71077706/redirect-print-and-or-logging-to-panel +# https://stackoverflow.com/questions/71077706/redirect-print-and-or-logging-to-panel class ConsolePanel(Console): def __init__(self, *args, **kwargs): console_file = open(os.devnull, "w") diff --git a/source/gesture inference/FeedForward.py b/source/gesture inference/FeedForward.py new file mode 100644 index 000000000..adc21c6d1 --- /dev/null +++ b/source/gesture inference/FeedForward.py @@ -0,0 +1,121 @@ +import torch.nn as nn +import torch +import numpy as np +from Console import GestureConsole + + +class NeuralNet(nn.Module): + + def __init__(self, modelName): + # Device configuration + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model, data = torch.load(modelName, map_location=device) + + # model hyperparameters, saved in the model file with its statedict from my train program + input_size = data[0] + hidden_size = data[1] + num_classes = data[2] + self.labels = data[3] + self.confidence_vector = [] + self.input_size = input_size + self.last_origin = [(0, 0)] + + self.console = GestureConsole() + self.console.print(device) + + # model definition + super(NeuralNet, self).__init__() + self.l1 = nn.Linear(input_size, hidden_size) + self.relu = nn.ReLU() + self.l2 = nn.Linear(hidden_size, num_classes) + self.load_state_dict(model) + self.eval() + + def forward(self, x): + """Runs a forward pass of the gesture model + + Args: + x (_type_): _description_ + + Returns: + _type_: _description_ + """ + out = self.l1(x) + out = self.relu(out) + out = self.l2(out) + # no activation and no softmax at the end + return out + + def get_gesture(self, model_input, print_table=True): + """ One hand input shape should be (1,65) + + Two hand input shape should be (2, 65) + """ + hands = torch.from_numpy(np.asarray(model_input, dtype="float32")) + outputs = self(hands) + probs = torch.nn.functional.softmax(outputs.data, dim=1) + + self.confidence_vector = probs + + # print table + if print_table: + self.console.table(self.labels, probs.tolist()) + + confidence, classes = torch.max(probs, 1) + return probs.tolist(), classes.numpy().tolist(), confidence.tolist() + + def find_velocity_and_location(self, result): + """Given a Mediapipe result object, calculates the velocity and origin of hands. + + Args: + result (Mediapipe.hands.result): Direct output object from Mediapipe hands model + + Returns: + (origins, velocity): A tuple containing an array of tuples representing hand origins, and an array of tuples containing hand velocitys + """ + + normalized_origin_offset = [] + hands_location_on_screen = [] + velocity = [] + + for hand in result.hand_world_landmarks: + # take middle finger knuckle + normalized_origin_offset.append(hand[9]) + + for index, hand in enumerate(result.hand_landmarks): + originX = hand[9].x - normalized_origin_offset[index].x + originY = hand[9].y - normalized_origin_offset[index].y + originZ = hand[9].z - normalized_origin_offset[index].z + hands_location_on_screen.append((originX, originY, originZ)) + velocityX = self.last_origin[index][0] - hands_location_on_screen[index][0] + velocityY = self.last_origin[index][1] - hands_location_on_screen[index][1] + velocity.append((velocityX, velocityY)) + self.last_origin = hands_location_on_screen + + return hands_location_on_screen, velocity + + def gesture_input(self, result, velocity): + """Converts Mediapipe landmarks and a velocity into a format usable by the gesture recognition model + + Args: + result (Mediapipe.hands.result): The result object returned by Mediapipe + velocity ([(float, float)]): An array of tuples containing the velocity of hands + + Returns: + array: An array of length 65 + """ + model_inputs = [] + + for index, hand in enumerate(result.hand_world_landmarks): + model_inputs.append([]) + for point in hand: + model_inputs[index].append(point.x) + model_inputs[index].append(point.y) + model_inputs[index].append(point.z) + if velocity != []: + model_inputs[index].append(velocity[index][0]) + model_inputs[index].append(velocity[index][1]) + + out = np.asarray(model_inputs, dtype="float32") + + return out diff --git a/source/gesture inference/GetHands.py b/source/gesture inference/GetHands.py new file mode 100644 index 000000000..6aafacdaf --- /dev/null +++ b/source/gesture inference/GetHands.py @@ -0,0 +1,206 @@ +# https://github.com/nrsyed/computer-vision/blob/master/multithread/VideoShow.py + +from threading import Thread +import mediapipe as mp +import time +import math +from FeedForward import NeuralNet +import traceback +from Console import GestureConsole +from Webcam import Webcam +import os + +abspath = os.path.abspath(__file__) +dname = os.path.dirname(abspath) +os.chdir(dname) + +class GetHands(Thread): + """ + Class that continuously gets frames and extracts hand data + with a dedicated thread and Mediapipe + """ + + def __init__( + self, + render_hands, + mediapipe_model="hand_landmarker.task", + gesture_model = "simple.pth", + control_mouse=None, + flags=None, + keyboard=None, + click_sensitinity=0.05, + ): + Thread.__init__(self) + + self.model_path = mediapipe_model + self.render_hands = render_hands + self.confidence = 0.5 + self.stopped = False + self.control_mouse = control_mouse + + self.flags = flags + self.click_sensitinity = click_sensitinity + self.keyboard = keyboard + self.console = GestureConsole() + self.camera = Webcam() + + self.gesture_model = NeuralNet(gesture_model) + self.gesture_list = self.gesture_model.labels + self.confidence_vectors = self.gesture_model.confidence_vector + self.gestures = ['no gesture'] + self.delay = 0 + + (self.grabbed, self.frame) = self.camera.read() + + self.timer = 0 + + self.build_model(flags["number_of_hands"]) + + def build_model(self, hands_num): + """Takes in option parameters for the Mediapipe hands model + + Args: + hands_num (int): max number of hands for Mediapipe to find + """ + # mediapipe setup + self.BaseOptions = mp.tasks.BaseOptions + self.HandLandmarker = mp.tasks.vision.HandLandmarker + self.HandLandmarkerOptions = mp.tasks.vision.HandLandmarkerOptions + self.VisionRunningMode = mp.tasks.vision.RunningMode + self.options = self.HandLandmarkerOptions( + base_options=self.BaseOptions(model_asset_path=self.model_path), + num_hands=hands_num, + min_hand_detection_confidence=self.confidence, + min_hand_presence_confidence=self.confidence, + min_tracking_confidence=self.confidence, + running_mode=self.VisionRunningMode.LIVE_STREAM, + result_callback=self.results_callback, + ) + + # build hands model + self.hands_detector = self.HandLandmarker.create_from_options(self.options) + + def move_mouse(self, location, button: str): + """Wrapper method to control the mouse + + Args: + hands_location_on_screen (origins): The origins result from find_velocity_and_location() + mouse_button_text (str): Type of click + """ + if callable(self.control_mouse): + if location != []: + # (0,0) is the top left corner + self.control_mouse( + location[0][0], + location[0][1], + button, + ) + + def results_callback( + self, + result: mp.tasks.vision.HandLandmarkerResult, + output_image: mp.Image, + timestamp_ms: int, + ): + # this try catch block is for debugging. this code runs in a different thread and doesn't automatically raise its own exceptions + try: + + if len(result.hand_world_landmarks) == 0: + self.render_hands( + result, + None, + None, + None, + ) + return + + location, velocity = self.gesture_model.find_velocity_and_location(result) + + if self.flags["run_model_flag"]: + + # get all the hands and format them + model_inputs = self.gesture_model.gesture_input(result, velocity) + + # for some reason parrellization with batches makes the model super slow + # if len(model_inputs) > 0: + # self.confidence_vector, indexs = self.gesture_model.get_gesture_confidence(model_inputs) + # # only take inputs from the first hand, subsequent hands can't control the keyboard + # self.keyboard.gesture_input(self.confidence_vector[0]) + + # serialized input + hand_confidences = [] #prepare data for console table + gestures = [] #store gesture output as text + for index, hand in enumerate(model_inputs): + confidences, predicted, predicted_confidence = ( + self.gesture_model.get_gesture([hand], print_table=False) + ) + gestures.append(self.gesture_list[predicted[0]]) # save gesture + hand_confidences.append(confidences[0]) + # only take inputs from the first hand, subsequent hands can't control the keyboard + + self.gestures = gestures + self.confidence_vectors = hand_confidences + self.keyboard.gesture_input(self.confidence_vectors[0]) + + self.console.table(self.gesture_list, hand_confidences) + + + if self.flags["move_mouse_flag"] and location != []: + mouse_button_text = "" + hand = result.hand_world_landmarks[0] + if self.is_clicking(hand[8], hand[4]): + mouse_button_text = "left" + self.move_mouse(location, mouse_button_text) + + # timestamps are in microseconds so convert to ms + + current_time = time.time() + self.delay = (current_time - self.timer) * 1000 + self.timer = current_time + + self.render_hands( + result, + self.flags["render_hands_mode"], + location, + velocity, + ) + + except Exception as e: + traceback.print_exc() + quit() + + def is_clicking(self, tip1, tip2): + distance = math.sqrt( + (tip1.x - tip2.x) ** 2 + (tip1.y - tip2.y) ** 2 + (tip1.z - tip2.z) ** 2 + ) + if distance < self.click_sensitinity: + return True + else: + return False + + def run(self): + """Continuously grabs new frames from the webcam and uses Mediapipe to detect hands""" + while not self.stopped: + if not self.grabbed: + self.camera.stop() + self.stop() + else: + (self.grabbed, self.frame) = self.camera.read() + + + # Detect hand landmarks + self.detect_hands(self.frame) + + def detect_hands(self, frame): + """Wrapper function for Mediapipe's hand detector in livestream mode + + Args: + frame (cv2.image): OpenCV webcam frame + """ + mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=frame) + self.hands_detector.detect_async( + mp_image, mp.Timestamp.from_seconds(time.time()).value + ) + + def stop(self): + self.stopped = True diff --git a/source/gesture recognition/Keyboard.py b/source/gesture inference/Keyboard.py similarity index 92% rename from source/gesture recognition/Keyboard.py rename to source/gesture inference/Keyboard.py index ebd82e519..64227798e 100644 --- a/source/gesture recognition/Keyboard.py +++ b/source/gesture inference/Keyboard.py @@ -1,6 +1,7 @@ import pyautogui import time from Console import GestureConsole +import numpy as np class Keyboard: def __init__( self, threshold=0.0, toggle_key_threshold=0.15, toggle_key_toggle_time=1, toggle_mouse_func=None, console=None @@ -28,6 +29,17 @@ def __init__( self.toggle_key_toggle_time = toggle_key_toggle_time self.console = GestureConsole() + def gesture_input(self, confidences): + + max_value = np.max(confidences) + max_index = np.argmax(confidences) + if max_index == 0: + self.press("space") + elif max_index == 1: + self.press("none") + elif max_index == 2: + self.press("toggle") + def press(self, key: str): current_time = time.time() # if it has been longer than threshold time diff --git a/source/gesture recognition/Mouse.py b/source/gesture inference/Mouse.py similarity index 100% rename from source/gesture recognition/Mouse.py rename to source/gesture inference/Mouse.py diff --git a/source/gesture recognition/README.md b/source/gesture inference/README.md similarity index 100% rename from source/gesture recognition/README.md rename to source/gesture inference/README.md diff --git a/source/gesture inference/RenderHands.py b/source/gesture inference/RenderHands.py new file mode 100644 index 000000000..4f1d3af86 --- /dev/null +++ b/source/gesture inference/RenderHands.py @@ -0,0 +1,159 @@ +import pygame + + +class RenderHands: + """Given the Mediapipe hands output data, renders the hands in a normilzed view or camera perspective view on a pygame surface""" + + def __init__(self, surface, render_scale=3): + """Create Render Hand object using a pygame surface and a scaling factor + + Args: + surface (pygame.surface): pygame surface to render a hand on + hand_scale (float): multiplier to change the size at which the hand is rendered at + """ + self.surface = surface + self.hand_scale = render_scale + self.font = pygame.font.Font("freesansbold.ttf", 30) + self.last_velocity = [(0.5, 0.5)] + + def set_render_scale(self, scale:float): + self.hand_scale = scale + + def connections(self, landmarks, mode): + """Renders lines between hand joints + + Args: + landmarks (results): requires the direct output from mediapipe + mode (bool): render either normalized or perspective + """ + + xy = [] + + w, h = self.surface.get_size() + + for index, hand in enumerate(landmarks): + xy.append([]) + for point in hand: + if mode: + xy[index].append( + ( + (point.x * self.hand_scale + 0.5) * w, + (point.y * self.hand_scale + 0.5) * h, + ) + ) + else: + xy[index].append( + ( + point.x * w, + point.y * h, + ) + ) + + for hand in range(len(xy)): + # thumb + self.render_line(xy[hand][0], xy[hand][1]) + self.render_line(xy[hand][1], xy[hand][2]) + self.render_line(xy[hand][2], xy[hand][3]) + self.render_line(xy[hand][3], xy[hand][4]) + # index + self.render_line(xy[hand][0], xy[hand][5]) + self.render_line(xy[hand][5], xy[hand][6]) + self.render_line(xy[hand][6], xy[hand][7]) + self.render_line(xy[hand][7], xy[hand][8]) + # middle + self.render_line(xy[hand][9], xy[hand][10]) + self.render_line(xy[hand][10], xy[hand][11]) + self.render_line(xy[hand][11], xy[hand][12]) + # ring + self.render_line(xy[hand][13], xy[hand][14]) + self.render_line(xy[hand][14], xy[hand][15]) + self.render_line(xy[hand][15], xy[hand][16]) + # pinky + self.render_line(xy[hand][0], xy[hand][17]) + self.render_line(xy[hand][17], xy[hand][18]) + self.render_line(xy[hand][18], xy[hand][19]) + self.render_line(xy[hand][19], xy[hand][20]) + # knuckle + self.render_line(xy[hand][5], xy[hand][9]) + self.render_line(xy[hand][9], xy[hand][13]) + self.render_line(xy[hand][13], xy[hand][17]) + + def render_line(self, start, end): + """Wrapper function for pygame's render line. Will render a white line with width=5 + + Args: + start (int): line start position + end (int): line end position + """ + pygame.draw.line(self.surface, (255, 255, 255), start, end, 5) + + def render_hands( + self, result, mode, origins, velocity + ): + + self.surface.fill((0, 0, 0)) + # Render hand landmarks + + w, h = self.surface.get_size() + if result.handedness != []: + if mode: + hand_points = result.hand_world_landmarks + pygame.draw.circle(self.surface, (255, 0, 255), (0.5 * w, 0.5 * h), 5) + pygame.draw.line( + self.surface, + (255, 255, 0), + ((velocity[0][0] + 0.5) * w, (velocity[0][1] + 0.5) * h), + ((0.5) * w, (0.5) * h), + 3, + ) + self.last_velocity = velocity + + else: + hand_points = result.hand_landmarks + + self.connections(hand_points, mode) + if hand_points: + # define colors for different hands + + hand_color = 0 + colors = [(0, 0, 255), (0, 255, 0), (255, 0, 0), (255, 255, 255)] + # get every hand detected + for index, hand in enumerate(hand_points): + # each hand has 21 landmarks + pygame.draw.circle( + self.surface, + (255, 0, 255), + (origins[index][0] * w, origins[index][1] * h), + 5, + ) + for landmark in hand: + self.__render_hands_pygame( + colors[hand_color], + landmark.x, + landmark.y, + mode, + ) + hand_color += 1 + + def __render_hands_pygame(self, color, x, y, mode): + """Renders a single landmark of a hand in pygame and scales the hand. + + Args: + color (rgb()): color of points in hand + x (float): x coordinant of a point + y (float): y coordinant of a point + surface (pygame.surface): surface to render a hand on + delay_ms ((float, float)): contains webcam latency and Mediapipe hands model latency + mode (bool): True to render in normalized mode. False for world coordinates + """ + + w, h = self.surface.get_size() + + if mode: + x *= self.hand_scale + y *= self.hand_scale + x += 0.5 + y += 0.5 + + pygame.draw.circle(self.surface, color, (x * w, y * h), 5) + diff --git a/source/gesture inference/Webcam.py b/source/gesture inference/Webcam.py new file mode 100644 index 000000000..965b0ddc5 --- /dev/null +++ b/source/gesture inference/Webcam.py @@ -0,0 +1,32 @@ +# https://github.com/nrsyed/computer-vision/blob/master/multithread/VideoShow.py + +import cv2 +class Webcam: + """ + Class that continuously gets frames and extracts hand data + with a dedicated thread and Mediapipe + """ + def __init__( + self, + webcam_id=0, + ): + self.webcam_id = webcam_id + self.start() + + def start(self): + # OpenCV setup + self.stream = cv2.VideoCapture(self.webcam_id) + # motion JPG format + self.stream.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc("M", "J", "P", "G")) + (self.grabbed, self.frame) = self.stream.read() + self.frame = cv2.flip(self.frame, 1) + + def stop(self): + self.stream.release() + cv2.destroyAllWindows() + + def read(self, flip=True): + (self.grabbed, self.frame) = self.stream.read() + if flip: + self.frame = cv2.flip(self.frame, 1) + return (self.grabbed, self.frame) \ No newline at end of file diff --git a/source/gesture inference/hand_landmarker.task b/source/gesture inference/hand_landmarker.task new file mode 100644 index 000000000..0d53faf37 Binary files /dev/null and b/source/gesture inference/hand_landmarker.task differ diff --git a/source/gesture inference/inference.py b/source/gesture inference/inference.py new file mode 100644 index 000000000..e30901f7b --- /dev/null +++ b/source/gesture inference/inference.py @@ -0,0 +1,230 @@ +# https://developers.google.com/mediapipe/framework/getting_started/gpu_support +# https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html +# https://pygame-menu.readthedocs.io/en/latest/_source/add_widgets.html +import pydoc +import pygame +import pygame_menu +from GetHands import GetHands +from RenderHands import RenderHands +from Mouse import Mouse +from Keyboard import Keyboard +import os +from Console import GestureConsole + +abspath = os.path.abspath(__file__) +dname = os.path.dirname(abspath) +os.chdir(dname) + +# global variables +pygame.init() +font = pygame.font.Font("freesansbold.ttf", 30) +clock = pygame.time.Clock() + +flags = { + "render_hands_mode": False, + "gesture_vector": [], + "number_of_hands": 2, + "move_mouse_flag": False, + "run_model_flag": True, +} + +console = GestureConsole() + + +def main() -> None: + """Main driver method which initilizes all children and starts pygame render pipeline""" + + window_width = 1200 + window_height = 1000 + window = pygame.display.set_mode((window_width, window_height), pygame.RESIZABLE) + pygame.display.set_caption("Test Hand Tracking Multithreaded") + + hands_surface = pygame.Surface((window_width, window_height)) + hands_surface.set_colorkey((0, 0, 0)) + + myRenderHands = RenderHands(hands_surface, render_scale=3) + + mouse_controls = Mouse(mouse_sensitivity=2) + + keyboard = Keyboard( + threshold=0, toggle_key_threshold=0.3, toggle_mouse_func=toggle_mouse + ) + + # control_mouse=mouse_controls.control, + hands = GetHands( + myRenderHands.render_hands, + control_mouse=mouse_controls.control, + flags=flags, + keyboard=keyboard, + ) + + menu = pygame_menu.Menu( + "Esc to toggle menu", + window_width * 0.5, + window_height * 0.5, + theme=pygame_menu.themes.THEME_BLUE, + ) + + menu.add.selector( + "Render Mode :", [("World", False), ("Normalized", True)], onchange=set_coords + ) + + def change_hands_num(value): + + flags["number_of_hands"] = value[1] + 1 + nonlocal hands + hands.stop() + hands.join() + hands = GetHands( + myRenderHands.render_hands, + control_mouse=mouse_controls.control, + flags=flags, + keyboard=keyboard, + ) + hands.start() + + menu.add.dropselect( + "Number of hands :", ["1", "2", "3", "4"], onchange=change_hands_num + ) + + menu.add.button("Close Menu", pygame_menu.events.CLOSE) + menu.add.button("Turn On Model", action=toggle_model) + menu.add.button("Turn On Mouse", action=toggle_mouse) + menu.add.button("Quit", pygame_menu.events.EXIT) + menu.enable() + + game_loop( + window, + hands, + hands_surface, + menu, + ) + + pygame.quit() + + +def toggle_mouse() -> None: + """Enable or disable mouse control""" + console.print("toggling mouse control") + flags["move_mouse_flag"] = not flags["move_mouse_flag"] + + +def toggle_model() -> None: + console.print("toggling model") + flags["run_model_flag"] = not flags["run_model_flag"] + + +def set_coords(value, mode) -> None: + """Defines the coordinate space for rendering hands + + Args: + value (_type_): used by pygame_menu + mode (_type_): True for normalized, False for world + """ + flags["render_hands_mode"] = mode + + +def game_loop( + window: pygame.display, + hands: GetHands, + hands_surface: pygame.Surface, + menu: pygame_menu.Menu, +): + """Runs the pygame event loop and renders surfaces + + Args: + window (_type_): The main pygame window + hands (_type_): The GetHands class + hands_surface (_type_): The surface that the hands are rendered on + menu (_type_): the main menu + """ + hands.start() + running = True + is_menu_showing = True + is_webcam_fullscreen = False + + is_fullscreen = False + + while running: + # window_width, window_height = menu.get_window_size() # what? idk + window_width, window_height = pygame.display.get_surface().get_size() + window.fill((0, 0, 0)) + events = pygame.event.get() + for event in events: + if event.type == pygame.QUIT: + hands.stop() + hands.join() + running = False + if event.type == pygame.KEYDOWN: + + if event.key == pygame.K_m: + toggle_mouse() + + if event.key == pygame.K_ESCAPE: + if is_menu_showing: + is_menu_showing = False + menu.disable() + else: + is_menu_showing = True + menu.enable() + + if event.key == pygame.K_F1: + is_webcam_fullscreen = not is_webcam_fullscreen + + if event.key == pygame.K_F11: + is_fullscreen = not is_fullscreen + pygame.display.toggle_fullscreen() + + # frames per second + fps = font.render( + str(round(clock.get_fps(), 1)) + "fps", False, (255, 255, 255) + ) + + frame = hands.frame + img_pygame = pygame.image.frombuffer( + frame.tostring(), frame.shape[1::-1], "BGR" + ) + img_width = img_pygame.get_width() + img_height = img_pygame.get_height() + + hand_surface_copy = pygame.transform.scale( + hands_surface.copy(), (img_width * 0.5, img_height * 0.5) + ) + img_pygame = pygame.transform.scale( + img_pygame, (img_width * 0.5, img_height * 0.5) + ) + + + + if is_webcam_fullscreen: + img_pygame = pygame.transform.scale( + img_pygame, (window_width, window_height) + ) + hand_surface_copy = pygame.transform.scale( + hands_surface.copy(), (window_width, window_height) + ) + + window.blit(img_pygame, (0, 0)) + + for index in range(len(hands.gestures)): + gesture_text = font.render(hands.gestures[index], False, (255, 255, 255)) + window.blit(gesture_text, (window_width - window_width // 5, index * 40)) + + delay_AI = font.render( + str(round(hands.delay, 1)) + "ms", False, (255, 255, 255) + ) + window.blit(fps, (0, 0)) + window.blit(delay_AI, (0, 40)) + + if menu.is_enabled(): + menu.update(events) + menu.draw(window) + + window.blit(hand_surface_copy, (0, 0)) + + clock.tick(60) + pygame.display.update() + + +if __name__ == "__main__": + main() diff --git a/source/gesture inference/motion.pth b/source/gesture inference/motion.pth new file mode 100644 index 000000000..4f0d9853e Binary files /dev/null and b/source/gesture inference/motion.pth differ diff --git a/source/gesture inference/simple.pth b/source/gesture inference/simple.pth new file mode 100644 index 000000000..7d84a84ae Binary files /dev/null and b/source/gesture inference/simple.pth differ diff --git a/source/gesture recognition/FeedForward.py b/source/gesture recognition/FeedForward.py deleted file mode 100644 index 17709613e..000000000 --- a/source/gesture recognition/FeedForward.py +++ /dev/null @@ -1,54 +0,0 @@ -import torch.nn as nn -import torch -class NeuralNet(nn.Module): - - def __init__(self, modelName): - # Device configuration - """Defines the model architecture - - Args: - modelName (_type_): PyTorch model with extra parameters defining the model parameters - """ - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model, data = torch.load(modelName, map_location=device) - super(NeuralNet, self).__init__() - input_size = data[0] - hidden_size = data[1] - num_classes = data[2] - self.labels = data[3] - self.input_size = input_size - self.l1 = nn.Linear(input_size, hidden_size) - self.relu = nn.ReLU() - self.l2 = nn.Linear(hidden_size, num_classes) - self.load_state_dict(model) - self.eval() - - def forward(self, x): - """Runs a forward pass of the gesture model - - Args: - x (_type_): _description_ - - Returns: - _type_: _description_ - """ - out = self.l1(x) - out = self.relu(out) - out = self.l2(out) - # no activation and no softmax at the end - return out - - def get_gesture(self, model_input): - """Callable method to run a forward pass of the gesture recognition model. - - Args: - model_input (_type_): Hand location data with velocity - - Returns: - _type_: single softmax output of the model with a confidence value for that classification - """ - hands = torch.from_numpy(model_input) - outputs = self(hands) - probs = torch.nn.functional.softmax(outputs.data, dim=1) - confidence, classes = torch.max(probs, 1) - return confidence.numpy(), classes.numpy() diff --git a/source/gesture recognition/SimpleModel.pth b/source/gesture recognition/SimpleModel.pth deleted file mode 100644 index 9d3831151..000000000 Binary files a/source/gesture recognition/SimpleModel.pth and /dev/null differ diff --git a/source/gesture recognition/make_pydoc.py b/source/gesture recognition/make_pydoc.py deleted file mode 100644 index 15a47af62..000000000 --- a/source/gesture recognition/make_pydoc.py +++ /dev/null @@ -1,128 +0,0 @@ -import pydoc - -pydoc.writedoc("main") -pydoc.writedoc("FeedForward") -pydoc.writedoc("GetHands") -pydoc.writedoc("RenderHands") -pydoc.writedoc("Writer") -pydoc.writedoc("Mouse") - -""" - - - - -""" \ No newline at end of file diff --git a/source/gesture recognition/Train.py b/source/train/Train.py similarity index 87% rename from source/gesture recognition/Train.py rename to source/train/Train.py index d5a3a924f..3b4ddfdf3 100644 --- a/source/gesture recognition/Train.py +++ b/source/train/Train.py @@ -6,24 +6,30 @@ import numpy as np import matplotlib.pyplot as plt import csv +import os + +abspath = os.path.abspath(__file__) +dname = os.path.dirname(abspath) +os.chdir(dname) # Device configuration device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Hyper-parameters input_size = 65 -hidden_size = 50 -num_epochs = 40 +hidden_size = 150 +num_epochs = 50 batch_size = 100 learning_rate = 0.001 -filename = "simple.csv" -labels = None +filename = "data.csv" +labels_list = None num_classes = 0 with open(filename, "r", newline="", encoding="utf-8") as dataset_file: - labels = next(csv.reader(dataset_file)) - num_classes = len(labels) - print("labels: "+str(labels)) + labels_list = next(csv.reader(dataset_file)) + num_classes = len(labels_list) + +print("labels: "+str(labels_list)) class HandDataset(Dataset): @@ -47,8 +53,8 @@ def __len__(self): # data size is 3868 train_dataset, test_dataset = (dataset, dataset) -train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=False) -test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False) +train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True) +test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True) # examples = iter(test_loader) # example_data, example_targets = next(examples) @@ -108,10 +114,9 @@ def forward(self, x): loss.backward() optimizer.step() - if (i + 1) % 100 == 0: - print( - f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss: {loss.item():.4f}" - ) + print( + f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss: {loss.item():.4f}" + ) # Test the model # In test phase, we don't need to compute gradients (for memory efficiency) @@ -130,4 +135,4 @@ def forward(self, x): acc = 100.0 * n_correct / n_samples print(f"Accuracy of the network on {int(dataset.__len__())+1} training dataset: {round(acc,3)} %") -torch.save((model.state_dict(),[input_size, hidden_size, num_classes, labels]), "SimpleModel.pth") \ No newline at end of file +torch.save((model.state_dict(),[input_size, hidden_size, num_classes, labels_list]), "simple.pth") \ No newline at end of file