Merge the resolv.conf branch in

preparation for having a new branch
for all of these changes.
This commit is contained in:
Joshua Harlow 2012-10-10 14:03:08 -07:00
commit f897623f57
6 changed files with 422 additions and 40 deletions

View File

@ -33,6 +33,8 @@ from cloudinit import log as logging
from cloudinit import ssh_util
from cloudinit import util
from cloudinit.distros import helpers
LOG = logging.getLogger(__name__)
@ -116,42 +118,43 @@ class Distro(object):
return "127.0.0.1"
def update_etc_hosts(self, hostname, fqdn):
# Format defined at
# http://unixhelp.ed.ac.uk/CGI/man-cgi?hosts
header = ''
if os.path.exists('/etc/hosts'):
eh = helpers.HostsConf(util.load_file("/etc/hosts"))
else:
eh = helpers.HostsConf('')
header = "# Added by cloud-init"
real_header = "%s on %s" % (header, util.time_rfc2822())
header = "%s on %s" % (header, util.time_rfc2822())
local_ip = self._get_localhost_ip()
hosts_line = "%s\t%s %s" % (local_ip, fqdn, hostname)
new_etchosts = StringIO()
need_write = False
prev_info = eh.get_entry(local_ip)
need_change = False
if not prev_info:
eh.add_entry(local_ip, fqdn, hostname)
need_change = True
hosts_ro_fn = self._paths.join(True, "/etc/hosts")
for line in util.load_file(hosts_ro_fn).splitlines():
if line.strip().startswith(header):
continue
if not line.strip() or line.strip().startswith("#"):
new_etchosts.write("%s\n" % (line))
continue
split_line = [s.strip() for s in line.split()]
if len(split_line) < 2:
new_etchosts.write("%s\n" % (line))
continue
(ip, hosts) = split_line[0], split_line[1:]
if ip == local_ip:
if sorted([hostname, fqdn]) == sorted(hosts):
else:
need_change = True
for entry in prev_info:
if sorted(entry) == sorted([fqdn, hostname]):
# Exists already, leave it be
need_change = False
break
if need_change:
line = "%s\n%s" % (real_header, hosts_line)
need_change = False
need_write = True
new_etchosts.write("%s\n" % (line))
# Doesn't exist, change the first
# entry to be this entry
new_entries = list(prev_info)
new_entries[0] = [fqdn, hostname]
eh.del_entries(local_ip)
for entry in new_entries:
if len(entry) == 1:
eh.add_entry(local_ip, entry[0])
elif len(entry) >= 2:
eh.add_entry(local_ip, *entry)
if need_change:
new_etchosts.write("%s\n%s\n" % (real_header, hosts_line))
need_write = True
if need_write:
contents = new_etchosts.getvalue()
util.write_file(self._paths.join(False, "/etc/hosts"),
contents, mode=0644)
contents = StringIO()
if header:
contents.write("%s\n" % (header))
contents.write("%s\n" % (eh))
util.write_file("/etc/hosts", contents.getvalue(), mode=0644)
def _bring_up_interface(self, device_name):
cmd = ['ifup', device_name]

View File

