AIfES 2  2.0.0
Tutorial inference Q7

This tutorial shows based on an example, how to perform an inference in AIfES on an integer quantized Feed-Forward Neural Network (FNN). It is assumed, that the trained weights are already available or will be calculated with external tools on a PC. If you want to train the neural network with AIfES, switch to the training tutorial.

Example

As an example, we take a robot with two powered wheels and a RGB color sensor that should follow a black line on a white paper. To fulfill the task, we map the color sensor values with an FNN directly to the control commands for the two wheel-motors of the robot. The inputs for the FNN are the RGB color values scaled to the interval [0, 1] and the outputs are either "on" (1) or "off" (0) to control the motors.

The following cases should be considered:

  1. The sensor points to a black area (RGB = [0, 0, 0]): The robot is too far on the left and should turn on the left wheel-motor while removing power from the right motor.
  2. The sensor points to a white area (RGB = [1, 1, 1]): The robot is too far on the right and should turn on the right wheel-motor while removing power from the left motor.
  3. The sensor points to a red area (RGB = [1, 0, 0]): The robot reached the stop mark and should switch off both motors.

The resulting input data of the FNN is then

R G B
0 0 0
1 1 1
1 0 0

and the output should be

left motor right motor
1 0
0 1
0 0

Design the neural network

To set up the FNN in AIfES 2, we need to design the structure of the neural network. It needs three inputs for the RGB color values and two outputs for the two motors. Because the task is rather easy, we use just one hidden layer with three neurons.

To create the network in AIfES, it must be divided into logical layers like the fully-connected (dense) layer and the activation functions. We choose a Leaky ReLU activation for the hidden layer and a Sigmoid activation for the output layer.

Get the pre-trained weights and biases

To perform an inference you need the trained weights and biases of the model. For example you can train your model with Keras or PyTorch, extract the weights and biases and copy them to your AIfES model.
For a dense layer, AIfES expects the weights as a matrix of shape [Inputs x Outputs] and the bias as a matrix of shape [1 x Outputs].

Example model in Keras:

model = Sequential()
model.add(Input(shape=(3,)))
model.add(Dense(3))
model.add(LeakyReLU(alpha=0.01))
model.add(Dense(2))
model.add(Activation('sigmoid'))

Example model in PyTorch:

class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.dense_layer_1 = nn.Linear(3, 3)
self.leaky_relu_layer = nn.LeakyReLU(0.01)
self.dense_layer_2 = nn.Linear(3, 2)
self.sigmoid_layer = nn.Sigmoid()
def forward(self, x):
x = self.dense_layer_1(x)
x = self.leaky_relu_layer(x)
x = self.dense_layer_2(x)
x = self.sigmoid_layer(x)
return x

Our example weights and biases for the two dense layers after training are:

\[ w_1 = \left( \begin{array}{c} 3.64540 & -3.60981 & 1.57631 \\ -2.98952 & -1.91465 & 3.06150 \\ -2.76578 & -1.24335 & 0.71257 \end{array}\right) \]

\[ b_1 = \left( \begin{array}{c} 0.72655 & 2.67281 & -0.21291 \end{array}\right) \]

\[ w_2 = \left( \begin{array}{c} -1.09249 & -2.44526 \\ 3.23528 & -2.88023 \\ -2.51201 & 2.52683 \end{array}\right) \]

\[ b_2 = \left( \begin{array}{c} 0.14391 & -1.34459 \end{array}\right) \]

Q7 quantization

General

The Q7 quantization is an asymmetric 8 bit integer quantization that allows integer-only calculations on real values. The quantization procedure is closely related to the proposed techniques of the paper "Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference". A real value \( r \) is represented by an integer value \( q \), the scaling factor / shift \( s \) and the zero point \( z \) according to the following formula:

\[ r = 2^{-s} * (q - z) \]

To get the quantized value \( q \) out of a real value \( r \) you have to calculate

\[ q = round(\frac{r}{2^{-s}} + z) \]

To store the zero point and the scaling factor, a Q7 tensor needs additional parameters. These are stored in a structure of type aimath_q7_params_t and set to the tensor_params field of an aitensor_t. For example the tensor

\[ \left( \begin{array}{c} 0 & 1 & 2 \\ 3 & 4 & 5 \end{array}\right) \]

can be created with

int8_t example_data[] = {0, 2, 4
6, 8, 10};
uint16_t example_shape[] = {2, 3};
aimath_q7_params_t example_q_params = { .shift = 1, .zero_point = 0 };
aitensor_t example_tensor = {
.dtype = aiq7,
.dim = 2,
.shape = example_shape,
.tensor_params = example_q_params,
.data = example_data
};
const aimath_dtype_t * aiq7
The Q7 data-type indicator.
Parameters used for the quantized Q7 values, used as property of a tensor.
Definition: aimath_q7.h:148
uint16_t shift
The scaling factor of the quantization (The total scale is calculated with )
Definition: aimath_q7.h:149
A tensor in AIfES.
Definition: aifes_math.h:89
const aimath_dtype_t * dtype
The datatype of the tensor, e.g.
Definition: aifes_math.h:90

