The Perl Toolchain Summit needs more sponsors. If your company depends on Perl, please support this very important event.
package org.maltparser.core.syntaxgraph.writer;

import java.io.BufferedWriter;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.UnsupportedEncodingException;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.SortedMap;
import java.util.TreeMap;
import java.util.regex.PatternSyntaxException;

import org.maltparser.core.exception.MaltChainedException;
import org.maltparser.core.io.dataformat.ColumnDescription;
import org.maltparser.core.io.dataformat.DataFormatException;
import org.maltparser.core.io.dataformat.DataFormatInstance;
import org.maltparser.core.syntaxgraph.PhraseStructure;
import org.maltparser.core.syntaxgraph.TokenStructure;
import org.maltparser.core.syntaxgraph.edge.Edge;
import org.maltparser.core.syntaxgraph.node.NonTerminalNode;
import org.maltparser.core.syntaxgraph.node.PhraseStructureNode;
import org.maltparser.ml.libsvm.LibsvmException;
/**
*
*
* @author Johan Hall
*/
public class NegraWriter implements SyntaxGraphWriter {
	private BufferedWriter writer; 
	private DataFormatInstance dataFormatInstance;
	private String optionString;
	private int sentenceCount;
	private LinkedHashMap<Integer, Integer> nonTerminalIndexMap;
	private int START_ID_OF_NONTERMINALS = 500;
	private boolean closeStream = true;
	
	public NegraWriter() { 
		nonTerminalIndexMap = new LinkedHashMap<Integer, Integer>();
	}
	
	public void open(String fileName, String charsetName) throws MaltChainedException {
		try {
			open(new OutputStreamWriter(new FileOutputStream(fileName),charsetName));
		} catch (FileNotFoundException e) {
			throw new DataFormatException("The output file '"+fileName+"' cannot be found.", e);
		} catch (UnsupportedEncodingException e) {
			throw new DataFormatException("The character encoding set '"+charsetName+"' isn't supported.", e);
		}	
	}
	
	public void open(OutputStream os, String charsetName) throws MaltChainedException {
		try {
			if (os == System.out || os == System.err) {
				closeStream = false;
			}
			open(new OutputStreamWriter(os, charsetName));
		} catch (UnsupportedEncodingException e) {
			throw new DataFormatException("The character encoding set '"+charsetName+"' isn't supported.", e);
		}
	}
	
	private void open(OutputStreamWriter osw) throws MaltChainedException {
		setWriter(new BufferedWriter(osw));
		setSentenceCount(0);
	}
	
	public void writeProlog() throws MaltChainedException { }
	
	public void writeSentence(TokenStructure syntaxGraph) throws MaltChainedException {
		if (syntaxGraph == null || dataFormatInstance == null || !(syntaxGraph instanceof PhraseStructure) || !syntaxGraph.hasTokens()) {
			return;
		}
		PhraseStructure phraseStructure = (PhraseStructure)syntaxGraph;
		sentenceCount++;
		try {
			writer.write("#BOS ");
			if (phraseStructure.getSentenceID() != 0) {
				writer.write(Integer.toString(phraseStructure.getSentenceID()));
			} else {
				writer.write(Integer.toString(sentenceCount));
			}
			writer.write('\n');

			if (phraseStructure.hasNonTerminals()) {
				calculateIndices(phraseStructure);
				writeTerminals(phraseStructure);
				writeNonTerminals(phraseStructure);
			} else {
				writeTerminals(phraseStructure);
			}
			writer.write("#EOS ");
			if (phraseStructure.getSentenceID() != 0) {
				writer.write(Integer.toString(phraseStructure.getSentenceID()));
			} else {
				writer.write(Integer.toString(sentenceCount));
			}
			writer.write('\n');
		} catch (IOException e) {
			throw new DataFormatException("Could not write to the output file. ", e);
		}
	}
	public void writeEpilog() throws MaltChainedException { }
	

