2.py (1181B) [raw]
1 #!/usr/bin/env python3 2 import sys 3 from collections import defaultdict 4 5 def parse(): 6 ants = defaultdict(list) 7 i = 0; 8 for l in sys.stdin: 9 for j in range(len(l.strip())): 10 c = l[j] 11 if c != '.': 12 ants[c].append((i, j)) 13 i += 1 14 return ants, i 15 16 def onmap(i, j, n): 17 return i >= 0 and i < n and j >=0 and j < n 18 19 def get_anodes(a, b, n): 20 dx = b[1] - a[1] 21 dy = b[0] - a[0] 22 anodes = set() 23 for loc in [a, b]: 24 i, j = loc 25 while onmap(i, j, n): 26 anodes.add((i, j)) 27 i += dy 28 j += dx 29 i, j = loc 30 while onmap(i, j, n): 31 anodes.add((i, j)) 32 i -= dy 33 j -= dx 34 return anodes 35 36 def count(ants, n): 37 locs = set() 38 for a in ants.keys(): 39 nodes = ants[a] 40 # pairwise step through and find new locs 41 for i in range(len(nodes)): 42 for j in range(i + 1, len(nodes)): 43 for l in get_anodes(nodes[i], nodes[j], n): 44 locs.add(l) 45 46 return len(locs) 47 48 if __name__ == '__main__': 49 # input is square nxn 50 ants, n = parse() 51 print(count(ants, n))