Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/pcg/include/pcg/device_id.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace FlexFlow {

device_id_t operator+(device_id_t, size_t);

DeviceType get_device_type(device_id_t);
DeviceType get_device_type(device_id_t const &device_id);
gpu_id_t unwrap_gpu(device_id_t);
cpu_id_t unwrap_cpu(device_id_t);

Expand Down
19 changes: 18 additions & 1 deletion lib/pcg/include/pcg/machine_view.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define _FLEXFLOW_PCG_INCLUDE_PCG_MACHINE_VIEW_H

#include "pcg/cpu_id_t.dtg.h"
#include "pcg/device_id.h"
#include "pcg/device_id_t.dtg.h"
#include "pcg/device_type.dtg.h"
#include "pcg/gpu_id_t.dtg.h"
Expand All @@ -14,15 +15,31 @@
namespace FlexFlow {

std::vector<device_id_t> device_ids(MachineView const &);
std::size_t num_dims(MachineView const &);
size_t num_dims(MachineView const &);
std::size_t num_devices(MachineView const &);
DeviceType get_device_type(MachineView const &);

MachineView make_1d_machine_view(gpu_id_t start, gpu_id_t stop, int stride = 1);
MachineView make_1d_machine_view(cpu_id_t start, cpu_id_t stop, int stride = 1);
MachineView
make_1d_machine_view(device_id_t start, device_id_t stop, int stride = 1);

MachineView make_1d_machine_view(gpu_id_t start,
num_points_t num_points,
int stride = 1);
MachineView make_1d_machine_view(cpu_id_t start,
num_points_t num_points,
int stride = 1);
MachineView make_1d_machine_view(device_id_t start,
num_points_t num_points,
int stride = 1);

MachineView make_1d_machine_view(gpu_id_t start,
side_size_t interval_size,
int stride = 1);
MachineView make_1d_machine_view(cpu_id_t start,
side_size_t interval_size,
int stride = 1);
MachineView make_1d_machine_view(device_id_t start,
side_size_t interval_size,
int stride = 1);
Expand Down
5 changes: 3 additions & 2 deletions lib/pcg/include/pcg/strided_rectangle.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
namespace FlexFlow {

size_t get_num_dims(StridedRectangle const &);
StridedRectangleSide get_side_at_idx(StridedRectangle const &,
ff_dim_t const &);
StridedRectangleSide get_side_at_idx(StridedRectangle const &rect,
ff_dim_t const &idx);
num_points_t get_num_points(StridedRectangle const &rect);

} // namespace FlexFlow

Expand Down
75 changes: 66 additions & 9 deletions lib/pcg/src/pcg/machine_view.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "pcg/machine_view.h"
#include "pcg/device_id.h"
#include "pcg/strided_rectangle.dtg.h"
#include "pcg/strided_rectangle.h"
#include "pcg/strided_rectangle_side.h"

namespace FlexFlow {
Expand All @@ -8,16 +10,16 @@ std::vector<device_id_t> device_ids(MachineView const &) {
NOT_IMPLEMENTED();
}

std::size_t num_dims(MachineView const &) {
NOT_IMPLEMENTED();
std::size_t num_dims(MachineView const &mv) {
return get_num_dims(mv.rect);
}

std::size_t num_devices(MachineView const &) {
NOT_IMPLEMENTED();
size_t num_devices(MachineView const &mv) {
return get_num_points(mv.rect).unwrapped;
}

DeviceType get_device_type(MachineView const &) {
NOT_IMPLEMENTED();
DeviceType get_device_type(MachineView const &mv) {
return get_device_type(mv.start);
}

static StridedRectangle make_1d_rect(int start, int stop, int stride) {
Expand All @@ -40,18 +42,73 @@ MachineView make_1d_machine_view(cpu_id_t start, cpu_id_t stop, int stride) {
return MachineView{device_id_t{start}, rect};
}

MachineView
make_1d_machine_view(device_id_t start, device_id_t stop, int stride) {
assert(get_device_type(start) == get_device_type(stop));
if (get_device_type(start) == DeviceType::CPU) {
return make_1d_machine_view(unwrap_cpu(start), unwrap_cpu(stop), stride);
}
assert(get_device_type(start) == DeviceType::GPU);
return make_1d_machine_view(unwrap_gpu(start), unwrap_gpu(stop), stride);
}

static StridedRectangle
make_1d_rect(int start, num_points_t num_points, int stride) {
return make_1d_rect(start, start + num_points.unwrapped * stride, stride);
}

MachineView
make_1d_machine_view(cpu_id_t start, num_points_t num_points, int stride) {
StridedRectangle rect = make_1d_rect(start.cpu_index, num_points, stride);
return MachineView{device_id_t{start}, rect};
}

MachineView
make_1d_machine_view(gpu_id_t start, num_points_t num_points, int stride) {
StridedRectangle rect = make_1d_rect(start.gpu_index, num_points, stride);
return MachineView{device_id_t{start}, rect};
}

MachineView make_1d_machine_view(device_id_t start,
num_points_t num_points,
int stride) {
NOT_IMPLEMENTED();
if (get_device_type(start) == DeviceType::CPU) {
return make_1d_machine_view(unwrap_cpu(start), num_points, stride);
} else {
assert(get_device_type(start) == DeviceType::GPU);
return make_1d_machine_view(unwrap_gpu(start), num_points, stride);
}
}

MachineView make_1d_machine_view(device_id_t start,
static StridedRectangle
make_1d_rect(int start, side_size_t interval_size, int stride) {
return make_1d_rect(start, start + interval_size.unwrapped, stride);
}

MachineView make_1d_machine_view(cpu_id_t start,
side_size_t interval_size,
int stride) {
NOT_IMPLEMENTED();
StridedRectangle rect = make_1d_rect(start.cpu_index, interval_size, stride);
return MachineView{device_id_t{start}, rect};
}

MachineView make_1d_machine_view(gpu_id_t start,
side_size_t interval_size,
int stride) {
StridedRectangle rect = make_1d_rect(start.gpu_index, interval_size, stride);
return MachineView{device_id_t{start}, rect};
}
MachineView make_1d_machine_view(device_id_t start,
side_size_t interval_size,
int stride) {

if (get_device_type(start) == DeviceType::CPU) {
return make_1d_machine_view(unwrap_cpu(start), interval_size, stride);
} else {
assert(get_device_type(start) == DeviceType::GPU);
return make_1d_machine_view(unwrap_gpu(start), interval_size, stride);
}
}
MachineView make_1d_machine_view(device_id_t start, size_t interval_size) {
NOT_IMPLEMENTED();
}
Expand Down
6 changes: 4 additions & 2 deletions lib/pcg/src/pcg/strided_rectangle_side.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@

namespace FlexFlow {

StridedRectangleSide strided_side_from_size_and_stride(side_size_t,
StridedRectangleSide strided_side_from_size_and_stride(side_size_t side_size,
int stride) {
NOT_IMPLEMENTED();
assert((side_size.unwrapped % stride) == 0);
return StridedRectangleSide{num_points_t{side_size.unwrapped / stride},
stride};
}

side_size_t get_side_size(StridedRectangleSide const &s) {
Expand Down
17 changes: 13 additions & 4 deletions lib/pcg/src/strided_rectangle.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "pcg/strided_rectangle.h"
#include "op-attrs/dim_ordered/transform.h"
#include "utils/containers.h"

namespace FlexFlow {
Expand All @@ -15,12 +16,20 @@ namespace FlexFlow {
/* return idx; */
/* } */

size_t get_num_dims(StridedRectangle const &) {
NOT_IMPLEMENTED();
size_t get_num_dims(StridedRectangle const &rect) {
return rect.sides.size();
}

size_t get_side_at_idx(StridedRectangle const &) {
NOT_IMPLEMENTED();
num_points_t get_num_points(StridedRectangle const &rect) {
return num_points_t{
product(transform(rect.sides, [](StridedRectangleSide const &side) {
return side.num_points.unwrapped;
}))};
}

StridedRectangleSide get_side_at_idx(StridedRectangle const &rect,
ff_dim_t const &idx) {
return rect.sides.at(idx);
}

} // namespace FlexFlow
74 changes: 74 additions & 0 deletions lib/pcg/test/src/test_machine_view.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#include "doctest/doctest.h"
#include "pcg/machine_view.h"
#include "pcg/strided_rectangle.h"
#include "pcg/strided_rectangle_side.h"

TEST_SUITE(FF_TEST_SUITE) {
TEST_CASE("MachineView general util functions") {
StridedRectangle rect{{StridedRectangleSide{num_points_t{7}, 5},
StridedRectangleSide{num_points_t{10}, 2}}};
gpu_id_t start(1);
MachineView mv{device_id_t{start}, rect};
SUBCASE("num_dims") {
CHECK(num_dims(mv) == 2);
}
SUBCASE("num_devices") {
CHECK(num_devices(mv) == 7 * 10);
}
SUBCASE("get_device_type") {
CHECK(get_device_type(mv) == DeviceType::GPU);
}
}

TEST_CASE("MachineView make_1d_machine_view - GPU") {
StridedRectangle rect{{StridedRectangleSide{num_points_t{7}, 5}}};
device_id_t start_gpu{gpu_id_t{1}};
MachineView gpu_mv{start_gpu, rect};

SUBCASE("make_1d_machine_view(gpu_id_t start, gpu_id_t stop, int stride)") {
MachineView result =
make_1d_machine_view(start_gpu, device_id_t{gpu_id_t(1 + 7 * 5)}, 5);
MachineView correct = gpu_mv;
CHECK(result == correct);
}
SUBCASE("make_1d_machine_view(gpu_id_t start, num_points_t num_points, int "
"stride)") {
MachineView result = make_1d_machine_view(start_gpu, num_points_t{7}, 5);
MachineView correct = gpu_mv;
CHECK(result == correct);
}
SUBCASE("make_1d_machine_view(gpu_id_t start, side_size_t interval_size, "
"int stride)") {
MachineView result = make_1d_machine_view(
start_gpu, get_side_size(rect.sides.at(ff_dim_t{0})), 5);
MachineView correct = gpu_mv;
CHECK(result == correct);
}
}

TEST_CASE("MachineView make_1d_machine_view - CPU") {
StridedRectangle rect{{StridedRectangleSide{num_points_t{11}, 4}}};
device_id_t start_cpu{cpu_id_t{2}};
MachineView cpu_mv{start_cpu, rect};

SUBCASE("make_1d_machine_view(cpu_id_t start, cpu_id_t stop, int stride)") {
MachineView result =
make_1d_machine_view(start_cpu, device_id_t{cpu_id_t(2 + 11 * 4)}, 4);
MachineView correct = cpu_mv;
CHECK(result == correct);
}
SUBCASE("make_1d_machine_view(cpu_id_t start, num_points_t num_points, int "
"stride)") {
MachineView result = make_1d_machine_view(start_cpu, num_points_t{11}, 4);
MachineView correct = cpu_mv;
CHECK(result == correct);
}
SUBCASE("make_1d_machine_view(cpu_id_t start, side_size_t interval_size, "
"int stride)") {
MachineView result = make_1d_machine_view(
start_cpu, get_side_size(rect.sides.at(ff_dim_t{0})), 4);
MachineView correct = cpu_mv;
CHECK(result == correct);
}
}
}
37 changes: 37 additions & 0 deletions lib/pcg/test/src/test_strided_rectangle.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#include "doctest/doctest.h"
#include "pcg/strided_rectangle.h"
#include "pcg/strided_rectangle_side.h"

TEST_SUITE(FF_TEST_SUITE) {
TEST_CASE("get_side_size(StridedRectangleSide)") {
StridedRectangleSide side{num_points_t{7}, 5};

CHECK(get_side_size(side) == side_size_t{7 * 5});
}
TEST_CASE("strided_side_from_size_and_stride") {
StridedRectangleSide correct{num_points_t{10}, 3};
StridedRectangleSide result =
strided_side_from_size_and_stride(side_size_t{10 * 3}, 3);
CHECK(result == correct);
}

TEST_CASE("StridedRectangle - helper functions") {

StridedRectangleSide s0{num_points_t{7}, 5};
StridedRectangleSide s1{num_points_t{10}, 2};
StridedRectangleSide s2{num_points_t{8}, 1};
StridedRectangle rect{{s0, s1, s2}};

SUBCASE("get_num_dims") {
CHECK(get_num_dims(rect) == 3);
}
SUBCASE("get_num_points") {
CHECK(get_num_points(rect) == num_points_t{7 * 8 * 10});
}
SUBCASE("get_side_at_idx") {
CHECK(get_side_at_idx(rect, ff_dim_t{0}) == s0);
CHECK(get_side_at_idx(rect, ff_dim_t{1}) == s1);
CHECK(get_side_at_idx(rect, ff_dim_t{2}) == s2);
}
}
}