	private void calculateIndices(PhraseStructure phraseStructure) throws MaltChainedException {
		final SortedMap<Integer,Integer> heights = new TreeMap<Integer,Integer>();
		for (int index : phraseStructure.getNonTerminalIndices()) {
			heights.put(index, ((NonTerminalNode)phraseStructure.getNonTerminalNode(index)).getHeight());
		}
		
		boolean done = false;
		int h = 1;
		int ntid = START_ID_OF_NONTERMINALS;
		nonTerminalIndexMap.clear();
		while (!done) {
			done = true;
			for (int index : phraseStructure.getNonTerminalIndices()) {
				if (heights.get(index) == h) {
					NonTerminalNode nt = (NonTerminalNode)phraseStructure.getNonTerminalNode(index);
					nonTerminalIndexMap.put(nt.getIndex(), ntid++);
//					nonTerminalIndexMap.put(nt.getIndex(), nt.getIndex()+START_ID_OF_NONTERMINALS-1);
					done = false;
				}
			}
			h++;
		}
		
//		boolean done = false;
//		int h = 1;
////		int ntid = START_ID_OF_NONTERMINALS;
////		nonTerminalIndexMap.clear();
//		while (!done) {
//			done = true;
//			for (int index : phraseStructure.getNonTerminalIndices()) {
//				if (heights.get(index) == h) {
//					NonTerminalNode nt = (NonTerminalNode)phraseStructure.getNonTerminalNode(index);
////					nonTerminalIndexMap.put(nt.getIndex(), ntid++);
//					nonTerminalIndexMap.put(nt.getIndex(), nt.getIndex()+START_ID_OF_NONTERMINALS-1);
//					done = false;
//				}
//			}
//			h++;
//		}
	}
	
	private void writeTerminals(PhraseStructure phraseStructure) throws MaltChainedException {
		try {
			for (int index : phraseStructure.getTokenIndices()) {
				final PhraseStructureNode terminal = phraseStructure.getTokenNode(index);
				final Iterator<ColumnDescription> columns = dataFormatInstance.iterator();
				ColumnDescription column = null;
				int ti = 1;
				while (columns.hasNext()) {
					column = columns.next();
					if (column.getCategory() == ColumnDescription.INPUT) {
						writer.write(terminal.getLabelSymbol(column.getSymbolTable()));
						int nTabs = 1;
						if (ti == 1 || ti == 2) {
							nTabs = 3 - (terminal.getLabelSymbol(column.getSymbolTable()).length() / 8);
						} else if (ti == 3) {
							nTabs = 1;
						} else if (ti == 4) {
							nTabs = 2 - (terminal.getLabelSymbol(column.getSymbolTable()).length() / 8);
						}
						if (nTabs < 1) {
							nTabs = 1;
						}
						for (int j = 0; j < nTabs; j++) {
							writer.write('\t');
						}
						ti++;
					} else if (column.getCategory() == ColumnDescription.PHRASE_STRUCTURE_EDGE_LABEL) {
						if (terminal.getParent() != null && terminal.hasParentEdgeLabel(column.getSymbolTable())) {
							writer.write(terminal.getParentEdgeLabelSymbol(column.getSymbolTable()));
							writer.write('\t');
						} else {
							writer.write("--\t");
						}
					} else if (column.getCategory() == ColumnDescription.PHRASE_STRUCTURE_NODE_LABEL) { 
						if (terminal.getParent() == null || terminal.getParent() == phraseStructure.getPhraseStructureRoot()) {
							writer.write('0');
						} else {
							writer.write(Integer.toString(nonTerminalIndexMap.get(terminal.getParent().getIndex())));
//							writer.write(Integer.toString(terminal.getParent().getIndex()+START_ID_OF_NONTERMINALS-1));
						}
					}
				}
				for (Edge e : terminal.getIncomingSecondaryEdges()) {
					if (e.hasLabel(column.getSymbolTable())) {
						writer.write('\t');
						writer.write(e.getLabelSymbol(column.getSymbolTable()));
						writer.write('\t');
						if (e.getSource() instanceof NonTerminalNode) {
							writer.write(Integer.toString(nonTerminalIndexMap.get(e.getSource().getIndex())));
//							writer.write(Integer.toString(e.getSource().getIndex()+START_ID_OF_NONTERMINALS-1));
						} else {
							writer.write(Integer.toString(e.getSource().getIndex()));
						}
					}
				}
				writer.write("\n");
			}

		} catch (IOException e) {
			throw new DataFormatException("The Negra writer is not able to write. ", e);
		}
	}
	
