From 4497af12c824f771d8c93e86992181648427ad73 Mon Sep 17 00:00:00 2001 From: David Lin Date: Mon, 13 May 2024 13:15:01 -0700 Subject: [PATCH] Add base for sgd optimizer (#3496) Summary: This adds the sgd_optimizer header to executorch. would appreciate some thoughts on where to place this file. Reviewed By: JacobSzwejbka Differential Revision: D56888378 --- extension/training/optimizer/TARGETS | 8 +++ extension/training/optimizer/sgd.h | 49 +++++++++++++++++++ extension/training/optimizer/targets.bzl | 20 ++++++++ extension/training/optimizer/test/TARGETS | 8 +++ .../training/optimizer/test/sgd_test.cpp | 28 +++++++++++ extension/training/optimizer/test/targets.bzl | 18 +++++++ 6 files changed, 131 insertions(+) create mode 100644 extension/training/optimizer/TARGETS create mode 100644 extension/training/optimizer/sgd.h create mode 100644 extension/training/optimizer/targets.bzl create mode 100644 extension/training/optimizer/test/TARGETS create mode 100644 extension/training/optimizer/test/sgd_test.cpp create mode 100644 extension/training/optimizer/test/targets.bzl diff --git a/extension/training/optimizer/TARGETS b/extension/training/optimizer/TARGETS new file mode 100644 index 00000000000..2341af9282f --- /dev/null +++ b/extension/training/optimizer/TARGETS @@ -0,0 +1,8 @@ +# Any targets that should be shared between fbcode and xplat must be defined in +# targets.bzl. This file can contain fbcode-only targets. + +load(":targets.bzl", "define_common_targets") + +oncall("executorch") + +define_common_targets() diff --git a/extension/training/optimizer/sgd.h b/extension/training/optimizer/sgd.h new file mode 100644 index 00000000000..a5f46b44066 --- /dev/null +++ b/extension/training/optimizer/sgd.h @@ -0,0 +1,49 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/** + * SGD (stochastic gradient descent) optimizer to perform on-device training. + * This uses the gradients calculated in the backwards pass of the loss function + * and updates the parameters such that it minimizes the loss. + * + * This is similar to the Lite Interpreter implementation of the SGD optimizer, + * but without the dependency on ATen Tensors and autograd. + */ +#pragma once + +namespace torch { +namespace executor { +namespace optimizer { + +/** + * SGD optimizer state. This keeps track of the state of a given parameter to + * be used in later epochs. + */ +class SGDParamState {}; + +/** + * SGD optimizer options. This contains options for performing training on a + * param group, such as the learning rate. + */ +class SGDOptions {}; + +/** + * SGD optimizer param group. This contains the parameters and + * the OptimizerOptions associated to it. + */ +class SGDParamGroup {}; + +/** + * SGD optimizer class. This is responsible for performing the optimization + * step. + */ +class SGD {}; + +} // namespace optimizer +} // namespace executor +} // namespace torch diff --git a/extension/training/optimizer/targets.bzl b/extension/training/optimizer/targets.bzl new file mode 100644 index 00000000000..ffe8e30d7b6 --- /dev/null +++ b/extension/training/optimizer/targets.bzl @@ -0,0 +1,20 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(): + """Defines targets that should be shared between fbcode and xplat. + + The directory containing this targets.bzl file should also contain both + TARGETS and BUCK files that call this function. + """ + + runtime.cxx_library( + name = "optimizer", + exported_headers = [ + "sgd.h", + ], + exported_deps = [ + ], + visibility = [ + "@EXECUTORCH_CLIENTS", + ], + ) diff --git a/extension/training/optimizer/test/TARGETS b/extension/training/optimizer/test/TARGETS new file mode 100644 index 00000000000..2341af9282f --- /dev/null +++ b/extension/training/optimizer/test/TARGETS @@ -0,0 +1,8 @@ +# Any targets that should be shared between fbcode and xplat must be defined in +# targets.bzl. This file can contain fbcode-only targets. + +load(":targets.bzl", "define_common_targets") + +oncall("executorch") + +define_common_targets() diff --git a/extension/training/optimizer/test/sgd_test.cpp b/extension/training/optimizer/test/sgd_test.cpp new file mode 100644 index 00000000000..1d35e43458f --- /dev/null +++ b/extension/training/optimizer/test/sgd_test.cpp @@ -0,0 +1,28 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +using namespace ::testing; +using namespace torch::executor::optimizer; + +class SGDOptimizerTest : public ::testing::Test {}; + +TEST_F(SGDOptimizerTest, InstantiateTypes) { + SGDParamState state; + SGDOptions options; + SGDParamGroup param_group; + SGD sgd; + + EXPECT_TRUE(dynamic_cast(&state) != nullptr); + EXPECT_TRUE(dynamic_cast(&options) != nullptr); + EXPECT_TRUE(dynamic_cast(¶m_group) != nullptr); + EXPECT_TRUE(dynamic_cast(&sgd) != nullptr); +} diff --git a/extension/training/optimizer/test/targets.bzl b/extension/training/optimizer/test/targets.bzl new file mode 100644 index 00000000000..9d380f90a14 --- /dev/null +++ b/extension/training/optimizer/test/targets.bzl @@ -0,0 +1,18 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(): + """Defines targets that should be shared between fbcode and xplat. + + The directory containing this targets.bzl file should also contain both + TARGETS and BUCK files that call this function. + """ + + runtime.cxx_test( + name = "sgd_test", + srcs = [ + "sgd_test.cpp", + ], + deps = [ + "//executorch/extension/training/optimizer:optimizer", + ], + )