@ -0,0 +1,252 @@
# vi: ts=4 expandtab
#
# Copyright (C) 2012 Canonical Ltd.
# Copyright (C) 2012 Yahoo! Inc.
#
# Author: Scott Moser <scott.moser@canonical.com>
# Author: Joshua Harlow <harlowja@yahoo-inc.com>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 3, as
# published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from StringIO import StringIO
from cloudinit import util
def _chop_comment(text, comment_chars):
comment_locations = [text.find(c) for c in comment_chars]
comment_locations = [c for c in comment_locations if c != -1]
if not comment_locations:
return (text, '')
min_comment = min(comment_locations)
before_comment = text[0:min_comment]
comment = text[min_comment:]
return (before_comment, comment)
# See: man hosts
# or http://unixhelp.ed.ac.uk/CGI/man-cgi?hosts
class HostsConf(object):
def __init__(self, text):
self._text = text
self._contents = None
def parse(self):
if self._contents is None:
self._contents = self._parse(self._text)
def get_entry(self, ip):
self.parse()
options = []
for (line_type, components) in self._contents:
if line_type == 'option':
(pieces, _tail) = components
if len(pieces) and pieces[0] == ip:
options.append(pieces[1:])
return options
def del_entries(self, ip):
self.parse()
n_entries = []
for (line_type, components) in self._contents:
if line_type != 'option':
n_entries.append((line_type, components))
continue
else:
(pieces, _tail) = components
if len(pieces) and pieces[0] == ip:
pass
elif len(pieces):
n_entries.append((line_type, list(components)))
self._contents = n_entries
def add_entry(self, ip, canonical_hostname, *aliases):
self.parse()
self._contents.append(('option',
([ip, canonical_hostname] + list(aliases), '')))
def _parse(self, contents):
entries = []
for line in contents.splitlines():
if not len(line.strip()):
entries.append(('blank', [line]))
continue
(head, tail) = _chop_comment(line.strip(), '#')
if not len(head):
entries.append(('all_comment', [line]))
continue
entries.append(('option', [head.split(None), tail]))
return entries
def __str__(self):
self.parse()
contents = StringIO()
for (line_type, components) in self._contents:
if line_type == 'blank':
contents.write("%s\n")
elif line_type == 'all_comment':
contents.write("%s\n" % (components[0]))
elif line_type == 'option':
(pieces, tail) = components
pieces = [str(p) for p in pieces]
pieces = "\t".join(pieces)
contents.write("%s%s\n" % (pieces, tail))
return contents.getvalue()
# See: man resolv.conf
class ResolvConf(object):
def __init__(self, text):
self._text = text
self._contents = None
def parse(self):
if self._contents is None:
self._contents = self._parse(self._text)
@property
def nameservers(self):
self.parse()
return self._retr_option('nameserver')
@property
def local_domain(self):
self.parse()
dm = self._retr_option('domain')
if dm:
return dm[0]
return None
@property
def search_domains(self):
self.parse()
current_sds = self._retr_option('search')
flat_sds = []
for sdlist in current_sds:
for sd in sdlist.split(None):
if sd:
flat_sds.append(sd)
return flat_sds
def __str__(self):
self.parse()
contents = StringIO()
for (line_type, components) in self._contents:
if line_type == 'blank':
contents.write("\n")
elif line_type == 'all_comment':
contents.write("%s\n" % (components[0]))
elif line_type == 'option':
(cfg_opt, cfg_value, comment_tail) = components
line = "%s %s" % (cfg_opt, cfg_value)
if len(comment_tail):
line += comment_tail
contents.write("%s\n" % (line))
return contents.getvalue()
def _retr_option(self, opt_name):
found = []
for (line_type, components) in self._contents:
if line_type == 'option':
(cfg_opt, cfg_value, _comment_tail) = components
if cfg_opt == opt_name:
found.append(cfg_value)
return found
def add_nameserver(self, ns):
self.parse()
current_ns = self._retr_option('nameserver')
new_ns = list(current_ns)
new_ns.append(str(ns))
new_ns = util.uniq_list(new_ns)
if len(new_ns) == len(current_ns):
return current_ns
if len(current_ns) >= 3:
# Hard restriction on only 3 name servers
raise ValueError(("Adding %r would go beyond the "
"'3' maximum name servers") % (ns))
self._remove_option('nameserver')
for n in new_ns:
self._contents.append(('option', ['nameserver', n, '']))
return new_ns
def _remove_option(self, opt_name):
def remove_opt(item):
line_type, components = item
if line_type != 'option':
return False
(cfg_opt, _cfg_value, _comment_tail) = components
if cfg_opt != opt_name:
return False
return True
new_contents = []
for c in self._contents:
if not remove_opt(c):
new_contents.append(c)
self._contents = new_contents
def add_search_domain(self, search_domain):
flat_sds = self.search_domains
new_sds = list(flat_sds)
new_sds.append(str(search_domain))
new_sds = util.uniq_list(new_sds)
if len(flat_sds) == len(new_sds):
return new_sds
if len(flat_sds) >= 6:
# Hard restriction on only 6 search domains
raise ValueError(("Adding %r would go beyond the "
"'6' maximum search domains") % (search_domain))
s_list = " ".join(new_sds)
if len(s_list) > 256:
# Some hard limit on 256 chars total
raise ValueError(("Adding %r would go beyond the "
"256 maximum search list character limit")
% (search_domain))
self._remove_option('search')
self._contents.append(('option', ['search', s_list, '']))
return flat_sds
@local_domain.setter
def local_domain(self, domain):
self.parse()
self._remove_option('domain')
self._contents.append(('option', ['domain', str(domain), '']))
return domain
def _parse(self, contents):
entries = []
for (i, line) in enumerate(contents.splitlines()):
sline = line.strip()
if not sline:
entries.append(('blank', [line]))
continue
(head, tail) = _chop_comment(line, ';#')
if not len(head.strip()):
entries.append(('all_comment', [line]))
continue
if not tail:
tail = ''
try:
(cfg_opt, cfg_values) = head.split(None, 1)
except (IndexError, ValueError):
raise IOError("Incorrectly formatted resolv.conf line %s"
% (i + 1))
if cfg_opt not in ['nameserver', 'domain',
'search', 'sortlist', 'options']:
raise IOError("Unexpected resolv.conf option %s" % (cfg_opt))
entries.append(("option", [cfg_opt, cfg_values, tail]))
return entries

