There seems not to be a built-in function to flatten a list in Python. So I implemented it.
List containing lists that have primitive values
If the list doesn’t have a nested list in it, the following way is the simplest way to flatten the list.
list_of_list = [range(5), range(3)]
flat_list = []
for x in list_of_list:
for y in x:
flat_list.append(y)
print(flat_list)
# [0, 1, 2, 3, 4, 0, 1, 2]
For one-liner
print([y for x in list_of_list for y in x])
But this solution doesn’t work for a list that has a nested list.
list_of_list = [range(5), [range(3), range(3)]]
print([y for x in list_of_list for y in x])
# TypeError: 'int' object is not iterable
List containing lists that have nested lists
Let’s define the following complicated list.
TEST_DATASET = [
[[1, 2, 3], [4, 5], [], [6]],
[
[1,
[
2, 3, [4, 5], [6], 7
]
],
[22, 33, [44, 55]]
],
[1, 2, [3, 4], [[5, 6], [7, 8]], [9, 10]],
[
[],
[
[
[
[[11, 22], [33, 44]], [55, "66", [77]]
],
[
["first", ["BB"], []], [["CC", ["DD", "FF"]]]
]
],
[
["99", "88"],
]
],
[
["E", "F"],
[55, [66, [77]]]
]
]
]
Let’s try to fallten this list.
Recursive call with for-in loop and isinstance check
The first solution is the following.
def flat(element) -> list:
has_list = any([isinstance(x, list) for x in element])
if not has_list:
return element
flatten_list = []
for x in element:
if isinstance(x, list):
val = flat(x)
flatten_list.extend(val)
else:
flatten_list.append(x)
return flatten_list
The results for each list
print([flat(x) for x in TEST_DATASET])
# [1, 2, 3, 4, 5, 6]
# [1, 2, 3, 4, 5, 6, 7, 22, 33, 44, 55]
# [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
# [11, 22, 33, 44, 55, '66', 77, 'first', 'BB', 'CC', 'DD', 'FF', '99', '88', 'E', 'F', 55, 66, 77]
The result for the whole list
print(flat(TEST_DATASET))
# [1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 22, 33, 44, 55, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 22, 33, 44, 55, '66', 77, 'first', 'BB', 'CC', 'DD', 'FF', '99', '88', 'E', 'F', 55, 66, 77]
Note that the function doesn’t work if the value is not iterable.
Recursive call with square brackets and colon
The second solution is to use square brackets with a colon.
def flat2(element) -> list:
if element == []:
return element
if isinstance(element[0], list):
return flat2(element[0]) + flat2(element[1:])
return element[:1] + flat2(element[1:])
[print(flat2(x)) for x in TEST_DATASET]
# [1, 2, 3, 4, 5, 6]
# [1, 2, 3, 4, 5, 6, 7, 22, 33, 44, 55]
# [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
# [11, 22, 33, 44, 55, '66', 77, 'first', 'BB', 'CC', 'DD', 'FF', '99', '88', 'E', 'F', 55, 66, 77]
print(flat2(TEST_DATASET))
# [1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 7, 22, 33, 44, 55, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 22, 33, 44, 55, '66', 77, 'first', 'BB', 'CC', 'DD', 'FF', '99', '88', 'E', 'F', 55, 66, 77]
Check the following article if you are not familiar with square brackets with a colon for a list.
The recursive process looks like the following. The yellow circle is the leaf. The result can be calculated by concatenating all the leaves.
Comments