#include "amduat/enc/pel_program_dag.h" #include #include #include #include #include static const uint8_t k_expected_program_bytes[] = { 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x05, 0x61, 0x64, 0x64, 0x36, 0x34, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x05, 0x6d, 0x75, 0x6c, 0x36, 0x34, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, }; static bool bytes_equal(amduat_octets_t bytes, const uint8_t *expected, size_t expected_len) { if (bytes.len != expected_len) { return false; } if (bytes.len == 0) { return true; } return memcmp(bytes.data, expected, expected_len) == 0; } static int test_program_encoding(void) { amduat_pel_dag_input_t node1_inputs[2]; amduat_pel_dag_input_t node2_inputs[2]; amduat_pel_node_t nodes[2]; amduat_pel_root_ref_t roots[1]; amduat_pel_program_t program; amduat_octets_t encoded; amduat_pel_program_t decoded; const char add_name[] = "add64"; const char mul_name[] = "mul64"; int exit_code = 1; node1_inputs[0].kind = AMDUAT_PEL_DAG_INPUT_EXTERNAL; node1_inputs[0].value.external.input_index = 0; node1_inputs[1].kind = AMDUAT_PEL_DAG_INPUT_EXTERNAL; node1_inputs[1].value.external.input_index = 1; node2_inputs[0].kind = AMDUAT_PEL_DAG_INPUT_NODE; node2_inputs[0].value.node.node_id = 1; node2_inputs[0].value.node.output_index = 0; node2_inputs[1].kind = AMDUAT_PEL_DAG_INPUT_EXTERNAL; node2_inputs[1].value.external.input_index = 2; nodes[0].id = 2; nodes[0].op.name = amduat_octets(mul_name, strlen(mul_name)); nodes[0].op.version = 1; nodes[0].inputs = node2_inputs; nodes[0].inputs_len = 2; nodes[0].params = amduat_octets(NULL, 0); nodes[1].id = 1; nodes[1].op.name = amduat_octets(add_name, strlen(add_name)); nodes[1].op.version = 1; nodes[1].inputs = node1_inputs; nodes[1].inputs_len = 2; nodes[1].params = amduat_octets(NULL, 0); roots[0].node_id = 2; roots[0].output_index = 0; program.nodes = nodes; program.nodes_len = 2; program.roots = roots; program.roots_len = 1; if (!amduat_enc_pel_program_dag_encode_v1(&program, &encoded)) { fprintf(stderr, "encode failed\n"); return exit_code; } if (!bytes_equal(encoded, k_expected_program_bytes, sizeof(k_expected_program_bytes))) { fprintf(stderr, "encoded bytes mismatch\n"); goto cleanup; } if (!amduat_enc_pel_program_dag_decode_v1(encoded, &decoded)) { fprintf(stderr, "decode failed\n"); goto cleanup; } if (decoded.nodes_len != 2 || decoded.roots_len != 1) { fprintf(stderr, "decoded lengths mismatch\n"); goto cleanup_decoded; } if (decoded.nodes[0].id != 1 || decoded.nodes[1].id != 2) { fprintf(stderr, "decoded node order mismatch\n"); goto cleanup_decoded; } exit_code = 0; cleanup_decoded: amduat_enc_pel_program_dag_free(&decoded); cleanup: free((void *)encoded.data); return exit_code; } int main(void) { return test_program_encoding(); }