or

int8_t example_data[] = {0, 2, 4
6, 8, 10};
uint16_t example_shape[] = {2, 3};
aimath_q7_params_t example_q_params = { .shift = 1, .zero_point = 0 };
aitensor_t example_tensor = AITENSOR_2D_Q7(example_shape, &example_q_params, example_data);

Scalar values can be created with the structure aiscalar_q7_t and initialized either automatically with

aiscalar_q7_t scalar = AISCALAR_Q7(4.2f, 4, 0);
Single quantized Q7 value/scalar.
Definition: aimath_q7.h:155

or manually with

aiscalar_q7_t scalar = { .value = 67, .shift = 4, .zero_point = 0};
int8_t value
Quantized value .
Definition: aimath_q7.h:156

Also the macro functions FLOAT_TO_Q7(float_value, shift, zero_point) and Q7_TO_FLOAT(integer_value, shift, zero_point) are available for quick and easy conversion between float and integer values.

Tensor quantization helper

AIfES provides some helper functions to quantize 32 bit float values to 8 bit integer values. The following example shows how to automatically quantize a tensor with AIfES:

Asymmetric quantization (zero point != 0):

// Source tensor with 32 bit float values
float example_data_f32[] = {0.0f, 1.0f, 2.0f,
3.0f, 4.0f, 5.0f};
uint16_t example_shape_f32[] = {2, 3};
aitensor_t example_tensor_f32 = AITENSOR_2D_F32(example_shape_f32, example_data_f32);
// Destination tensor to write the 8 bit integer quantized values in
int8_t example_data_q7[2*3];
uint16_t example_shape_q7[] = {2, 3};
aimath_q7_params_t example_q_params_q7;
aitensor_t example_tensor_q7 = AITENSOR_2D_Q7(example_shape_q7, &example_q_params_q7, example_data_q7);
float min_value, max_value;
aimath_f32_default_min(&example_tensor_f32, &min_value);
aimath_f32_default_max(&example_tensor_f32, &max_value);
aimath_q7_calc_q_params_from_f32(min_value, max_value, &example_q_params_q7); // Calculate the quantization parameters for a tensor with values in the given range
aimath_q7_quantize_tensor_from_f32(&example_tensor_f32, &example_tensor_q7); // Quantize the F32 tensor and write the results to the Q7 tensor
// Print the quantized tensor to the console
print_aitensor(&example_tensor_q7);
void print_aitensor(const aitensor_t *tensor)
Printing a tensor to console.
void aimath_f32_default_min(const aitensor_t *x, void *result)
Identifies the minimum value in a F32 tensor.
void aimath_f32_default_max(const aitensor_t *x, void *result)
Identifies the maximum value in a F32 tensor.
void aimath_q7_quantize_tensor_from_f32(const aitensor_t *tensor_f32, aitensor_t *tensor_q7)
Converts a float f32 tensor into a quantized q7 tensor.
void aimath_q7_calc_q_params_from_f32(float min_value, float max_value, aimath_q7_params_t *q_params)
Calculates the aimath_q7_params parameters.

Symmetric quantization (zero point = 0):

// Source tensor with 32 bit float values
float example_data_f32[] = {0.0f, 1.0f, 2.0f,
3.0f, 4.0f, 5.0f};
uint16_t example_shape_f32[] = {2, 3};
aitensor_t example_tensor_f32 = AITENSOR_2D_F32(example_shape_f32, example_data_f32);
// Destination tensor to write the 8 bit integer quantized values in
int8_t example_data_q7[2*3];
uint16_t example_shape_q7[] = {2, 3};
aimath_q7_params_t example_q_params_q7;
aitensor_t example_tensor_q7 = AITENSOR_2D_Q7(example_shape_q7, &example_q_params_q7, example_data_q7);
float min_value, max_value;
aimath_f32_default_min(&example_tensor_f32, &min_value);
aimath_f32_default_max(&example_tensor_f32, &max_value);
if(max_value < -min_value){
max_value = -min_value;
}
aimath_q7_calc_q_params_from_f32(-max_value, max_value, &example_q_params_q7); // Calculate the quantization parameters for a tensor with values in the given range
aimath_q7_quantize_tensor_from_f32(&example_tensor_f32, &example_tensor_q7); // Quantize the F32 tensor and write the results to the Q7 tensor
// Print the quantized tensor to the console
print_aitensor(&example_tensor_q7);

Manual quantization of the neural network

Quantization of the intermediate results

To perform integer-only calculations, the first thing of an ANN we need to quantize are the intermediate results of the layers. Every layer needs to know in which range the values of its result are, in order to do the right calculations without any overflow. Therefor we have to perform some inferences on a representative part of the dataset (or the whole dataset) and remember the min and max values of each layers result tensor. Afterwards we can calculate the quantization parameters (shift and zero point) and configure them to the result tensors.

