The Perl Toolchain Summit needs more sponsors. If your company depends on Perl, please support this very important event.
#include "ccv_nnc.h"
#include "ccv_nnc_easy.h"
#include "ccv_nnc_internal.h"
#include "ccv_internal.h"
#ifdef HAVE_CUDA
#include "gpu/ccv_nnc_compat.h"
#endif
#include "_ccv_nnc_symbolic_graph.h"

/**
 * Level-4 API
 */

ccv_nnc_graph_exec_symbol_t ccv_nnc_symbolic_graph_while(ccv_nnc_symbolic_graph_t* const graph, ccv_nnc_symbolic_graph_t* const while_graph, const char* const name)
{
	assert(while_graph->p == 0);
	assert(while_graph->p_idx == 0);
	ccv_nnc_cmd_t cmd = ccv_nnc_cmd(CCV_NNC_GRAPH_FORWARD, 0, CMD_GENERIC(), 0);
	// Added one more symbol.
	ccv_nnc_graph_exec_symbol_t symbol = ccv_nnc_graph_exec_symbol_new(graph, cmd, 0, 0, 0, 0, name);
	// Assigning graph_ref to it.
	if (!graph->sub_graphs)
		graph->sub_graphs = ccv_array_new(sizeof(ccv_nnc_symbolic_graph_t*), 1, 0);
	ccv_array_push(graph->sub_graphs, &while_graph);
	ccv_nnc_graph_exec_symbol_info_t* symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, symbol.d);
	// Note the extra allocation (the ccv_array_t only holds a pointer to ccv_nnc_symbolic_graph_t*).
	// In this way, we can get the while graph and don't have to worry about it will be an invalid pointer once
	// the array expands (another while graph allocated).
	symbol_info->graph_ref = graph->sub_graphs->rnum;
	while_graph->p_idx = graph->sub_graphs->rnum;
	while_graph->exec_idx = symbol.d + 1;
	while_graph->p = graph;
	return symbol;
}

void ccv_nnc_symbolic_graph_set_while_expr(ccv_nnc_symbolic_graph_t* const while_graph, const ccv_nnc_graph_while_f while_expr, const void* const while_data, const ccv_nnc_graph_exec_symbol_t* const breakpoints, const int breakpoint_size)
{
	while_graph->while_expr = while_expr;
	while_graph->while_data = while_data;
	if (breakpoint_size > 0)
	{
		assert(breakpoints);
		while_graph->breakpoint_size = breakpoint_size;
		while_graph->breakpoints = (ccv_nnc_graph_exec_symbol_t*)ccmalloc(sizeof(ccv_nnc_graph_exec_symbol_t) * breakpoint_size);
		memcpy(while_graph->breakpoints, breakpoints, sizeof(ccv_nnc_graph_exec_symbol_t) * breakpoint_size);
	}
}

