Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/infiniop.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
#define __INFINIOP_API_H__

#include "infiniop/handle.h"
// Unified headers for elementwise operators
#include "infiniop/ops/unary_ops_api.h"
#include "infiniop/ops/binary_ops_api.h"
// Other operators
#include "infiniop/ops/add.h"
#include "infiniop/ops/add_rms_norm.h"
#include "infiniop/ops/attention.h"
Expand Down
50 changes: 50 additions & 0 deletions include/infiniop/ops/binary_op_api.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#ifndef __INFINIOP_BINARY_OP_API_H__
#define __INFINIOP_BINARY_OP_API_H__

#include "../operator_descriptor.h"

/**
* @brief Macro to generate the C API header for a binary operator.
*
* This macro generates all the necessary declarations for a binary operator:
* - Descriptor type definition
* - Create descriptor function
* - Get workspace size function
* - Execute operator function
* - Destroy descriptor function
*
* Usage:
* BINARY_OP_API_DECLARE(div, Div)
* BINARY_OP_API_DECLARE(pow, Pow)
*
* @param OP_NAME Lowercase operator name (e.g., div, pow, mod)
* @param OP_NAME_UPPER Uppercase operator name (e.g., Div, Pow, Mod)
*/
#define BINARY_OP_API_DECLARE(OP_NAME, OP_NAME_UPPER) \
\
typedef struct InfiniopDescriptor *infiniop##OP_NAME_UPPER##Descriptor_t; \
\
__C __export infiniStatus_t infiniopCreate##OP_NAME_UPPER##Descriptor( \
infiniopHandle_t handle, \
infiniop##OP_NAME_UPPER##Descriptor_t *desc_ptr, \
infiniopTensorDescriptor_t c, \
infiniopTensorDescriptor_t a, \
infiniopTensorDescriptor_t b); \
\
__C __export infiniStatus_t infiniopGet##OP_NAME_UPPER##WorkspaceSize( \
infiniop##OP_NAME_UPPER##Descriptor_t desc, \
size_t *size); \
\
__C __export infiniStatus_t infiniop##OP_NAME_UPPER( \
infiniop##OP_NAME_UPPER##Descriptor_t desc, \
void *workspace, \
size_t workspace_size, \
void *c, \
const void *a, \
const void *b, \
void *stream); \
\
__C __export infiniStatus_t infiniopDestroy##OP_NAME_UPPER##Descriptor( \
infiniop##OP_NAME_UPPER##Descriptor_t desc);

#endif // __INFINIOP_BINARY_OP_API_H__
23 changes: 23 additions & 0 deletions include/infiniop/ops/binary_ops_api.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#ifndef __INFINIOP_BINARY_OPS_API_H__
#define __INFINIOP_BINARY_OPS_API_H__

#include "binary_op_api.h"

/**
* @brief Unified API declarations for all binary operators.
*
* This header contains API declarations for all binary operators in a single file,
* eliminating the need for individual header files for each operator.
*
* All binary operator APIs are declared here:
* - div, pow, mod, max, min
*/

// Declare all binary operator APIs
BINARY_OP_API_DECLARE(div, Div)
BINARY_OP_API_DECLARE(pow, Pow)
BINARY_OP_API_DECLARE(mod, Mod)
BINARY_OP_API_DECLARE(max, Max)
BINARY_OP_API_DECLARE(min, Min)

#endif // __INFINIOP_BINARY_OPS_API_H__
48 changes: 48 additions & 0 deletions include/infiniop/ops/unary_op_api.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#ifndef __INFINIOP_UNARY_OP_API_H__
#define __INFINIOP_UNARY_OP_API_H__

#include "../operator_descriptor.h"

/**
* @brief Macro to generate the C API header for a unary operator.
*
* This macro generates all the necessary declarations for a unary operator:
* - Descriptor type definition
* - Create descriptor function
* - Get workspace size function
* - Execute operator function
* - Destroy descriptor function
*
* Usage:
* UNARY_OP_API_DECLARE(abs, Abs)
* UNARY_OP_API_DECLARE(log, Log)
*
* @param OP_NAME Lowercase operator name (e.g., abs, log, sin)
* @param OP_NAME_UPPER Uppercase operator name (e.g., Abs, Log, Sin)
*/
#define UNARY_OP_API_DECLARE(OP_NAME, OP_NAME_UPPER) \
\
typedef struct InfiniopDescriptor *infiniop##OP_NAME_UPPER##Descriptor_t; \
\
__C __export infiniStatus_t infiniopCreate##OP_NAME_UPPER##Descriptor( \
infiniopHandle_t handle, \
infiniop##OP_NAME_UPPER##Descriptor_t *desc_ptr, \
infiniopTensorDescriptor_t y, \
infiniopTensorDescriptor_t x); \
\
__C __export infiniStatus_t infiniopGet##OP_NAME_UPPER##WorkspaceSize( \
infiniop##OP_NAME_UPPER##Descriptor_t desc, \
size_t *size); \
\
__C __export infiniStatus_t infiniop##OP_NAME_UPPER( \
infiniop##OP_NAME_UPPER##Descriptor_t desc, \
void *workspace, \
size_t workspace_size, \
void *y, \
const void *x, \
void *stream); \
\
__C __export infiniStatus_t infiniopDestroy##OP_NAME_UPPER##Descriptor( \
infiniop##OP_NAME_UPPER##Descriptor_t desc);

#endif // __INFINIOP_UNARY_OP_API_H__
39 changes: 39 additions & 0 deletions include/infiniop/ops/unary_ops_api.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#ifndef __INFINIOP_UNARY_OPS_API_H__
#define __INFINIOP_UNARY_OPS_API_H__

#include "unary_op_api.h"

/**
* @brief Unified API declarations for all unary operators.
*
* This header contains API declarations for all unary operators in a single file,
* eliminating the need for individual header files for each operator.
*
* All unary operator APIs are declared here:
* - abs, log, sqrt, reciprocal, neg, round, sinh, sign, tan
* - acosh, asinh, cos, atanh, asin, floor, cosh, erf, atan, acos, ceil
*/

// Declare all unary operator APIs
UNARY_OP_API_DECLARE(abs, Abs)
UNARY_OP_API_DECLARE(log, Log)
UNARY_OP_API_DECLARE(sqrt, Sqrt)
UNARY_OP_API_DECLARE(reciprocal, Reciprocal)
UNARY_OP_API_DECLARE(neg, Neg)
UNARY_OP_API_DECLARE(round, Round)
UNARY_OP_API_DECLARE(sinh, Sinh)
UNARY_OP_API_DECLARE(sign, Sign)
UNARY_OP_API_DECLARE(tan, Tan)
UNARY_OP_API_DECLARE(acosh, Acosh)
UNARY_OP_API_DECLARE(asinh, Asinh)
UNARY_OP_API_DECLARE(cos, Cos)
UNARY_OP_API_DECLARE(atanh, Atanh)
UNARY_OP_API_DECLARE(asin, Asin)
UNARY_OP_API_DECLARE(floor, Floor)
UNARY_OP_API_DECLARE(cosh, Cosh)
UNARY_OP_API_DECLARE(erf, Erf)
UNARY_OP_API_DECLARE(atan, Atan)
UNARY_OP_API_DECLARE(acos, Acos)
UNARY_OP_API_DECLARE(ceil, Ceil)

#endif // __INFINIOP_UNARY_OPS_API_H__
Loading