Example:

Our input dataset is given by the tensor

\[ \left( \begin{array}{rrr} 0 & 0 & 0 \\ 1 & 1 & 1 \\ 1 & 0 & 0 \end{array}\right) \]

The inference / forward pass for the layers gives us the following values:

Dense 1:

\[ \left( \begin{array}{rrr} 0.72655 & 2.67281 & -0.21291 \\ -1.38335 & -4.09500 & 5.13747 \\ 4.37195 & -0.93700 & 1.36340 \end{array}\right) \]

Leaky ReLU:

\[ \left( \begin{array}{rrr} 0.72655 & 2.67281 & -0.00213 \\ -0.01383 & -0.04095 & 5.13747 \\ 4.37195 & -0.00937 & 1.36340 \end{array}\right) \]

Dense 2:

\[ \left( \begin{array}{rrr} 8.00280 & -10.82488 \\ -12.87884 & 11.78870 \\ -8.08759 & -8.56308 \end{array}\right) \]

Sigmoid:

\[ \left( \begin{array}{rrr} 0.99967 & 0.00002 \\ 0.00000 & 0.99999 \\ 0.00031 & 0.00019 \end{array}\right) \]

Now we can calculate the quantization parameters (shift and zero point) with the value range (minimum and maximum values) of the result tensors of the input and dense layers. (The activation layers can calculate the quantization parameters on their own, because they just perform fixed transformations. Check the documentation to see if a layer needs manual calculation of the parameters or not.) For this we can use the aimath_q7_calc_q_params_from_f32() function. It is recommendet to add a small safety margin to the min and max values to deal with unseen data that results in slightly bigger values.

The min and max values and the resulting quantization parameters of the intermediate results of the layers are:

layer min max margin shift zero point
Input 0 1 10 % 7 -70
Dense 1 -4.095 5.13747 10 % 4 -9
Dense 2 -12.878839 11.788695 10 % 3 5

Example for quantization parameter calculation with margin:

float min_value = 0.0f;
float max_value = 1.0f;
float margin = 0.1f;
aimath_q7_params_t example_q_params;
aimath_q7_calc_q_params_from_f32(min_value * (1.0f + margin), max_value * (1.0f + margin), &example_q_params);

Quantization of the weights and biases

In our simple example, we have two dense layers with weights and biases. For the integer model, we need to quantize the floating point values (given in the section above) to the integer type. To simplify the calculations of the inference, we use the symmetric quantization (zero point is zero). We perform a 8 bit integer quantization (Q7) on the weights and a 32 bit integer quantization (Q31) on the bias, to get an efficient tradeoff between performance, size and accuracy.

The maximum absolute value of the weight tensors and the resulting quantization parameters are

tensor max_abs shift zero point
Weights 1 3.6454 5 0
Weights 2 3.23528 5 0

The shift of the (32 bit quantized) bias tensor is the sum of the related weights shift and result shift of the previous layer (This is due to the simple addition of the bias value to the internal accumulator without shift correction) are

tensor weights shift of same layer result shift of previous layer shift zero point
Bias 1 5 7 12 0
Bias 2 5 4 9 0

Tipp: Check the documentation of the layers in the Q7 datatype (e.g. of ailayer_sigmoid_q7_default()) to find the result shift of the activation layers. In addition, the function ailayer.calc_result_tensor_params can calculate the value.

The resulting integer values of the quantized weights and bias tensors are

\[ w_1 = \left( \begin{array}{c} 117 & -116 & 50 \\ -96 & -61 & 98 \\ -89 & -40 & 23 \end{array}\right) \]

\[ b_1 = \left( \begin{array}{c} 2976 & 10948 & -872 \end{array}\right) \]

\[ w_2 = \left( \begin{array}{c} -35 & -78 \\ 104 & -92 \\ -80 & 81 \end{array}\right) \]

\[ b_2 = \left( \begin{array}{c} 74 & -688 \end{array}\right) \]

Create the neural network in AIfES

AIfES provides implementations of the layers for different data types that can be optimized for several hardware platforms. An overview of the layers that are available for inference can be seen in the overview section of the main page. In the overview table you can click on the layer in the first column for a description on how the layer works. To see how to create the layer in code, choose one of the implementations for your used data type and click on the link.
In this tutorial we work with the int 8 data type (Q7 ) and use the default implementations (without any hardware specific optimizations) of the layers.

Used layer types:

Used implementations:

Manual declaration and configuration of the layers

For every layer we need to create a variable of the specific layer type and configure it for our needs. See the documentation of the data type and hardware specific implementations (for example ailayer_dense_q7_default()) for code examples on how to configure the layers.

Our designed network can be declared with the following code.

