Refactor out source data updating

This commit is contained in:
gfyoung
2017-08-13 04:22:42 -07:00
parent 1ee48fb718
commit 24ab22e139
2 changed files with 138 additions and 28 deletions

View File

@ -13,7 +13,7 @@ from updateHostsFile import (
prompt_for_exclusions, prompt_for_move, prompt_for_flush_dns_cache,
prompt_for_update, query_yes_no, recursive_glob, remove_old_hosts_file,
supports_color, strip_rule, update_all_sources, update_readme_data,
write_data, write_opening_header)
update_sources_data, write_data, write_opening_header)
import updateHostsFile
import unittest
@ -542,6 +542,80 @@ class TestMatchesExclusions(Base):
# Update Logic
class TestUpdateSourcesData(Base):
def setUp(self):
Base.setUp(self)
self.data_path = "data"
self.extensions_path = "extensions"
self.source_data_filename = "update.json"
self.update_kwargs = dict(datapath=self.data_path,
extensionspath=self.extensions_path,
sourcedatafilename=self.source_data_filename)
def update_sources_data(self, sources_data, extensions):
return update_sources_data(sources_data[:], extensions=extensions,
**self.update_kwargs)
@mock.patch("updateHostsFile.recursive_glob", return_value=[])
@mock.patch("updateHostsFile.path_join_robust", return_value="dirpath")
@mock.patch(builtins() + ".open", return_value=mock.Mock())
def test_no_update(self, mock_open, mock_join_robust, _):
extensions = []
sources_data = [{"source": "source1.txt"}, {"source": "source2.txt"}]
new_sources_data = self.update_sources_data(sources_data, extensions)
self.assertEqual(new_sources_data, sources_data)
mock_join_robust.assert_not_called()
mock_open.assert_not_called()
extensions = [".json", ".txt"]
new_sources_data = self.update_sources_data(sources_data, extensions)
self.assertEqual(new_sources_data, sources_data)
join_calls = [mock.call(self.extensions_path, ".json"),
mock.call(self.extensions_path, ".txt")]
mock_join_robust.assert_has_calls(join_calls)
mock_open.assert_not_called()
@mock.patch("updateHostsFile.recursive_glob",
side_effect=[[], ["update1.txt", "update2.txt"]])
@mock.patch("json.load", return_value={"mock_source": "mock_source.ext"})
@mock.patch(builtins() + ".open", return_value=mock.Mock())
@mock.patch("updateHostsFile.path_join_robust", return_value="dirpath")
def test_update_only_extensions(self, mock_join_robust, *_):
extensions = [".json"]
sources_data = [{"source": "source1.txt"}, {"source": "source2.txt"}]
new_sources_data = self.update_sources_data(sources_data, extensions)
expected = sources_data + [{"mock_source": "mock_source.ext"}] * 2
self.assertEqual(new_sources_data, expected)
self.assert_called_once(mock_join_robust)
@mock.patch("updateHostsFile.recursive_glob",
side_effect=[["update1.txt", "update2.txt"],
["update3.txt", "update4.txt"]])
@mock.patch("json.load", side_effect=[{"mock_source": "mock_source.txt"},
{"mock_source": "mock_source2.txt"},
{"mock_source": "mock_source3.txt"},
{"mock_source": "mock_source4.txt"}])
@mock.patch(builtins() + ".open", return_value=mock.Mock())
@mock.patch("updateHostsFile.path_join_robust", return_value="dirpath")
def test_update_both_pathways(self, mock_join_robust, *_):
extensions = [".json"]
sources_data = [{"source": "source1.txt"}, {"source": "source2.txt"}]
new_sources_data = self.update_sources_data(sources_data, extensions)
expected = sources_data + [{"mock_source": "mock_source.txt"},
{"mock_source": "mock_source2.txt"},
{"mock_source": "mock_source3.txt"},
{"mock_source": "mock_source4.txt"}]
self.assertEqual(new_sources_data, expected)
self.assert_called_once(mock_join_robust)
class TestUpdateAllSources(BaseStdout):
def setUp(self):

View File

