amduat/tests/enc/test_pel_program_dag.c

114 lines
3.4 KiB
C
Raw Normal View History

2025-12-20 13:54:18 +01:00
#include "amduat/enc/pel_program_dag.h"
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
static const uint8_t k_expected_program_bytes[] = {
0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00,
0x00, 0x05, 0x61, 0x64, 0x64, 0x36, 0x34, 0x00, 0x00, 0x00, 0x01, 0x00,
0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00,
0x05, 0x6d, 0x75, 0x6c, 0x36, 0x34, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00,
0x00, 0x02, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01,
0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00,
};
static bool bytes_equal(amduat_octets_t bytes,
const uint8_t *expected,
size_t expected_len) {
if (bytes.len != expected_len) {
return false;
}
if (bytes.len == 0) {
return true;
}
return memcmp(bytes.data, expected, expected_len) == 0;
}
static int test_program_encoding(void) {
amduat_pel_dag_input_t node1_inputs[2];
amduat_pel_dag_input_t node2_inputs[2];
amduat_pel_node_t nodes[2];
amduat_pel_root_ref_t roots[1];
amduat_pel_program_t program;
amduat_octets_t encoded;
amduat_pel_program_t decoded;
const char add_name[] = "add64";
const char mul_name[] = "mul64";
int exit_code = 1;
node1_inputs[0].kind = AMDUAT_PEL_DAG_INPUT_EXTERNAL;
node1_inputs[0].value.external.input_index = 0;
node1_inputs[1].kind = AMDUAT_PEL_DAG_INPUT_EXTERNAL;
node1_inputs[1].value.external.input_index = 1;
node2_inputs[0].kind = AMDUAT_PEL_DAG_INPUT_NODE;
node2_inputs[0].value.node.node_id = 1;
node2_inputs[0].value.node.output_index = 0;
node2_inputs[1].kind = AMDUAT_PEL_DAG_INPUT_EXTERNAL;
node2_inputs[1].value.external.input_index = 2;
nodes[0].id = 2;
nodes[0].op.name = amduat_octets(mul_name, strlen(mul_name));
nodes[0].op.version = 1;
nodes[0].inputs = node2_inputs;
nodes[0].inputs_len = 2;
nodes[0].params = amduat_octets(NULL, 0);
nodes[1].id = 1;
nodes[1].op.name = amduat_octets(add_name, strlen(add_name));
nodes[1].op.version = 1;
nodes[1].inputs = node1_inputs;
nodes[1].inputs_len = 2;
nodes[1].params = amduat_octets(NULL, 0);
roots[0].node_id = 2;
roots[0].output_index = 0;
program.nodes = nodes;
program.nodes_len = 2;
program.roots = roots;
program.roots_len = 1;
if (!amduat_enc_pel_program_dag_encode_v1(&program, &encoded)) {
fprintf(stderr, "encode failed\n");
return exit_code;
}
if (!bytes_equal(encoded, k_expected_program_bytes,
sizeof(k_expected_program_bytes))) {
fprintf(stderr, "encoded bytes mismatch\n");
goto cleanup;
}
if (!amduat_enc_pel_program_dag_decode_v1(encoded, &decoded)) {
fprintf(stderr, "decode failed\n");
goto cleanup;
}
if (decoded.nodes_len != 2 || decoded.roots_len != 1) {
fprintf(stderr, "decoded lengths mismatch\n");
goto cleanup_decoded;
}
if (decoded.nodes[0].id != 1 || decoded.nodes[1].id != 2) {
fprintf(stderr, "decoded node order mismatch\n");
goto cleanup_decoded;
}
exit_code = 0;
cleanup_decoded:
amduat_enc_pel_program_dag_free(&decoded);
cleanup:
free((void *)encoded.data);
return exit_code;
}
int main(void) {
return test_program_encoding();
}