// The main model structure that holds the whole neural network
aimodel_t model;
// The layer structures for Q7 data type and their configurations
// Input
uint16_t input_layer_shape[] = {1, 3};
aimath_q7_params_t input_layer_result_q_params = { .shift = 7, .zero_point = -70 };
ailayer_input_q7_t input_layer = AILAYER_INPUT_Q7_M(2, input_layer_shape, &input_layer_result_q_params);
// Dense
// Weights (8 bit quantized)
aimath_q7_params_t dense_layer_1_weights_q_params = { .shift = 5, .zero_point = 0 };
int8_t dense_layer_1_weights[3*3] = {117, -116, 50,
-96, -61, 98,
-89, -40, 23};
// Bias (32 bit quantized)
aimath_q31_params_t dense_layer_1_bias_q_params = { .shift = 12, .zero_point = 0 };
int32_t dense_layer_1_bias[1*3] = {2976, 10948, -872};
// Result (8 bit quantized)
aimath_q7_params_t dense_layer_1_result_q_params = { .shift = 4, .zero_point = -9 };
ailayer_dense_q7_t dense_layer_1 = AILAYER_DENSE_Q7_M(3,
dense_layer_1_weights, &dense_layer_1_weights_q_params,
dense_layer_1_bias, &dense_layer_1_bias_q_params,
&dense_layer_1_result_q_params);
// Leaky ReLU
// Result (8 bit quantized)
ailayer_leaky_relu_q7_t leaky_relu_layer = AILAYER_LEAKY_RELU_Q7_M(AISCALAR_Q7(0.01f,10,0));
// Dense
// Weights (8 bit quantized)
aimath_q7_params_t dense_layer_2_weights_q_params = { .shift = 5, .zero_point = 0 };
int8_t dense_layer_2_weights[3*2] = {-35, -78,
104, -92,
-80, 81};
// Bias (32 bit quantized)
aimath_q31_params_t dense_layer_2_bias_q_params = { .shift = 9, .zero_point = 0 };
int32_t dense_layer_2_bias[1*2] = {74, -688};
// Result (8 bit quantized)
aimath_q7_params_t dense_layer_2_result_q_params = { .shift = 3, .zero_point = 5 };
ailayer_dense_q7_t dense_layer_2 = AILAYER_DENSE_Q7_M(2,
dense_layer_2_weights, &dense_layer_2_weights_q_params,
dense_layer_2_bias, &dense_layer_2_bias_q_params,
&dense_layer_2_result_q_params);
// Sigmoid
// Result (8 bit quantized)
ailayer_sigmoid_q7_t sigmoid_layer = AILAYER_SIGMOID_Q7_M();
General Dense layer structure.
Definition: ailayer_dense.h:71
General Input layer structure.
Definition: ailayer_input.h:39
Data-type specific Leaky ReLU layer struct for Q7 .
Definition: ailayer_leaky_relu_default.h:69
General Sigmoid layer struct.
Definition: ailayer_sigmoid.h:47
Parameters used for the quantized Q31 values, used as property of a tensor.
Definition: aimath_q31.h:149
uint16_t shift
The scaling factor of the quantization (The total scale is calculated with )
Definition: aimath_q31.h:150
AIfES artificial neural network model.
Definition: aifes_core.h:181

We use the initializer macros with the "_M" (for "Manually") at the end, because we need to set our parameters (like the weights) to the layers.

Automatic declaration and configuration of the layers

If you already have the parameters of your model as a flat byte array (for example by quantization in python using the aifes-pytools or by loading the parameters from a stored buffer created with AIfES), you can use the automatic configuration of the layer. This way makes it easier to configure the neural network and to update new weight sets. On the other hand you lose some flexibility on where you store your data and it is harder to distinguish between the different values. Also the buffer format might change slightly with further updates of AIfES, so make shure that everything works out when updating.

