Source code for kindred.evalFunctions


import kindred
from collections import Counter

[docs]def evaluate(goldCorpus,testCorpus,metric='f1score',display=False): """ Compares the gold corpus with the test corpus and calculate appropriate metrics. :param goldCorpus: The gold standard set of data :param testCorpus: The test set for comparison :param metric: Which metric to use (precision/recall/f1score). 'all' will provide all three as a tuple :param display: Whether to print (to stdout) specific statistics for each relation type :type goldCorpus: kindred.Corpus :type testCorpus: kindred.Corpus :type metric: str :type display: bool :return: The value of the corresponding metric (or metrics) :rtype: float (or tuple of floats) """ assert isinstance(goldCorpus,kindred.Corpus) assert isinstance(testCorpus,kindred.Corpus) mismatchMessage = "Mismatch between the corpora. Expected the same documents in each corpus with only annotations differing" assert len(goldCorpus.documents) == len(testCorpus.documents), mismatchMessage for d1,d2 in zip(goldCorpus.documents, testCorpus.documents): assert d1.text == d2.text, mismatchMessage TPs,FPs,FNs = Counter(),Counter(),Counter() goldTuples = [ (r.relationType,tuple(r.entities)) for r in goldCorpus.getRelations() ] testTuples = [ (r.relationType,tuple(r.entities)) for r in testCorpus.getRelations() ] totalSet = set(goldTuples + testTuples) for relation in totalSet: inGold = relation in goldTuples inTest = relation in testTuples relType = relation[0] if inGold and inTest: TPs[relType] += 1 elif inGold: FNs[relType] += 1 elif inTest: FPs[relType] += 1 sortedRelTypes = sorted( list(set( [relation[0] for relation in totalSet] ))) maxLen = max( [len(rt) for rt in sortedRelTypes ] ) formatString = '%-' + str(maxLen) + 's\tTP:%d FP:%d FN:%d\tP:%f R:%f F1:%f' for relType in sortedRelTypes: TP,FP,FN = TPs[relType],FPs[relType],FNs[relType] precision = 0.0 if (TP+FP) == 0 else TP / float(TP+FP) recall = 0.0 if (TP+FN) == 0 else TP / float(TP+FN) f1score = 0.0 if precision==0 or recall == 0 else 2 * (precision*recall) / (precision+recall) if display: print(formatString % (relType,TP,FP,FN,precision,recall,f1score)) TP,FP,FN = sum(TPs.values()),sum(FPs.values()),sum(FNs.values()) precision = 0.0 if (TP+FP) == 0 else TP / float(TP+FP) recall = 0.0 if (TP+FN) == 0 else TP / float(TP+FN) f1score = 0.0 if precision==0 or recall == 0 else 2 * (precision*recall) / (precision+recall) if display: print("-"*50) print(formatString % ("All",TP,FP,FN,precision,recall,f1score)) if metric == 'f1score': return f1score elif metric == 'precision': return precision elif metric == 'recall': return recall elif metric == 'all': return precision,recall,f1score else: raise RuntimeError('Unknown metric: %s' % metric)