ccv_nnc_tensor_symbol_t ccv_nnc_find_tensor_symbol_from_graph(const ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_tensor_symbol_t symbol)
{
	if (symbol.graph == graph)
		return symbol;
	const ccv_nnc_symbolic_graph_t* curr_graph = symbol.graph;
	ccv_nnc_tensor_symbol_info_t* const symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(symbol.graph->tensor_symbol_info, symbol.d);
	assert(symbol.d >= 0 && symbol.d < curr_graph->tensor_symbol_info->rnum);
	while (curr_graph && curr_graph != graph)
		curr_graph = curr_graph->p;
	if (curr_graph)
	{
		curr_graph = symbol.graph;
		ccv_nnc_tensor_symbol_info_t* curr_symbol_info = symbol_info;
		ccv_nnc_tensor_symbol_t curr_symbol = symbol;
		while (curr_graph != graph)
		{
			ccv_nnc_symbolic_graph_t* const p = curr_graph->p;
			// I need to find the symbol, it must exist.
			assert(curr_symbol_info->p_ref);
			// Move on.
			curr_symbol.d = curr_symbol_info->p_ref - 1;
			curr_symbol.graph = p;
			assert(curr_symbol.d >= 0 && curr_symbol.d < p->tensor_symbol_info->rnum);
			curr_symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(p->tensor_symbol_info, curr_symbol.d);
			curr_graph = p;
		}
		return curr_symbol;
	}
	// Otherwise, if the symbol is in the parent graph, this is a bit more expensive because I need to keep a trace stack.
	curr_graph = graph;
	ccv_array_t* trace = ccv_array_new(sizeof(int), 0, 0);
	while (curr_graph && curr_graph != symbol.graph)
	{
		const int p_idx = curr_graph->p_idx - 1;
		ccv_array_push(trace, &p_idx);
		curr_graph = curr_graph->p;
	}
	// If it is not in both the parent graph and the sub-graph, the input is invalid.
	assert(curr_graph);
	curr_graph = symbol.graph;
	ccv_nnc_tensor_symbol_info_t* curr_symbol_info = symbol_info;
	ccv_nnc_tensor_symbol_t curr_symbol = symbol;
	// The graph is a sub graph of the symbol passed in.
	int i;
	for (i = trace->rnum - 1; i >= 0; i--)
	{
		const int p_idx = *(int*)ccv_array_get(trace, i);
		assert(p_idx >= 0);
		assert(curr_graph->sub_graphs);
		assert(curr_symbol_info->s_ref);
		assert(p_idx >= 0 && p_idx < curr_symbol_info->s_ref->rnum);
		const int s_idx = *(int*)ccv_array_get(curr_symbol_info->s_ref, p_idx);
		ccv_nnc_symbolic_graph_t* const s = *(ccv_nnc_symbolic_graph_t**)ccv_array_get(curr_graph->sub_graphs, p_idx);
		// I need to find the symbol, it must exist.
		assert(s_idx);
		curr_symbol.d = s_idx - 1;
		curr_symbol.graph = s;
		assert(curr_symbol.d >= 0 && curr_symbol.d < s->tensor_symbol_info->rnum);
		curr_symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(s->tensor_symbol_info, curr_symbol.d);
		// Move on.
		curr_graph = s;
	}
	ccv_array_free(trace);
	return curr_symbol;
}

void ccv_nnc_symbolic_graph_set_while_params(ccv_nnc_symbolic_graph_t* const while_graph, const ccv_nnc_tensor_symbol_map_t* const symbol_map, const int symbol_map_size)
{
	int i;
	for (i = 0; i < symbol_map_size; i++)
	{
		const ccv_nnc_tensor_symbol_t source = ccv_nnc_find_tensor_symbol_from_graph(while_graph, symbol_map[i].source);
		const ccv_nnc_tensor_symbol_t destination = ccv_nnc_find_tensor_symbol_from_graph(while_graph, symbol_map[i].destination);
		assert(source.graph == while_graph);
		assert(destination.graph == while_graph);
		assert(source.d < while_graph->tensor_symbol_info->rnum);
		assert(destination.d < while_graph->tensor_symbol_info->rnum);
		ccv_nnc_tensor_symbol_info_t* destination_tensor_symbol_info = (ccv_nnc_tensor_symbol_info_t*)ccv_array_get(while_graph->tensor_symbol_info, destination.d);
		// Don't support parameterize with alias. The reason is that to support parameterized loop (for SSA), I choose
		// to simply reuse the piece of memory (allocating the same memory region to both, therefore to enable parameter
		// passing). For alias, it is not possible because alias can pointing to the tensors with different sizes, thus,
		// these pointed tensors cannot share the same memory region. The best way for alias to be parameterized is to
		// create a new tensor of the same size, transfer value over, and parameterized on that tensor instead.
		assert(!destination_tensor_symbol_info->alias_ref);
		assert(!((ccv_nnc_tensor_symbol_info_t*)ccv_array_get(while_graph->tensor_symbol_info, source.d))->alias_ref);
		destination_tensor_symbol_info->assign_ref = source.d + 1;
	}
}

ccv_nnc_symbolic_graph_t* ccv_nnc_symbolic_graph_from_while_symbol(const ccv_nnc_symbolic_graph_t* const graph, const ccv_nnc_graph_exec_symbol_t while_symbol)
{
	assert(graph->sub_graphs);
	assert(while_symbol.graph == graph);
	assert(while_symbol.d < graph->exec_symbol_info->rnum);
	ccv_nnc_graph_exec_symbol_info_t* symbol_info = (ccv_nnc_graph_exec_symbol_info_t*)ccv_array_get(graph->exec_symbol_info, while_symbol.d);
	assert(symbol_info->graph_ref <= graph->sub_graphs->rnum);
	return *(ccv_nnc_symbolic_graph_t**)ccv_array_get(graph->sub_graphs, symbol_info->graph_ref - 1);
}