// The main model structure that holds the whole neural network
aimodel_t model;
// The layer structures for Q7 data type and their configurations
uint16_t input_layer_shape[] = {1, 3};
ailayer_input_q7_t input_layer = AILAYER_INPUT_Q7_A(2, input_layer_shape);
ailayer_dense_q7_t dense_layer_1 = AILAYER_DENSE_Q7_A(3);
ailayer_leaky_relu_q7_t leaky_relu_layer = AILAYER_LEAKY_RELU_Q7_A(AISCALAR_Q7(0.01f,10,0));
ailayer_dense_q7_t dense_layer_2 = AILAYER_DENSE_Q7_A(2);
ailayer_sigmoid_q7_t sigmoid_layer = AILAYER_SIGMOID_Q7_A();
// ... do the layer connetion and model compilation of the following chapter here ...
// The parameter memory may vary on different architectures because the parameters have to be memory aligned
#define MEMORY_ALIGNMENT 4 // Must have the same size as AIFES_MEMORY_ALIGNMENT
#if MEMORY_ALIGNMENT == 4
// This is an example parameter buffer for a memory alignment of 4 byte where integer are stored with little endian byteorder
const uint32_t parameter_memory_size = 76;
const uint32_t model_parameters[ 19 ] = {
0x00BA0007, 0x00F70004, 0x00050003, 0x00000005, 0xA0328C75, 0xD8A762C3, 0x00000017, 0x0000000C,
0x00000000, 0x00000BA0, 0x00002AC4, 0xFFFFFC98, 0x00000005, 0xA468B2DD, 0x000051B0, 0x00000009,
0x00000000, 0x0000004A, 0xFFFFFD50
};
#elif MEMORY_ALIGNMENT == 2
// This is an example parameter buffer for a memory alignment of 2 byte where integer are stored with little endian byteorder
const uint32_t parameter_memory_size = 68;
const uint16_t model_parameters[ 34 ] = {
0x0007, 0x00BA, 0x0004, 0x00F7, 0x0003, 0x0005, 0x0005, 0x0000, 0x8C75, 0xA032, 0x62C3, 0xD8A7, 0x0017, 0x000C, 0x0000, 0x0000,
0x0BA0, 0x0000, 0x2AC4, 0x0000, 0xFC98, 0xFFFF, 0x0005, 0x0000, 0xB2DD, 0xA468, 0x51B0, 0x0009, 0x0000, 0x0000, 0x004A, 0x0000,
0xFD50, 0xFFFF
};
#endif
// Set the model parameters to the layers of the model
aialgo_distribute_parameter_memory(&model, model_parameters, parameter_memory_size);
void aialgo_distribute_parameter_memory(aimodel_t *model, void *memory_ptr, uint32_t memory_size)
Assign the memory for the trainable parameters (like weights, bias, ...) of the model.

We use the initializer macros with the "_A" (for "Automatically") at the end, because our parameters (like the weights) will be set by an aifes function to the layers.

Connection and initialization of the layers

Afterwards the layers are connected and initialized with the data type and hardware specific implementations

// Layer pointer to perform the connection
model.input_layer = ailayer_input_q7_default(&input_layer);
x = ailayer_dense_q7_default(&dense_layer_1, model.input_layer);
x = ailayer_leaky_relu_q7_default(&leaky_relu_layer, x);
x = ailayer_dense_q7_default(&dense_layer_2, x);
x = ailayer_sigmoid_q7_default(&sigmoid_layer, x);
model.output_layer = x;
// Finish the model creation by checking the connections and setting some parameters for further processing
uint8_t aialgo_compile_model(aimodel_t *model)
Initialize the model structure.
ailayer_t * ailayer_dense_q7_default(ailayer_dense_q7_t *layer, ailayer_t *input_layer)
Initializes and connect a Dense layer with the Q7 default implementation.
ailayer_t * ailayer_input_q7_default(ailayer_input_q7_t *layer)
Initializes and connect an Input layer with the Q7 default implementation.
ailayer_t * ailayer_leaky_relu_q7_default(ailayer_leaky_relu_q7_t *layer, ailayer_t *input_layer)
Initializes and connect a Leaky ReLU layer with the Q7 default implementation.
ailayer_t * ailayer_sigmoid_q7_default(ailayer_sigmoid_q7_t *layer, ailayer_t *input_layer)
Initializes and connect a Sigmoid layer with the Q7 default implementation.
AIfES layer interface.
Definition: aifes_core.h:252
ailayer_t * input_layer
Input layer of the model that gets the input data.
Definition: aifes_core.h:182
ailayer_t * output_layer
Output layer of the model.
Definition: aifes_core.h:183

Print the layer structure to the console

To see the structure of your created model, you can print a model summary to the console

aiprint("\n-------------- Model structure ---------------\n");
aiprint("----------------------------------------------\n\n");
void aialgo_print_model_structure(aimodel_t *model)
Print the layer structure of the model with the configured parameters.

Perform the inference

Allocate and initialize the working memory

Because AIfES doesn't allocate any memory on its own, you have to set the memory buffers for the inference manually. This memory is required for example for the intermediate results of the layers. Therefore you can choose fully flexible, where the buffer should be located in memory. To calculate the required amount of memory for the inference, the aialgo_sizeof_inference_memory() function can be used.
With aialgo_schedule_inference_memory() a memory block of the required size can be distributed and scheduled (memory regions might be shared over time) to the model.

A dynamic allocation of the memory using malloc could look like the following:

uint32_t inference_memory_size = aialgo_sizeof_inference_memory(&model);
void *inference_memory = malloc(inference_memory_size);
// Schedule the memory to the model
aialgo_schedule_inference_memory(&model, inference_memory, inference_memory_size);
uint8_t aialgo_schedule_inference_memory(aimodel_t *model, void *memory_ptr, uint32_t memory_size)
Assign the memory for intermediate results of an inference to the model.
uint32_t aialgo_sizeof_inference_memory(aimodel_t *model)
Calculate the memory requirements for intermediate results of an inference.