@ -129,25 +129,27 @@ def main():
settings = get_defaults()
settings.update(options)
settings["sources"] = list_dir_no_hidden(settings["datapath"])
settings["extensionsources"] = list_dir_no_hidden(
settings["extensionspath"])
data_path = settings["datapath"]
extensions_path = settings["extensionspath"]
settings["sources"] = list_dir_no_hidden(data_path)
settings["extensionsources"] = list_dir_no_hidden(extensions_path)
# All our extensions folders...
settings["extensions"] = [os.path.basename(item) for item in
list_dir_no_hidden(settings["extensionspath"])]
list_dir_no_hidden(extensions_path)]
# ... intersected with the extensions passed-in as arguments, then sorted.
settings["extensions"] = sorted(list(
set(options["extensions"]).intersection(settings["extensions"])))
auto = settings["auto"]
exclusion_regexes = settings["exclusionregexs"]
source_data_filename = settings["sourcedatafilename"]
update_sources = prompt_for_update(freshen=settings["freshen"],
update_auto=auto)
if update_sources:
update_all_sources(settings["sourcedatafilename"],
settings["hostfilename"])
update_all_sources(source_data_filename, settings["hostfilename"])
gather_exclusions = prompt_for_exclusions(skip_prompt=auto)
@ -159,15 +161,19 @@ def main():
exclusion_pattern=exclusion_pattern,
exclusion_regexes=exclusion_regexes)
extensions = settings["extensions"]
sources_data = update_sources_data(settings["sourcesdata"],
datapath=data_path,
extensions=extensions,
extensionspath=extensions_path,
sourcedatafilename=source_data_filename)
merge_file = create_initial_file()
remove_old_hosts_file(settings["backup"])
extensions = settings["extensions"]
output_subfolder = settings["outputsubfolder"]
final_file = remove_dups_and_excl(merge_file, exclusion_regexes)
number_of_rules = settings["numberofrules"]
output_subfolder = settings["outputsubfolder"]
skip_static_hosts = settings["skipstatichosts"]
write_opening_header(final_file, extensions=extensions,
@ -180,7 +186,7 @@ def main():
extensions=extensions,
numberofrules=number_of_rules,
outputsubfolder=output_subfolder,
sourcesdata=settings["sourcesdata"])
sourcesdata=sources_data)
print_success("Success! The hosts file has been saved in folder " +
output_subfolder + "\nIt contains " +
@ -477,6 +483,52 @@ def matches_exclusions(stripped_rule, exclusion_regexes):
# Update Logic
def update_sources_data(sources_data, **sources_params):
"""
Update the sources data and information for each source.
Parameters
----------
sources_data : list
The list of sources data that we are to update.
sources_params : kwargs
Dictionary providing additional parameters for updating the
sources data. Currently, those fields are:
1) datapath
2) extensions
3) extensionspath
4) sourcedatafilename
Returns
-------
update_sources_data : list
The original source data list with new source data appended.
"""
source_data_filename = sources_params["sourcedatafilename"]
for source in recursive_glob(sources_params["datapath"],
source_data_filename):
update_file = open(source, "r")
update_data = json.load(update_file)
sources_data.append(update_data)
update_file.close()
for source in sources_params["extensions"]:
source_dir = path_join_robust(
sources_params["extensionspath"], source)
for update_file_path in recursive_glob(source_dir,
source_data_filename):
update_file = open(update_file_path, "r")
update_data = json.load(update_file)
sources_data.append(update_data)
update_file.close()
return sources_data
def update_all_sources(source_data_filename, host_filename):
"""
Update all host files, regardless of folder depth.
@ -534,13 +586,6 @@ def create_initial_file():
with open(source, "r") as curFile:
write_data(merge_file, curFile.read())
for source in recursive_glob(settings["datapath"],
settings["sourcedatafilename"]):
update_file = open(source, "r")
update_data = json.load(update_file)
settings["sourcesdata"].append(update_data)
update_file.close()
# spin the sources for extensions to the base file
for source in settings["extensions"]:
for filename in recursive_glob(path_join_robust(
@ -548,15 +593,6 @@ def create_initial_file():
with open(filename, "r") as curFile:
write_data(merge_file, curFile.read())
for update_file_path in recursive_glob(path_join_robust(
settings["extensionspath"], source),
settings["sourcedatafilename"]):
update_file = open(update_file_path, "r")
update_data = json.load(update_file)
settings["sourcesdata"].append(update_data)
update_file.close()
if os.path.isfile(settings["blacklistfile"]):
with open(settings["blacklistfile"], "r") as curFile:
write_data(merge_file, curFile.read())