amduat/tests/enc/test_pel_trace_dag.c

238 lines
8 KiB
C
Raw Normal View History

2025-12-20 13:54:18 +01:00
#include "amduat/enc/pel_trace_dag.h"
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
static const uint8_t k_expected_trace_bytes[] = {
0x00, 0x01, 0x00, 0x00, 0x00, 0x22, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01,
0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01,
0x01, 0x01, 0x01, 0x01, 0x00, 0x00, 0x00, 0x22, 0x00, 0x01, 0x02, 0x02,
0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02,
0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x01, 0x00, 0x00, 0x00, 0x22, 0x00, 0x01, 0x03, 0x03, 0x03, 0x03, 0x03,
0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03,
0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03,
0x03, 0x03, 0x03, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00, 0x22, 0x00,
0x01, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10,
0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10,
0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x10, 0x00, 0x00, 0x00,
0x22, 0x00, 0x01, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11,
0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x00,
0x00, 0x00, 0x22, 0x00, 0x01, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12,
0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12,
0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12, 0x12,
0x12, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00,
0x00, 0x05, 0x61, 0x64, 0x64, 0x36, 0x34, 0x00, 0x00, 0x00, 0x01, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x22,
0x00, 0x01, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20,
0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x20, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x05, 0x6d, 0x75,
0x6c, 0x36, 0x34, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x22, 0x00, 0x01, 0x21, 0x21,
0x21, 0x21, 0x21, 0x21, 0x21, 0x21, 0x21, 0x21, 0x21, 0x21, 0x21, 0x21,
0x21, 0x21, 0x21, 0x21, 0x21, 0x21, 0x21, 0x21, 0x21, 0x21, 0x21, 0x21,
0x21, 0x21, 0x21, 0x21, 0x21, 0x21, 0x00, 0x00, 0x00, 0x00,
};
static void fill_digest(uint8_t *out, uint8_t value) {
memset(out, value, 32);
}
static amduat_reference_t make_ref(uint8_t value, uint8_t *storage) {
fill_digest(storage, value);
return amduat_reference(0x0001, amduat_octets(storage, 32));
}
static amduat_reference_t make_ref_custom(amduat_hash_id_t hash_id,
uint8_t *storage,
size_t len,
uint8_t value) {
memset(storage, value, len);
return amduat_reference(hash_id, amduat_octets(storage, len));
}
2025-12-20 13:54:18 +01:00
static bool bytes_equal(amduat_octets_t bytes,
const uint8_t *expected,
size_t expected_len) {
if (bytes.len != expected_len) {
return false;
}
if (bytes.len == 0) {
return true;
}
return memcmp(bytes.data, expected, expected_len) == 0;
}
static int test_trace_encoding(void) {
amduat_pel_trace_dag_value_t trace;
amduat_pel_node_trace_dag_t nodes[2];
amduat_reference_t input_refs[3];
amduat_reference_t output_refs0[1];
amduat_reference_t output_refs1[1];
amduat_octets_t encoded;
amduat_pel_trace_dag_value_t decoded;
uint8_t s[32], p[32], r[32], i0[32], i1[32], i2[32], o0[32], o1[32];
const char add_name[] = "add64";
const char mul_name[] = "mul64";
int exit_code = 1;
memset(&trace, 0, sizeof(trace));
trace.pel1_version = 1;
trace.scheme_ref = make_ref(0x01, s);
trace.program_ref = make_ref(0x02, p);
trace.status = AMDUAT_PEL_EXEC_STATUS_OK;
trace.summary.kind = AMDUAT_PEL_EXEC_ERROR_NONE;
trace.summary.status_code = 0;
trace.has_exec_result_ref = true;
trace.exec_result_ref = make_ref(0x03, r);
input_refs[0] = make_ref(0x10, i0);
input_refs[1] = make_ref(0x11, i1);
input_refs[2] = make_ref(0x12, i2);
trace.input_refs = input_refs;
trace.input_refs_len = 3;
trace.has_params_ref = false;
output_refs0[0] = make_ref(0x20, o0);
output_refs1[0] = make_ref(0x21, o1);
nodes[0].node_id = 1;
nodes[0].op_name = amduat_octets(add_name, strlen(add_name));
nodes[0].op_version = 1;
nodes[0].status = AMDUAT_PEL_NODE_TRACE_OK;
nodes[0].status_code = 0;
nodes[0].output_refs = output_refs0;
nodes[0].output_refs_len = 1;
nodes[0].diagnostics = NULL;
nodes[0].diagnostics_len = 0;
nodes[1].node_id = 2;
nodes[1].op_name = amduat_octets(mul_name, strlen(mul_name));
nodes[1].op_version = 1;
nodes[1].status = AMDUAT_PEL_NODE_TRACE_OK;
nodes[1].status_code = 0;
nodes[1].output_refs = output_refs1;
nodes[1].output_refs_len = 1;
nodes[1].diagnostics = NULL;
nodes[1].diagnostics_len = 0;
trace.node_traces = nodes;
trace.node_traces_len = 2;
if (!amduat_enc_pel_trace_dag_encode_v1(&trace, &encoded)) {
fprintf(stderr, "encode failed\n");
return exit_code;
}
if (!bytes_equal(encoded, k_expected_trace_bytes,
sizeof(k_expected_trace_bytes))) {
fprintf(stderr, "encoded bytes mismatch\n");
goto cleanup;
}
if (!amduat_enc_pel_trace_dag_decode_v1(encoded, &decoded)) {
fprintf(stderr, "decode failed\n");
goto cleanup;
}
if (decoded.node_traces_len != 2 || decoded.input_refs_len != 3) {
fprintf(stderr, "decoded lengths mismatch\n");
goto cleanup_decoded;
}
exit_code = 0;
cleanup_decoded:
amduat_enc_pel_trace_dag_free(&decoded);
cleanup:
free((void *)encoded.data);
return exit_code;
}
static int test_trace_unknown_hash_id(void) {
amduat_pel_trace_dag_value_t trace;
amduat_pel_node_trace_dag_t node;
amduat_reference_t input_refs[1];
amduat_octets_t encoded;
amduat_pel_trace_dag_value_t decoded;
uint8_t s[32], p[32], r[32], i0[5];
const char op_name[] = "noop";
int exit_code = 1;
memset(&trace, 0, sizeof(trace));
memset(&node, 0, sizeof(node));
trace.pel1_version = 1;
trace.scheme_ref = make_ref(0x01, s);
trace.program_ref = make_ref(0x02, p);
trace.status = AMDUAT_PEL_EXEC_STATUS_OK;
trace.summary.kind = AMDUAT_PEL_EXEC_ERROR_NONE;
trace.summary.status_code = 0;
trace.has_exec_result_ref = true;
trace.exec_result_ref = make_ref(0x03, r);
input_refs[0] = make_ref_custom(0x1234, i0, sizeof(i0), 0x5a);
trace.input_refs = input_refs;
trace.input_refs_len = 1;
trace.has_params_ref = false;
node.node_id = 0;
node.op_name = amduat_octets(op_name, strlen(op_name));
node.op_version = 1;
node.status = AMDUAT_PEL_NODE_TRACE_OK;
node.status_code = 0;
node.output_refs = NULL;
node.output_refs_len = 0;
node.diagnostics = NULL;
node.diagnostics_len = 0;
trace.node_traces = &node;
trace.node_traces_len = 1;
if (!amduat_enc_pel_trace_dag_encode_v1(&trace, &encoded)) {
fprintf(stderr, "encode failed (unknown hash)\n");
return exit_code;
}
if (!amduat_enc_pel_trace_dag_decode_v1(encoded, &decoded)) {
fprintf(stderr, "decode failed (unknown hash)\n");
goto cleanup;
}
if (decoded.input_refs_len != 1) {
fprintf(stderr, "decoded input count mismatch\n");
goto cleanup_decoded;
}
if (decoded.input_refs[0].hash_id != 0x1234 ||
decoded.input_refs[0].digest.len != sizeof(i0) ||
memcmp(decoded.input_refs[0].digest.data, i0, sizeof(i0)) != 0) {
fprintf(stderr, "decoded unknown hash mismatch\n");
goto cleanup_decoded;
}
exit_code = 0;
cleanup_decoded:
amduat_enc_pel_trace_dag_free(&decoded);
cleanup:
free((void *)encoded.data);
return exit_code;
}
2025-12-20 13:54:18 +01:00
int main(void) {
if (test_trace_encoding() != 0) {
return 1;
}
return test_trace_unknown_hash_id();
2025-12-20 13:54:18 +01:00
}