View File

@ -23,6 +23,8 @@
import os
from cloudinit import distros
from cloudinit.distros import helpers as d_helpers
from cloudinit import helpers
from cloudinit import log as logging
from cloudinit import util
@ -81,16 +83,29 @@ class Distro(distros.Distro):
def install_packages(self, pkglist):
self.package_command('install', pkglist)
def _write_resolve(self, dns_servers, search_servers):
contents = []
def _adjust_resolve(self, dns_servers, search_servers):
r_conf = d_helpers.ResolvConf(util.load_file("/etc/resolv.conf"))
try:
r_conf.parse()
except IOError:
util.logexc(LOG,
"Failed at parsing %s reverting to an empty instance",
"/etc/resolv.conf")
r_conf = d_helpers.ResolvConf('')
r_conf.parse()
if dns_servers:
for s in dns_servers:
contents.append("nameserver %s" % (s))
try:
r_conf.add_nameserver(s)
except ValueError:
util.logexc(LOG, "Failed at adding nameserver %s", s)
if search_servers:
contents.append("search %s" % (" ".join(search_servers)))
if contents:
contents.insert(0, _make_header())
util.write_file("/etc/resolv.conf", "\n".join(contents), 0644)
for s in search_servers:
try:
r_conf.add_search_domain(s)
except ValueError:
util.logexc(LOG, "Failed at adding search domain %s", s)
util.write_file("/etc/resolv.conf", str(r_conf), 0644)
def _write_network(self, settings):
# TODO(harlowja) fix this... since this is the ubuntu format

View File

