1.py (1021B) [raw]
1 #!/usr/bin/env python3 2 import sys 3 from collections import defaultdict 4 5 def parse(): 6 ants = defaultdict(list) 7 i = 0; 8 for l in sys.stdin: 9 for j in range(len(l.strip())): 10 c = l[j] 11 if c != '.': 12 ants[c].append((i, j)) 13 i += 1 14 return ants, i 15 16 def onmap(i, j, n): 17 return i >= 0 and i < n and j >=0 and j < n 18 19 def get_anodes(a, b, n): 20 dx = b[1] - a[1] 21 dy = b[0] - a[0] 22 cands = [ 23 (b[0] + dy, b[1] + dx), 24 (a[0] - dy, a[1] - dx) 25 ] 26 return filter(lambda x: onmap(x[0], x[1], n), cands) 27 28 def count(ants, n): 29 locs = set() 30 for a in ants.keys(): 31 nodes = ants[a] 32 # pairwise step through and find new locs 33 for i in range(len(nodes)): 34 for j in range(i + 1, len(nodes)): 35 for l in get_anodes(nodes[i], nodes[j], n): 36 locs.add(l) 37 38 return len(locs) 39 40 if __name__ == '__main__': 41 # input is square nxn 42 ants, n = parse() 43 print(count(ants, n))