You could also pre-define a memory buffer if you know the size in advance, for example

const uint32_t inference_memory_size = 16;
char inference_memory[inference_memory_size];
...
// Schedule the memory to the model
aialgo_schedule_inference_memory(&model, inference_memory, inference_memory_size);

Run the inference

To perform the inference, the input data must be packed in a tensor to be processed by AIfES. A tensor in AIfES is just a N-dimensional array that is used to hold the data values in a structured way. To do this in the example, create a 2D tensor of the used data type. The shape describes the size of the dimensions of the tensor. The first dimension (the rows) is the batch dimension, i.e. the dimension of the different input samples. If you process just one sample at a time, this dimension is 1. The second dimension equals the inputs of the neural network.

uint16_t in_shape[2] = {1, 3};
aimath_q7_params_t in_q_params = { .shift = 7, .zero_point = -70 }; // Same as configured in the input layer
int8_t in_data[1*3] = {58, -70, -70};
aitensor_t in = AITENSOR_2D_Q7(in_shape, &in_q_params, in_data);

You can also create a float 32 tensor and let it quantize by AIfES:

// F32 tensor that has to be quantized
uint16_t in_f32_shape[2] = {1, 3};
float in_f32_data[1*3] = {1.0f, 0.0f, 0.0f};
aitensor_t in_f32 = AITENSOR_2D_F32(in_f32_shape, in_f32_data);
// Automatically quantized Q7 tensor to feed into the neural network
uint16_t in_shape[] = {1, 3};
aimath_q7_params_t in_q_params = { .shift = 7, .zero_point = -70 };
int8_t in_data[1*3];
aitensor_t in = AITENSOR_2D_Q7(in_shape, &in_q_params, in_data);
aimath_q7_quantize_tensor_from_f32(&in_f32, &in); // Quantize the F32 tensor and write the results to the Q7 tensor

Now everything is ready to perform the actual inference. For this you can use the function aialgo_inference_model().

// Create an empty tensor for the inference results
uint16_t out_shape[2] = {1, 2};
aimath_q7_params_t out_q_params;
float out_data[1*2];
aitensor_t out = AITENSOR_2D_Q7(out_shape, &out_q_params, out_data);
aialgo_inference_model(&model, &in, &out);
uint8_t aialgo_inference_model(aimodel_t *model, aitensor_t *input_data, aitensor_t *output_data)
Perform an inference on the model / Run the model.

Alternative you can also do the inference without creating an empty tensor for the result with the function aialgo_forward_model(). The results of this function are stored in the inference memory. If you want to perform another inference or delete the inference memory, you have to save the results first to another tensor / array. Otherwise you will loose the data.

aitensor_t *y = aialgo_forward_model(&model, &in);
aitensor_t * aialgo_forward_model(aimodel_t *model, aitensor_t *input_data)
Perform a forward pass on the model.

Afterwards you can print the results to the console for debugging purposes:

aiprint("input:\n");
aiprint("NN output:\n");

If you want to get F32 values out of the Q7 tensor, you can use the macro Q7_TO_FLOAT(integer_value, shift, zero_point) as explained earlier:

int8_t value_q7 = -127;
aimath_q7_params_t value_q_params = {.shift = 8, .zero_point = -128};
float value_f32 = Q7_TO_FLOAT(value_q7, value_q_params.shift, value_q_params.zero_point);
int8_t zero_point
The zero point of the quantization.
Definition: aimath_q7.h:150

Automatic quantization in AIfES

If you have a trained F32 model in AIfES and you want to convert it to a Q7 quantized model (for example to save memory needed for the weights or to speed up the inference), you can use the provided helper function aialgo_quantize_model_f32_to_q7(). To use this function, you need to have the F32 model and a Q7 model skeleton with the same structure as the F32 model. Also a dataset is needed that is representative for the input data of the model (for example a fraction or all of the training dataset) to calculate the quantization range.

Example:

If you have the example model, described in the F32 training tutorial or the F32 inference tutorial , you have to build up a Q7 model skeleton that looks like this:

// The main model structure that holds the whole neural network
aimodel_t model_q7;
// The layer structures for Q7 data type and their configurations (Same as in the F32 model, indicated with the "_f32" ending)
ailayer_input_q7_t input_layer_q7 = AILAYER_INPUT_Q7_A(2, input_layer_f32.input_shape);
ailayer_dense_q7_t dense_layer_1_q7 = AILAYER_DENSE_Q7_A(dense_layer_1_f32.neurons);
ailayer_leaky_relu_q7_t leaky_relu_layer_q7 = AILAYER_LEAKY_RELU_Q7_A(AISCALAR_Q7(leaky_relu_layer_f32.alpha,10,0));
ailayer_dense_q7_t dense_layer_2_q7 = AILAYER_DENSE_Q7_A(dense_layer_2_f32.neurons);
ailayer_sigmoid_q7_t sigmoid_layer_q7 = AILAYER_SIGMOID_Q7_A();
model_q7.input_layer = ailayer_input_q7_default(&input_layer_q7);
x = ailayer_dense_q7_default(&dense_layer_1_q7, model_q7.input_layer);
x = ailayer_leaky_relu_q7_default(&leaky_relu_layer_q7, x);
x = ailayer_dense_q7_default(&dense_layer_2_q7, x);
x = ailayer_sigmoid_q7_default(&sigmoid_layer_q7, x);
model_q7.output_layer = x;
// Finish the model creation by checking the connections and setting some parameters for further processing