@ -983,6 +983,16 @@ def find_devs_with(criteria=None, oformat='device',
return entries
def uniq_list(in_list):
out_list = []
for i in in_list:
if i in out_list:
continue
else:
out_list.append(i)
return out_list
def load_file(fname, read_cb=None, quiet=False):
LOG.debug("Reading from %s (quiet=%s)", fname, quiet)
ofh = StringIO()

View File

@ -0,0 +1,41 @@
from mocker import MockerTestCase
from cloudinit.distros import helpers
BASE_ETC = '''
# Example
127.0.0.1 localhost
192.168.1.10 foo.mydomain.org foo
192.168.1.10 bar.mydomain.org bar
146.82.138.7 master.debian.org master
209.237.226.90 www.opensource.org
'''
BASE_ETC = BASE_ETC.strip()
class TestHostsHelper(MockerTestCase):
def test_parse(self):
eh = helpers.HostsConf(BASE_ETC)
self.assertEquals(eh.get_entry('127.0.0.1'), [['localhost']])
self.assertEquals(eh.get_entry('192.168.1.10'),
[['foo.mydomain.org', 'foo'],
['bar.mydomain.org', 'bar']])
eh = str(eh)
self.assertTrue(eh.startswith('# Example'))
def test_add(self):
eh = helpers.HostsConf(BASE_ETC)
eh.add_entry('127.0.0.0', 'blah')
self.assertEquals(eh.get_entry('127.0.0.0'), [['blah']])
eh.add_entry('127.0.0.3', 'blah', 'blah2', 'blah3')
self.assertEquals(eh.get_entry('127.0.0.3'),
[['blah', 'blah2', 'blah3']])
def test_del(self):
eh = helpers.HostsConf(BASE_ETC)
eh.add_entry('127.0.0.0', 'blah')
self.assertEquals(eh.get_entry('127.0.0.0'), [['blah']])
eh.del_entries('127.0.0.0')
self.assertEquals(eh.get_entry('127.0.0.0'), [])

View File

@ -0,0 +1,61 @@
from mocker import MockerTestCase
from cloudinit.distros import helpers
BASE_RESOLVE = '''
; generated by /sbin/dhclient-script
search blah.yahoo.com yahoo.com
nameserver 10.15.44.14
nameserver 10.15.30.92
'''
BASE_RESOLVE = BASE_RESOLVE.strip()
class TestResolvHelper(MockerTestCase):
def test_parse_same(self):
rp = helpers.ResolvConf(BASE_RESOLVE)
rp_r = str(rp).strip()
self.assertEquals(BASE_RESOLVE, rp_r)
def test_local_domain(self):
rp = helpers.ResolvConf(BASE_RESOLVE)
self.assertEquals(None, rp.local_domain)
rp.local_domain = "bob"
self.assertEquals('bob', rp.local_domain)
self.assertIn('domain bob', str(rp))
def test_nameservers(self):
rp = helpers.ResolvConf(BASE_RESOLVE)
self.assertIn('10.15.44.14', rp.nameservers)
self.assertIn('10.15.30.92', rp.nameservers)
rp.add_nameserver('10.2')
self.assertIn('10.2', rp.nameservers)
self.assertIn('nameserver 10.2', str(rp))
self.assertNotIn('10.3', rp.nameservers)
self.assertEquals(len(rp.nameservers), 3)
rp.add_nameserver('10.2')
with self.assertRaises(ValueError):
rp.add_nameserver('10.3')
self.assertNotIn('10.3', rp.nameservers)
def test_search_domains(self):
rp = helpers.ResolvConf(BASE_RESOLVE)
self.assertIn('yahoo.com', rp.search_domains)
self.assertIn('blah.yahoo.com', rp.search_domains)
rp.add_search_domain('bbb.y.com')
self.assertIn('bbb.y.com', rp.search_domains)
self.assertRegexpMatches(str(rp), r'search(.*)bbb.y.com(.*)')
self.assertIn('bbb.y.com', rp.search_domains)
rp.add_search_domain('bbb.y.com')
self.assertEquals(len(rp.search_domains), 3)
rp.add_search_domain('bbb2.y.com')
self.assertEquals(len(rp.search_domains), 4)
rp.add_search_domain('bbb3.y.com')
self.assertEquals(len(rp.search_domains), 5)
rp.add_search_domain('bbb4.y.com')
self.assertEquals(len(rp.search_domains), 6)
with self.assertRaises(ValueError):
rp.add_search_domain('bbb5.y.com')
self.assertEquals(len(rp.search_domains), 6)