amduat/src/pel_stack/program_dag/program_dag.c

737 lines
21 KiB
C

#include "amduat/pel/program_dag.h"
#include "amduat/pel/opreg_kernel.h"
#include "amduat/pel/opreg_kernel_params.h"
#include "amduat/pel/program_dag_desc.h"
#include <stdbool.h>
#include <stddef.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
typedef struct {
size_t *order;
const amduat_pel_kernel_op_desc_t **ops;
amduat_pel_kernel_params_t *params;
} amduat_pel_program_dag_prepared_t;
static void amduat_prepared_reset(amduat_pel_program_dag_prepared_t *prepared) {
if (prepared == NULL) {
return;
}
prepared->order = NULL;
prepared->ops = NULL;
prepared->params = NULL;
}
static void amduat_prepared_free(amduat_pel_program_dag_prepared_t *prepared) {
if (prepared == NULL) {
return;
}
free(prepared->order);
free(prepared->ops);
free(prepared->params);
amduat_prepared_reset(prepared);
}
static void amduat_set_result(amduat_pel_execution_result_value_t *result,
amduat_pel_execution_status_t status,
amduat_pel_execution_error_kind_t kind,
uint32_t status_code) {
if (result == NULL) {
return;
}
result->pel1_version = 1;
result->status = status;
result->scheme_ref = amduat_pel_program_dag_scheme_ref();
result->summary.kind = kind;
result->summary.status_code = status_code;
result->diagnostics = NULL;
result->diagnostics_len = 0;
}
static bool amduat_utf8_is_valid(amduat_octets_t value) {
size_t i = 0;
while (i < value.len) {
uint8_t c = value.data[i];
if (c <= 0x7f) {
i += 1;
continue;
}
if ((c & 0xe0u) == 0xc0u) {
if (i + 1 >= value.len) {
return false;
}
if ((value.data[i + 1] & 0xc0u) != 0x80u) {
return false;
}
if (c < 0xc2u) {
return false;
}
i += 2;
continue;
}
if ((c & 0xf0u) == 0xe0u) {
if (i + 2 >= value.len) {
return false;
}
if ((value.data[i + 1] & 0xc0u) != 0x80u ||
(value.data[i + 2] & 0xc0u) != 0x80u) {
return false;
}
if (c == 0xe0u && value.data[i + 1] < 0xa0u) {
return false;
}
if (c == 0xedu && value.data[i + 1] >= 0xa0u) {
return false;
}
i += 3;
continue;
}
if ((c & 0xf8u) == 0xf0u) {
if (i + 3 >= value.len) {
return false;
}
if ((value.data[i + 1] & 0xc0u) != 0x80u ||
(value.data[i + 2] & 0xc0u) != 0x80u ||
(value.data[i + 3] & 0xc0u) != 0x80u) {
return false;
}
if (c == 0xf0u && value.data[i + 1] < 0x90u) {
return false;
}
if (c == 0xf4u && value.data[i + 1] >= 0x90u) {
return false;
}
if (c > 0xf4u) {
return false;
}
i += 4;
continue;
}
return false;
}
return true;
}
static int amduat_find_node_index(const amduat_pel_program_t *program,
amduat_pel_node_id_t node_id) {
size_t i;
for (i = 0; i < program->nodes_len; ++i) {
if (program->nodes[i].id == node_id) {
return (int)i;
}
}
return -1;
}
static bool amduat_build_node_order(const amduat_pel_program_t *program,
size_t *out_order,
bool *out_oom) {
size_t n;
size_t *deps;
bool *placed;
size_t i;
if (out_oom != NULL) {
*out_oom = false;
}
n = program->nodes_len;
deps = NULL;
placed = NULL;
if (n == 0) {
return true;
}
deps = (size_t *)calloc(n, sizeof(*deps));
placed = (bool *)calloc(n, sizeof(*placed));
if (deps == NULL || placed == NULL) {
if (out_oom != NULL) {
*out_oom = true;
}
free(deps);
free(placed);
return false;
}
for (i = 0; i < n; ++i) {
size_t j;
for (j = i + 1; j < n; ++j) {
if (program->nodes[i].id == program->nodes[j].id) {
free(deps);
free(placed);
return false;
}
}
}
for (i = 0; i < n; ++i) {
size_t j;
const amduat_pel_node_t *node;
node = &program->nodes[i];
for (j = 0; j < node->inputs_len; ++j) {
const amduat_pel_dag_input_t *input = &node->inputs[j];
if (input->kind == AMDUAT_PEL_DAG_INPUT_NODE) {
if (amduat_find_node_index(program, input->value.node.node_id) < 0) {
free(deps);
free(placed);
return false;
}
deps[i] += 1;
} else if (input->kind != AMDUAT_PEL_DAG_INPUT_EXTERNAL) {
free(deps);
free(placed);
return false;
}
}
}
for (i = 0; i < n; ++i) {
size_t best = SIZE_MAX;
size_t j;
amduat_pel_node_id_t best_id = 0;
for (j = 0; j < n; ++j) {
if (placed[j] || deps[j] != 0) {
continue;
}
if (best == SIZE_MAX || program->nodes[j].id < best_id) {
best = j;
best_id = program->nodes[j].id;
}
}
if (best == SIZE_MAX) {
free(deps);
free(placed);
return false;
}
if (out_order != NULL) {
out_order[i] = best;
}
placed[best] = true;
for (j = 0; j < n; ++j) {
size_t k;
const amduat_pel_node_t *node;
if (placed[j]) {
continue;
}
node = &program->nodes[j];
for (k = 0; k < node->inputs_len; ++k) {
const amduat_pel_dag_input_t *input = &node->inputs[k];
if (input->kind == AMDUAT_PEL_DAG_INPUT_NODE &&
input->value.node.node_id == program->nodes[best].id) {
if (deps[j] == 0) {
free(deps);
free(placed);
return false;
}
deps[j] -= 1;
}
}
}
}
free(deps);
free(placed);
return true;
}
static void amduat_node_results_free(
amduat_pel_program_dag_node_result_t *node_results,
size_t len) {
size_t i;
if (node_results == NULL) {
return;
}
for (i = 0; i < len; ++i) {
amduat_pel_program_dag_node_result_t *entry = &node_results[i];
size_t j;
for (j = 0; j < entry->outputs_len; ++j) {
free((void *)entry->outputs[j].bytes.data);
entry->outputs[j].bytes.data = NULL;
entry->outputs[j].bytes.len = 0;
}
free(entry->outputs);
entry->outputs = NULL;
entry->outputs_len = 0;
}
free(node_results);
}
static void amduat_trace_reset(amduat_pel_program_dag_trace_t *trace) {
if (trace == NULL) {
return;
}
trace->nodes = NULL;
trace->nodes_len = 0;
trace->order = NULL;
trace->any_node_executed = false;
}
void amduat_pel_program_dag_trace_free(amduat_pel_program_dag_trace_t *trace) {
if (trace == NULL) {
return;
}
amduat_node_results_free(trace->nodes, trace->nodes_len);
free(trace->order);
amduat_trace_reset(trace);
}
static bool amduat_copy_artifact(amduat_artifact_t *out,
const amduat_artifact_t *src) {
uint8_t *buffer = NULL;
if (out == NULL || src == NULL) {
return false;
}
if (src->bytes.len != 0) {
buffer = (uint8_t *)malloc(src->bytes.len);
if (buffer == NULL) {
return false;
}
if (src->bytes.data != NULL) {
memcpy(buffer, src->bytes.data, src->bytes.len);
}
}
out->bytes = amduat_octets(buffer, src->bytes.len);
out->has_type_tag = src->has_type_tag;
out->type_tag = src->type_tag;
return true;
}
typedef enum {
AMDUAT_PEL_PROGRAM_PREP_OK = 0,
AMDUAT_PEL_PROGRAM_PREP_INVALID = 1,
AMDUAT_PEL_PROGRAM_PREP_OOM = 2
} amduat_pel_program_prepare_result_t;
static amduat_pel_program_prepare_result_t amduat_program_prepare(
const amduat_pel_program_t *program,
amduat_pel_program_dag_prepared_t *prepared) {
size_t i;
bool oom;
if (program == NULL || prepared == NULL) {
return AMDUAT_PEL_PROGRAM_PREP_INVALID;
}
amduat_prepared_reset(prepared);
if (program->nodes_len > 0 && program->nodes == NULL) {
return AMDUAT_PEL_PROGRAM_PREP_INVALID;
}
if (program->roots_len > 0 && program->roots == NULL) {
return AMDUAT_PEL_PROGRAM_PREP_INVALID;
}
if (program->nodes_len == 0) {
if (program->roots_len != 0) {
return AMDUAT_PEL_PROGRAM_PREP_INVALID;
}
return AMDUAT_PEL_PROGRAM_PREP_OK;
}
prepared->order = (size_t *)malloc(program->nodes_len *
sizeof(*prepared->order));
prepared->ops = (const amduat_pel_kernel_op_desc_t **)calloc(
program->nodes_len, sizeof(*prepared->ops));
prepared->params = (amduat_pel_kernel_params_t *)calloc(
program->nodes_len, sizeof(*prepared->params));
if (prepared->order == NULL || prepared->ops == NULL ||
prepared->params == NULL) {
amduat_prepared_free(prepared);
return AMDUAT_PEL_PROGRAM_PREP_OOM;
}
for (i = 0; i < program->nodes_len; ++i) {
const amduat_pel_node_t *node = &program->nodes[i];
const amduat_pel_kernel_op_desc_t *desc;
if (node->op.name.len > 0 && node->op.name.data == NULL) {
amduat_prepared_free(prepared);
return AMDUAT_PEL_PROGRAM_PREP_INVALID;
}
if (!amduat_utf8_is_valid(node->op.name)) {
amduat_prepared_free(prepared);
return AMDUAT_PEL_PROGRAM_PREP_INVALID;
}
if (node->inputs_len > 0 && node->inputs == NULL) {
amduat_prepared_free(prepared);
return AMDUAT_PEL_PROGRAM_PREP_INVALID;
}
if (node->params.len > 0 && node->params.data == NULL) {
amduat_prepared_free(prepared);
return AMDUAT_PEL_PROGRAM_PREP_INVALID;
}
desc = amduat_pel_kernel_op_lookup(node->op.name, node->op.version);
if (desc == NULL) {
amduat_prepared_free(prepared);
return AMDUAT_PEL_PROGRAM_PREP_INVALID;
}
if (node->inputs_len < desc->min_inputs ||
node->inputs_len > desc->max_inputs) {
amduat_prepared_free(prepared);
return AMDUAT_PEL_PROGRAM_PREP_INVALID;
}
if (!amduat_pel_kernel_params_decode(desc, node->params,
&prepared->params[i])) {
amduat_prepared_free(prepared);
return AMDUAT_PEL_PROGRAM_PREP_INVALID;
}
prepared->ops[i] = desc;
}
oom = false;
if (!amduat_build_node_order(program, prepared->order, &oom)) {
amduat_prepared_free(prepared);
return oom ? AMDUAT_PEL_PROGRAM_PREP_OOM
: AMDUAT_PEL_PROGRAM_PREP_INVALID;
}
for (i = 0; i < program->nodes_len; ++i) {
const amduat_pel_node_t *node = &program->nodes[i];
size_t j;
for (j = 0; j < node->inputs_len; ++j) {
const amduat_pel_dag_input_t *input = &node->inputs[j];
if (input->kind == AMDUAT_PEL_DAG_INPUT_NODE) {
int dep_index = amduat_find_node_index(program,
input->value.node.node_id);
if (dep_index < 0) {
amduat_prepared_free(prepared);
return AMDUAT_PEL_PROGRAM_PREP_INVALID;
}
if (input->value.node.output_index >=
prepared->ops[dep_index]->outputs_len) {
amduat_prepared_free(prepared);
return AMDUAT_PEL_PROGRAM_PREP_INVALID;
}
} else if (input->kind != AMDUAT_PEL_DAG_INPUT_EXTERNAL) {
amduat_prepared_free(prepared);
return AMDUAT_PEL_PROGRAM_PREP_INVALID;
}
}
}
for (i = 0; i < program->roots_len; ++i) {
int root_index = amduat_find_node_index(program, program->roots[i].node_id);
if (root_index < 0) {
amduat_prepared_free(prepared);
return AMDUAT_PEL_PROGRAM_PREP_INVALID;
}
if (program->roots[i].output_index >=
prepared->ops[root_index]->outputs_len) {
amduat_prepared_free(prepared);
return AMDUAT_PEL_PROGRAM_PREP_INVALID;
}
}
return AMDUAT_PEL_PROGRAM_PREP_OK;
}
bool amduat_pel_program_dag_validate(const amduat_pel_program_t *program) {
amduat_pel_program_dag_prepared_t prepared;
amduat_pel_program_prepare_result_t prep_result;
amduat_prepared_reset(&prepared);
prep_result = amduat_program_prepare(program, &prepared);
amduat_prepared_free(&prepared);
return prep_result == AMDUAT_PEL_PROGRAM_PREP_OK;
}
static bool amduat_pel_program_dag_exec_internal(
const amduat_pel_program_t *program,
const amduat_artifact_t *inputs,
size_t inputs_len,
amduat_artifact_t **out_outputs,
size_t *out_outputs_len,
amduat_pel_execution_result_value_t *out_result,
amduat_pel_program_dag_trace_t *out_trace) {
amduat_pel_program_dag_prepared_t prepared;
amduat_pel_program_prepare_result_t prep_result;
amduat_pel_program_dag_node_result_t *node_results;
amduat_artifact_t *resolved_inputs;
size_t max_inputs;
size_t i;
bool any_node_executed = false;
bool wants_trace = out_trace != NULL;
if (out_outputs == NULL || out_outputs_len == NULL || out_result == NULL) {
return false;
}
*out_outputs = NULL;
*out_outputs_len = 0;
if (wants_trace) {
amduat_trace_reset(out_trace);
}
amduat_prepared_reset(&prepared);
prep_result = amduat_program_prepare(program, &prepared);
if (prep_result == AMDUAT_PEL_PROGRAM_PREP_OOM) {
return false;
}
if (prep_result != AMDUAT_PEL_PROGRAM_PREP_OK) {
amduat_set_result(out_result, AMDUAT_PEL_EXEC_STATUS_INVALID_PROGRAM,
AMDUAT_PEL_EXEC_ERROR_PROGRAM, 2);
amduat_prepared_free(&prepared);
return true;
}
if (inputs_len > 0 && inputs == NULL) {
bool needs_inputs = false;
for (i = 0; i < program->nodes_len; ++i) {
size_t j;
const amduat_pel_node_t *node = &program->nodes[i];
for (j = 0; j < node->inputs_len; ++j) {
if (node->inputs[j].kind == AMDUAT_PEL_DAG_INPUT_EXTERNAL) {
needs_inputs = true;
break;
}
}
if (needs_inputs) {
break;
}
}
if (needs_inputs) {
amduat_set_result(out_result, AMDUAT_PEL_EXEC_STATUS_INVALID_INPUTS,
AMDUAT_PEL_EXEC_ERROR_INPUTS, 3);
amduat_prepared_free(&prepared);
return true;
}
}
if (program->nodes_len == 0) {
amduat_set_result(out_result, AMDUAT_PEL_EXEC_STATUS_OK,
AMDUAT_PEL_EXEC_ERROR_NONE, 0);
amduat_prepared_free(&prepared);
return true;
}
node_results = (amduat_pel_program_dag_node_result_t *)calloc(
program->nodes_len, sizeof(*node_results));
if (node_results == NULL) {
amduat_prepared_free(&prepared);
return false;
}
for (i = 0; i < program->nodes_len; ++i) {
node_results[i].status = AMDUAT_PEL_NODE_TRACE_SKIPPED;
node_results[i].status_code = 0;
}
max_inputs = 0;
for (i = 0; i < program->nodes_len; ++i) {
if (program->nodes[i].inputs_len > max_inputs) {
max_inputs = program->nodes[i].inputs_len;
}
}
resolved_inputs = NULL;
if (max_inputs != 0) {
resolved_inputs = (amduat_artifact_t *)malloc(
max_inputs * sizeof(*resolved_inputs));
if (resolved_inputs == NULL) {
amduat_node_results_free(node_results, program->nodes_len);
amduat_prepared_free(&prepared);
return false;
}
}
for (i = 0; i < program->nodes_len; ++i) {
size_t node_index = prepared.order[i];
const amduat_pel_node_t *node = &program->nodes[node_index];
size_t j;
uint32_t status_code = 0;
for (j = 0; j < node->inputs_len; ++j) {
const amduat_pel_dag_input_t *input = &node->inputs[j];
if (input->kind == AMDUAT_PEL_DAG_INPUT_EXTERNAL) {
if (input->value.external.input_index >= inputs_len) {
amduat_set_result(out_result, AMDUAT_PEL_EXEC_STATUS_INVALID_INPUTS,
AMDUAT_PEL_EXEC_ERROR_INPUTS, 3);
free(resolved_inputs);
goto finish;
}
resolved_inputs[j] = inputs[input->value.external.input_index];
} else if (input->kind == AMDUAT_PEL_DAG_INPUT_NODE) {
int dep_index = amduat_find_node_index(program,
input->value.node.node_id);
if (dep_index < 0 ||
input->value.node.output_index >=
node_results[dep_index].outputs_len) {
amduat_set_result(out_result, AMDUAT_PEL_EXEC_STATUS_INVALID_PROGRAM,
AMDUAT_PEL_EXEC_ERROR_PROGRAM, 2);
free(resolved_inputs);
goto finish;
}
resolved_inputs[j] =
node_results[dep_index].outputs[input->value.node.output_index];
} else {
amduat_set_result(out_result, AMDUAT_PEL_EXEC_STATUS_INVALID_PROGRAM,
AMDUAT_PEL_EXEC_ERROR_PROGRAM, 2);
free(resolved_inputs);
goto finish;
}
}
if (!amduat_pel_kernel_op_eval(
prepared.ops[node_index], resolved_inputs, node->inputs_len,
&prepared.params[node_index], &node_results[node_index].outputs,
&node_results[node_index].outputs_len, &status_code)) {
if (status_code == 2 || status_code == 3 || status_code == 0) {
status_code = 1;
}
node_results[node_index].status = AMDUAT_PEL_NODE_TRACE_FAILED;
node_results[node_index].status_code = status_code;
any_node_executed = true;
amduat_set_result(out_result, AMDUAT_PEL_EXEC_STATUS_RUNTIME_FAILED,
AMDUAT_PEL_EXEC_ERROR_RUNTIME, status_code);
free(resolved_inputs);
goto finish;
}
node_results[node_index].status = AMDUAT_PEL_NODE_TRACE_OK;
node_results[node_index].status_code = 0;
any_node_executed = true;
}
if (program->roots_len != 0) {
*out_outputs = (amduat_artifact_t *)calloc(
program->roots_len, sizeof(**out_outputs));
if (*out_outputs == NULL) {
free(resolved_inputs);
goto oom_finish;
}
}
for (i = 0; i < program->roots_len; ++i) {
int root_index = amduat_find_node_index(program, program->roots[i].node_id);
if (root_index < 0 ||
program->roots[i].output_index >=
node_results[root_index].outputs_len) {
amduat_set_result(out_result, AMDUAT_PEL_EXEC_STATUS_INVALID_PROGRAM,
AMDUAT_PEL_EXEC_ERROR_PROGRAM, 2);
amduat_pel_program_dag_free_outputs(*out_outputs,
program->roots_len);
*out_outputs = NULL;
*out_outputs_len = 0;
free(resolved_inputs);
goto finish;
}
if (!amduat_copy_artifact(
&(*out_outputs)[i],
&node_results[root_index]
.outputs[program->roots[i].output_index])) {
amduat_pel_program_dag_free_outputs(*out_outputs,
program->roots_len);
*out_outputs = NULL;
*out_outputs_len = 0;
free(resolved_inputs);
goto oom_finish;
}
}
*out_outputs_len = program->roots_len;
amduat_set_result(out_result, AMDUAT_PEL_EXEC_STATUS_OK,
AMDUAT_PEL_EXEC_ERROR_NONE, 0);
free(resolved_inputs);
finish:
if (wants_trace) {
out_trace->any_node_executed = any_node_executed;
if (any_node_executed) {
out_trace->nodes = node_results;
out_trace->nodes_len = program->nodes_len;
out_trace->order = prepared.order;
prepared.order = NULL;
} else {
amduat_node_results_free(node_results, program->nodes_len);
node_results = NULL;
}
} else {
amduat_node_results_free(node_results, program->nodes_len);
node_results = NULL;
}
amduat_prepared_free(&prepared);
return true;
oom_finish:
if (wants_trace) {
out_trace->any_node_executed = any_node_executed;
if (any_node_executed) {
amduat_node_results_free(node_results, program->nodes_len);
node_results = NULL;
} else {
amduat_node_results_free(node_results, program->nodes_len);
node_results = NULL;
}
} else {
amduat_node_results_free(node_results, program->nodes_len);
node_results = NULL;
}
amduat_prepared_free(&prepared);
return false;
}
bool amduat_pel_program_dag_exec(
const amduat_pel_program_t *program,
const amduat_artifact_t *inputs,
size_t inputs_len,
amduat_artifact_t **out_outputs,
size_t *out_outputs_len,
amduat_pel_execution_result_value_t *out_result) {
return amduat_pel_program_dag_exec_internal(
program, inputs, inputs_len, out_outputs, out_outputs_len, out_result,
NULL);
}
bool amduat_pel_program_dag_exec_trace(
const amduat_pel_program_t *program,
const amduat_artifact_t *inputs,
size_t inputs_len,
amduat_artifact_t **out_outputs,
size_t *out_outputs_len,
amduat_pel_execution_result_value_t *out_result,
amduat_pel_program_dag_trace_t *out_trace) {
if (out_trace == NULL) {
return false;
}
return amduat_pel_program_dag_exec_internal(
program, inputs, inputs_len, out_outputs, out_outputs_len, out_result,
out_trace);
}
void amduat_pel_program_dag_free_outputs(amduat_artifact_t *outputs,
size_t outputs_len) {
size_t i;
if (outputs == NULL) {
return;
}
for (i = 0; i < outputs_len; ++i) {
free((void *)outputs[i].bytes.data);
outputs[i].bytes.data = NULL;
outputs[i].bytes.len = 0;
}
free(outputs);
}