In Addition you need to set the parameter memory and the inference memory to the model. In the parameter memory, the quantized parameters (weights, biases, quantization parameters) will be stored and the inference memory is needed to perform the quantization.

uint32_t parameter_memory_size_q7 = aialgo_sizeof_parameter_memory(&model_q7);
void *parameter_memory_q7 = malloc(parameter_memory_size_q7);
// Distribute the memory to the trainable parameters of the model
aialgo_distribute_parameter_memory(&model_q7, parameter_memory_q7, parameter_memory_size_q7);
uint32_t inference_memory_size_q7 = aialgo_sizeof_inference_memory(&model_q7);
void *inference_memory_q7 = malloc(inference_memory_size_q7);
// Schedule the memory to the model
aialgo_schedule_inference_memory(&model_q7, inference_memory_q7, inference_memory_size_q7);
uint32_t aialgo_sizeof_parameter_memory(aimodel_t *model)
Calculate the memory requirements for the trainable parameters (like weights, bias,...

Then you need the mentioned representative dataset of the input data. In this case we simply take the training dataset because it is very small and has no representative subset.

uint16_t x_repr_shape[2] = {3, 3};
float x_repr_data[3*3] = {0.0f, 0.0f, 0.0f,
1.0f, 1.0f, 1.0f,
1.0f, 0.0f, 0.0f};
aitensor_t x_repr = AITENSOR_2D_F32(x_repr_shape, x_repr_data);

Now you can perform the quantization by calling aialgo_quantize_model_f32_to_q7().

aialgo_quantize_model_f32_to_q7(&model_f32, &model_q7, &x_repr);
void aialgo_quantize_model_f32_to_q7(aimodel_t *model_f32, aimodel_t *model_q7, aitensor_t *representative_dataset)
Quantize model parameters (weights and bias)

The Q7 model is now ready to use for example to run an inference. You can also store the parameter memory buffer in your storage memory after the model is quantized and later on load it using the layer declaration described above.

Automatic quantization in Python

To automatically quantize a model using python (for example a Keras or PyTorch model) you can use our AIfES Python tools. You can install the tools via pip with:

pip install https://github.com/Fraunhofer-IMS/AIfES_for_Arduino/raw/main/etc/python/aifes_tools.zip

The the quantized model can either be setup manually with the calculated weights and quantization parameters (result_q_params, weights_q_params and weights_q7) or setup automatically by using the buffer printed by the print_flatbuffer_c_style() function.

Example: Quantize a tf.keras model:

import tensorflow as tf
from tensorflow import keras
import numpy as np
from tensorflow.keras import layers
from aifes.tools import quantize_model_q7, Layer, create_flatbuffer_q7, create_flatbuffer_f32, print_flatbuffer_c_style
# ------------------------------------------------- Create and train the model in tf.keras --------------------------------------------------------
model = keras.Sequential()
model.add(keras.Input(shape=(3,)))
model.add(layers.Dense(3, activation="leaky_relu"))
model.add(layers.Dense(2, activation="sigmoid"))
optimizer = keras.optimizers.Adam(lr=0.1)
model.compile(optimizer=optimizer, loss="binary_crossentropy")
model.summary()
X = np.array([[0., 0., 0.],
[1., 1., 1.],
[1., 0., 0.]])
T = np.array([[1., 0.],
[0., 1.],
[0., 0.]])
model.fit(X, T, batch_size=4, epochs=5)
'''
# You may set the weights manually instead of training the model.
w1 = np.array([3.64540, -3.60981, 1.57631,
-2.98952, -1.91465, 3.06150,
-2.76578, -1.24335, 0.71257]).reshape(3, 3)
b1 = np.array([0.72655, 2.67281, -0.21291])
w2 = np.array([-1.09249, -2.44526,
3.23528, -2.88023,
-2.51201, 2.52683]).reshape(3, 2)
b2 = np.array([0.14391, -1.34459])
weights = [w1, b1, w2, b2]
model.set_weights(weights)
'''
# ------------------------------------------------- Convert and quantize the model to a flatbuffer --------------------------------------------------------
# Representation of the model for AIfES pytools
layers = [
Layer.DENSE_WT, # Dense / Fully connected layer with transposed weights (WT)
Layer.LEAKY_RELU, # Leaky ReLU layer. Add the alpha parameter to the act_params list
Layer.DENSE_WT, # Dense / Fully connected layer with transposed weights (WT)
Layer.SIGMOID # Sigmoid layer
]
act_params = [0.01] # Append additional parameters fore the activation functions here (e.g. alpha value for Leaky ReLU)
weights = model.get_weights()
# Platform specific settings
ALIGNMENT = 4 # For example ALIGNMENT = 4 on ARM Cortex-M or ESP32 controllers and ALIGNMENT = 2 on AVR controllers
BYTEORDER = 'little' # 'little' for little-endian or 'big' for big-endian representation of the target system
result_q_params, weights_q_params, weights_q7 = quantize_model_q7(layers, weights, X, act_params=act_params)
# Print required parameters for a Q7 model to console
print()
print("Layer result quantization parameters (shift, zero point):")
print(result_q_params)
print("Weight and bias quantization parameters (shift, zero point):")
print(weights_q_params)
print("Q7 weights:")
print(weights_q7)
print()
flatbuffer_q7 = create_flatbuffer_q7(result_q_params, weights_q_params, weights_q7, target_alignment=ALIGNMENT, byteorder=BYTEORDER)
flatbuffer_f32 = create_flatbuffer_f32(weights)
# Print the parameter memory buffer to the console
print("\nQ7:")
print_flatbuffer_c_style(flatbuffer_q7, elements_per_line=12, target_alignment=ALIGNMENT, byteorder=BYTEORDER, mutable=True)
print("\nF32:")
print_flatbuffer_c_style(flatbuffer_f32, elements_per_line=8)

Example: Quantize a PyTorch model

import numpy as np
import torch
from torch import nn
from aifes.tools import quantize_model_q7, Layer, create_flatbuffer_q7, create_flatbuffer_f32, print_flatbuffer_c_style
# ------------------------------------------------- Create and train the model in PyTorch --------------------------------------------------------
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.dense_layer_1 = nn.Linear(3, 3)
self.leaky_relu_layer = nn.LeakyReLU(0.01)
self.dense_layer_2 = nn.Linear(3, 2)
self.sigmoid_layer = nn.Sigmoid()
def forward(self, x):
x = self.dense_layer_1(x)
x = self.leaky_relu_layer(x)
x = self.dense_layer_2(x)
x = self.sigmoid_layer(x)
return x
X = np.array([[0., 0., 0.],
[1., 1., 1.],
[1., 0., 0.]])
Y = np.array([[1., 0.],
[0., 1.],
[0., 0.]])
X_tensor = torch.FloatTensor(X)
Y_tensor = torch.FloatTensor(Y)
model = Net()
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
epochs = 200
model.train()
for epoch in range(epochs):
optimizer.zero_grad()
pred = model(X_tensor)
loss = criterion(pred, Y_tensor)
loss.backward()
optimizer.step()
# ------------------------------------------------- Convert and quantize the model to a flatbuffer --------------------------------------------------------
# Representation of the model for AIfES pytools
layers = [
Layer.DENSE_WT, # Dense / Fully connected layer with transposed weights (WT)
Layer.LEAKY_RELU, # Leaky ReLU layer. Add the alpha parameter to the act_params list
Layer.DENSE_WT, # Dense / Fully connected layer with transposed weights (WT)
Layer.SIGMOID # Sigmoid layer
]
act_params = [0.01] # Append additional parameters fore the activation functions here (e.g. alpha value for Leaky ReLU)
weights = [param.detach().numpy().T for param in model.parameters()]
# Platform specific settings
ALIGNMENT = 4 # For example ALIGNMENT = 4 on ARM Cortex-M or ESP32 controllers and ALIGNMENT = 2 on AVR controllers (must equal AIFES_MEMORY_ALIGNMENT)
BYTEORDER = 'little' # 'little' for little-endian or 'big' for big-endian representation of the target system
result_q_params, weights_q_params, weights_q7 = quantize_model_q7(layers, weights, X, act_params=act_params)
# Print required parameters for a Q7 model to console
print()
print("Layer result quantization parameters (shift, zero point):")
print(result_q_params)
print("Weight and bias quantization parameters (shift, zero point):")
print(weights_q_params)
print("Q7 weights:")
print(weights_q7)
print()
flatbuffer_q7 = create_flatbuffer_q7(result_q_params, weights_q_params, weights_q7, target_alignment=ALIGNMENT, byteorder=BYTEORDER)
flatbuffer_f32 = create_flatbuffer_f32(weights)
# Print the parameter memory buffer to the console
print("\nQ7:")
print_flatbuffer_c_style(flatbuffer_q7, elements_per_line=10, target_alignment=ALIGNMENT, byteorder=BYTEORDER, mutable=True)
print("\nF32:")
print_flatbuffer_c_style(flatbuffer_f32, elements_per_line=8)