149 lines
4.6 KiB
Python
149 lines
4.6 KiB
Python
# Copyright 2014 Huawei Technologies Co. Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Cluster database operations."""
|
|
|
|
import simplejson as json
|
|
|
|
from compass.db import api
|
|
from compass.db.api import database
|
|
from compass.db.api.utils import merge_dict
|
|
from compass.db.api.utils import wrap_to_dict
|
|
from compass.db.exception import InvalidParameter
|
|
from compass.db.exception import RecordNotExists
|
|
|
|
from compass.db.config_validation import default_validator
|
|
# from compass.db.config_validation import extension
|
|
|
|
from compass.db.models import Cluster
|
|
|
|
|
|
SUPPORTED_FILTERS = ['name', 'adapter', 'owner']
|
|
|
|
ERROR_MSG = {
|
|
'findNoCluster': 'Cannot find the Cluster, ID is %d',
|
|
}
|
|
|
|
|
|
@wrap_to_dict()
|
|
def get_cluster(cluster_id):
|
|
|
|
with database.session() as session:
|
|
cluster = _get_cluster(session, cluster_id)
|
|
info = cluster.to_dict()
|
|
|
|
return info
|
|
|
|
|
|
@wrap_to_dict()
|
|
def list_clusters(filters=None):
|
|
"""List all users, optionally filtered by some fields."""
|
|
|
|
filters = filters or {}
|
|
with database.session() as session:
|
|
clusters = _list_clusters(session, filters)
|
|
clusters_info = [cluster.to_dict() for cluster in clusters]
|
|
|
|
return clusters_info
|
|
|
|
|
|
@wrap_to_dict()
|
|
def get_cluster_config(cluster_id):
|
|
"""Get configuration info for a specified cluster."""
|
|
|
|
with database.session() as session:
|
|
config = _get_cluster_config(session, cluster_id)
|
|
|
|
return config
|
|
|
|
|
|
def _get_cluster_config(session, cluster_id):
|
|
|
|
with session.begin(subtransactions=True):
|
|
cluster = _get_cluster(cluster_id)
|
|
config = cluster.config
|
|
|
|
return config
|
|
|
|
|
|
def _get_cluster(session, cluster_id):
|
|
"""Get the adapter by ID."""
|
|
with session.begin(subtransactions=True):
|
|
cluster = session.query(Cluster).filter_by(id=cluster_id).first()
|
|
if not cluster:
|
|
err_msg = ERROR_MSG['findNoCluster'] % cluster_id
|
|
raise RecordNotExists(err_msg)
|
|
return cluster
|
|
|
|
|
|
def _list_clusters(session, filters=None):
|
|
"""Get all clusters, optionally filtered by some fields."""
|
|
|
|
filters = filters or {}
|
|
|
|
with session.begin(subtransactions=True):
|
|
query = api.model_query(session, Cluster)
|
|
clusters = api.model_filter(query, Cluster,
|
|
filters, SUPPORTED_FILTERS).all()
|
|
|
|
return clusters
|
|
|
|
|
|
def update_cluster_config(cluster_id, root_elem, config, patch=True):
|
|
result = None
|
|
if root_elem not in ["os_config", "package_config"]:
|
|
raise InvalidParameter("Invalid parameter %s" % root_elem)
|
|
|
|
with database.session() as session:
|
|
cluster = _get_cluster(session, cluster_id)
|
|
|
|
id_name = None
|
|
id_value = None
|
|
if root_elem == "os_config":
|
|
id_name = "os_id"
|
|
id_value = getattr(cluster, "os_id")
|
|
else:
|
|
id_name = "adapter_id"
|
|
id_value = getattr(cluster, "adapter_id")
|
|
|
|
# Validate config format and values
|
|
is_valid, message = default_validator.validate_config(session,
|
|
config, id_name,
|
|
id_value, patch)
|
|
if not is_valid:
|
|
raise InvalidParameter(message)
|
|
|
|
# For addtional validation, you can define functions in extension,
|
|
# for example:
|
|
# os_name = get_os(cluster.os_id)['name']
|
|
# if getattr(extension, os_name):
|
|
# func = getattr(getattr(extension, os_name), 'validate_config')
|
|
# if not func(session, os_id, config, patch):
|
|
# return False
|
|
|
|
if root_elem == 'os_config':
|
|
os_config = cluster.os_global_config
|
|
os_config = json.loads(json.dumps(os_config))
|
|
merge_dict(os_config, config)
|
|
cluster.os_global_config = os_config
|
|
result = cluster.os_global_config
|
|
else:
|
|
package_config = cluster.package_global_config
|
|
package_config = json.loads(json.dumps(os_config))
|
|
merge_dict(package_config, config)
|
|
cluster.package_global_config = package_config
|
|
result = cluster.package_global_config
|
|
|
|
return result
|