# $Id: TestAssertions.py,v 1.11 2010/02/19 20:14:13 eharrison Exp $ # ------------------------ # from urllib import * from pickle import dumps from pickle import loads from bz2 import compress from bz2 import decompress def unpack(string): return loads(decompress(unquote_plus(string))) def pack(data): return quote_plus(compress(dumps(data))) class TestAssertions: """ this test assertions mix-in MUST be mixed in with a unittest.TestCase, or there will be no 'assertEqual' function available (unless you provide it ....) """ ABBREVIATION_LENGTH=25 def abb(self, obj): s=str(obj) if len(s) > self.ABBREVIATION_LENGTH: return "%s..."%s[:self.ABBREVIATION_LENGTH] return s def __str__(self): # override the nasty test name print out and make one that we can copy/paste to run 1 test if hasattr(self, "id"): return self.id() return "CANNOT FIND self.id(): %s"%self.__class__ def findDivergence(self, l1, l2): i=0 for i in range (0, min(len(l1), len(l2))): if l1[i] != l2[i]: return "(loc %d), %s != %s"%(i, l1[i], l2[i]) return "(loc %d), %s != %s"%(i, self.abb(l1), self.abb(l2)) def assertEqualLists(self, expected, actual, msg=""): self.assertEqual(type(expected), type([])) self.assertEqual(type(actual), type([])) self.assertEqual(len(expected), len(actual), "%s: length e'%d' != a'%d', divergence=%s, **DETAILS**:e%s!=a%s"%(msg, len(expected), len(actual), self.findDivergence(expected, actual), expected, actual)) count = len(expected) for i in range(0, count): ev = expected[i] if type(ev) is type({}): self.assertEqualDictionaries(ev, actual[i], msg=msg+"[%s]"%i) elif type(ev) is type([]): self.assertEqualLists(ev, actual[i], msg=msg+"[%s]"%i) else: self.assertEqual(ev, actual[i], "%s (%d): e'%s' != a'%s'"%(msg, i, str(ev), str(actual[i]))) def convertRealDictRowToDictionary(self, item): retval = {} for k in item.keys(): retval[k] = item[k] return retval def assertEqualDictionaries(self, expected, actual, msg=""): if (str(type(actual)) == ""): actual = self.convertRealDictRowToDictionary(actual) if (str(type(expected)) == ""): expected = self.convertRealDictRowToDictionary(expected) self.assertEqual(type(expected), type({})) self.assertEqual(type(actual), type({})) ek=expected.keys() ak=actual.keys() self.assertEqual(len(expected), len(actual), "%s: length e'%d' != a'%d' ak-ek=%s, ek-ak=%s"%(msg, len(expected), len(actual), list(set(ak)-set(ek)), list(set(ek)-set(ak)))) for ek in expected.keys(): self.failUnless(actual.has_key(ek), "%s expected key '%s' is missing %s != %s"%(msg, ek, str(actual), str(expected))) ev = expected[ek] if type(ev) is type({}): self.assertEqualDictionaries(ev, actual[ek], msg=msg+"[%s]"%ek) elif type(ev) is type([]): self.assertEqualLists(ev, actual[ek], msg=msg+"[%s]"%ek) else: self.assertEqual(ev, actual[ek], "%s[%s] expected='%s'(%s), actual='%s'(%s)"%(msg, ek, ev, type(ev), actual[ek], type(actual[ek]))) def assertEqualAttributes(self, attrs1, attrs2, path=""): if (attrs1 == None and attrs2 != None) or (attrs2 == None and attrs1 != None): self.fail("one of the attributes sets is null, and not the other one: %s"%path) if attrs1 == None: # this is the None case return keys1 = attrs1.keys() keys2 = attrs2.keys() self.assertEquals(keys1, keys2, "mismatched keys: %s, %s: %s"%(keys1, keys2, path)) for key in keys1: self.assertEquals(attrs1[key].nodeValue, attrs2[key].nodeValue, "wrong value, '%s' != '%s': %s"%(attrs1[key].nodeValue, attrs1[key].nodeValue, path)) def assertEqualDoms(self, dom1, dom2, path=""): self.assertEqualAttributes(dom1.attributes, dom2.attributes, path) self.assertEqual(dom1.nodeType, dom2.nodeType, "wrong types '%s' != '%s': %s"%(dom1.nodeType, dom2.nodeType, path)) v1 = ("%s"%dom1.nodeValue).strip() v2 = ("%s"%dom2.nodeValue).strip() self.assertEqual(v1, v2, "wrong values '%s' != '%s': %s"%(v1, v2, path)) children1 = dom1.childNodes children2 = dom2.childNodes if (children1 == None and children2 != None) or (children2 == None and children1 != None): self.fail("one of the child sets is null, and not the other one: %s"%path) self.assertEquals(len(children1), len(children2), "childcount, there are %d in the first set %s, and %d in the second set %s: %s"%(len(children1), children1, len(children2), children2, path)) for i in range(0, len(children1)): self.assertEqualDoms(children1[i], children2[i], "%s/%s"%(path, dom1.nodeName)) def assertEqualMoney(self, expectedValue, actualValue,): self.failIf(expectedValue == None, "expected arg is None!") self.failIf(actualValue == None, "actual arg is None!") if type(expectedValue) != type([]): expectedValue = [expectedValue] for i in range(0, len(expectedValue)): if type(expectedValue[i]) != type(float(1)): expectedValue[i]=float(expectedValue[i]) if type(actualValue) != type(float(1)): actualValue=float(actualValue) # fix precision actualValue = round(actualValue * 100)/100 for i in range(0, len(expectedValue)): expectedValue[i] = round(expectedValue[i] * 100)/100 return self.failUnless(actualValue in expectedValue, "%s must be one of %s"%(actualValue, expectedValue)) def extractNodeValue(self, value): if value==None: return None elif type(value) == type("") or type(value) == type(u""): value = value.strip() return str(value) else: possible=[] if hasattr(value, 'childNodes'): if len(value.childNodes) > 0: for c in value.childNodes: possible.append(self.extractNodeValue(c)) else: value = value.nodeValue if type(value) == type(u""): value = str(value) possible.append(value) else: possible.append(value) for p in possible: if p != None and p != "": return p return None import types import unittest class DynamicTestSuite(list): """Automatically gather all the test classes in this module. ------------------------------------------------------------ modulename: the name of the file where the tests will be found... it prefixes the full test name namespace: a map of string=>object where objects may be test types/classes (usually this would be the callers globals(), or locals() ) allowfrom: only allow the test to run from that context. The default is __main__ and modulename since we usually only test from the commandline or from AllTests """ def __init__(self, modulename, namespace, allowfrom=("__main__"), skiptests=()): # if they sent in their __file__ variable, we need to massage it a bit self.modulename = self.parseFileName(modulename) self.namespace = namespace self.allowfrom = allowfrom # in the event that they did not provide a specific allowfrom, we also # allow from the module if self.allowfrom == ("__main__"): self.allowfrom = (self.modulename, "__main__") self.skiptests = skiptests def parseFileName(self, file): slash = max(file.rfind("/"), file.rfind("\\")) if slash > -1: file = file[slash+1:] dot=file.rfind(".") if dot > -1: file = file[:dot] return file def makeSuite(self): ns=self.namespace tests=[] for key in ns.keys(): if types.TypeType == type(ns[key]) and \ ns[key].__module__ in self.allowfrom and \ not key in self.skiptests and \ issubclass(ns[key], unittest.TestCase): tests.append(self.modulename + "." + key) return unittest.defaultTestLoader.loadTestsFromNames(tests) import xml.dom def printDomTree(node, tab="", show_labels=False): if not node: return if xml.dom.Node.ELEMENT_NODE == node.nodeType: if not show_labels and node.localName == 'label': return sys.stdout.write(tab + node.localName ) if node.hasAttributes: sys.stdout.write( " [" ) for i in range(0, node.attributes.length): sys.stdout.write(str(node.attributes.item(i).localName) + "='") sys.stdout.write(str(node.attributes.item(i).childNodes[0].nodeValue) + "', ") sys.stdout.write("]") # print text nodes for child in node.childNodes: if xml.dom.Node.TEXT_NODE == child.nodeType: s=child.nodeValue.strip() if len(s) > 0: try: sys.stdout.write(tab + "#" + s + "#") except UnicodeEncodeError, e: sys.stdout.write(tab + "#<>#") sys.stdout.write("\n") # recursively print subtrees for child in node.childNodes: printDomTree(child, tab + " ", show_labels) elif xml.dom.Node.DOCUMENT_NODE == node.nodeType: print "DOCUMENT ROOT **" for child in node.childNodes: printDomTree(child, tab + " ", show_labels) elif xml.dom.Node.TEXT_NODE == node.nodeType: s=node.nodeValue.strip() if len(s) > 0: try: sys.stdout.write(tab + "#" + s + "#\n") except UnicodeEncodeError, e: sys.stdout.write(tab + "#<>#\n") def same_type(obj1, obj2): return type(obj1) == type(obj2) def pretty_str(obj, prefix=""): h="" if same_type({}, obj): h = "\n" + prefix + "{\n" + prefix for k,v in obj.items(): h = " " + h + "'%s':"%k + pretty_str(v, prefix + " ") + ", \n" + prefix h = h + "\n" + prefix + "}" elif same_type([], obj) or same_type(tuple(), obj): if len(obj) < 1: h = h + "[]" else: h = "[\n" + prefix for i in obj: h = h + pretty_str(i, prefix + " ") + ", \n" + prefix h = h + "\n" + prefix + "]" elif same_type("", obj): h = "\"" + obj + "\"" else: h = str(obj) return h