	private void writeNonTerminals(PhraseStructure phraseStructure) throws MaltChainedException {
		for (int index : nonTerminalIndexMap.keySet()) {
//		for (int index : phraseStructure.getNonTerminalIndices()) {
			NonTerminalNode nonTerminal = (NonTerminalNode)phraseStructure.getNonTerminalNode(index);
	
			if (nonTerminal == null || nonTerminal.isRoot()) {
				return;
			}
			try {
				writer.write('#');
//				writer.write(Integer.toString(index+START_ID_OF_NONTERMINALS-1));
				writer.write(Integer.toString(nonTerminalIndexMap.get(index)));
				writer.write("\t\t\t--\t\t\t");
				if (nonTerminal.hasLabel(dataFormatInstance.getColumnDescriptionByName("CAT").getSymbolTable())) {
					writer.write(nonTerminal.getLabelSymbol(dataFormatInstance.getColumnDescriptionByName("CAT").getSymbolTable()));
				} else {
					writer.write("--");
				}
				writer.write("\t--\t\t");
				if (nonTerminal.hasParentEdgeLabel(dataFormatInstance.getColumnDescriptionByName("LABEL").getSymbolTable())) {
					writer.write(nonTerminal.getParentEdgeLabelSymbol(dataFormatInstance.getColumnDescriptionByName("LABEL").getSymbolTable()));
				} else {
					writer.write("--");
				}
				writer.write('\t');
				if (nonTerminal.getParent() == null || nonTerminal.getParent().isRoot()) {
					writer.write('0');
				} else {
//					writer.write(Integer.toString(nonTerminal.getParent().getIndex()+START_ID_OF_NONTERMINALS-1));
					writer.write(Integer.toString(nonTerminalIndexMap.get(nonTerminal.getParent().getIndex())));
				}
				for (Edge e : nonTerminal.getIncomingSecondaryEdges()) {
					if (e.hasLabel(dataFormatInstance.getColumnDescriptionByName("SECEDGELABEL").getSymbolTable())) {
						writer.write('\t');
						writer.write(e.getLabelSymbol(dataFormatInstance.getColumnDescriptionByName("SECEDGELABEL").getSymbolTable()));
						writer.write('\t');
						if (e.getSource() instanceof NonTerminalNode) {
//							writer.write(Integer.toString(e.getSource().getIndex()+START_ID_OF_NONTERMINALS-1));
							writer.write(Integer.toString(nonTerminalIndexMap.get(e.getSource().getIndex())));
						} else {
							writer.write(Integer.toString(e.getSource().getIndex()));
						}
					}
				}
				writer.write("\n");
			} catch (IOException e) {
				throw new DataFormatException("The Negra writer is not able to write the non-terminals. ", e);
			}
		}
	}
	
	public BufferedWriter getWriter() {
		return writer;
	}

	public void setWriter(BufferedWriter writer) {
		this.writer = writer;
	}
	
	public int getSentenceCount() {
		return sentenceCount;
	}

	public void setSentenceCount(int sentenceCount) {
		this.sentenceCount = sentenceCount;
	}
	
	public DataFormatInstance getDataFormatInstance() {
		return dataFormatInstance;
	}

	public void setDataFormatInstance(DataFormatInstance dataFormatInstance) {
		this.dataFormatInstance = dataFormatInstance;
	}

	public String getOptions() {
		return optionString;
	}
	
	public void setOptions(String optionString) throws MaltChainedException {
		this.optionString = optionString;
		String[] argv;
		try {
			argv = optionString.split("[_\\p{Blank}]");
		} catch (PatternSyntaxException e) {
			throw new DataFormatException("Could not split the penn writer option '"+optionString+"'. ", e);
		}
		for (int i=0; i < argv.length-1; i++) {
			if(argv[i].charAt(0) != '-') {
				throw new DataFormatException("The argument flag should start with the following character '-', not with "+argv[i].charAt(0));
			}
			if(++i>=argv.length) {
				throw new DataFormatException("The last argument does not have any value. ");
			}
			switch(argv[i-1].charAt(1)) {
			case 's': 
				try {
					START_ID_OF_NONTERMINALS = Integer.parseInt(argv[i]);
				} catch (NumberFormatException e){
					throw new MaltChainedException("The TigerXML Reader option -s must be an integer value. ");
				}
				break;
			default:
				throw new LibsvmException("Unknown svm parameter: '"+argv[i-1]+"' with value '"+argv[i]+"'. ");		
			}
		}	
	}
	
	public void close() throws MaltChainedException {
		try {
			if (writer != null) {
				writer.flush();
				if (closeStream) {
					writer.close();
				}
				writer = null;
			}
		}   catch (IOException e) {
			throw new DataFormatException("Could not close the output file. ", e);
		} 
	}
}