From 6b1a94b09235691a27c5046efb0dec63decfc640 Mon Sep 17 00:00:00 2001
From: Christophe de Vienne <cdevienne@gmail.com>
Date: Wed, 17 Apr 2013 17:51:44 +0200
Subject: [PATCH] wsmeext.cornice now handle errors properly

---
 tests/test_cornice.py | 23 +++++++++++++++++++++
 wsmeext/cornice.py    | 47 +++++++++++++++++++++++++++++--------------
 2 files changed, 55 insertions(+), 15 deletions(-)

diff --git a/tests/test_cornice.py b/tests/test_cornice.py
index 797e839..7cc2b85 100644
--- a/tests/test_cornice.py
+++ b/tests/test_cornice.py
@@ -31,6 +31,14 @@ def users_create(data):
     return data
 
 
+divide = Service(name='divide', path='/divide')
+
+
+@divide.get()
+@signature(int, int, int)
+def do_divide(a, b):
+    return a / b
+
 needrequest = Service(name='needrequest', path='/needrequest')
 
 
@@ -140,3 +148,18 @@ class WSMECorniceTestCase(unittest.TestCase):
         )
         assert resp.json['authorId'] == 5
         assert resp.json['name'] == 'Author 5'
+
+    def test_server_error(self):
+        resp = self.app.get('/divide?a=1&b=0', expect_errors=True)
+        self.assertEquals(resp.json['faultcode'], 'Server')
+        self.assertEquals(resp.status_code, 500)
+
+    def test_client_error(self):
+        resp = self.app.get(
+            '/divide?a=1&c=0',
+            headers={'Accept': 'application/json'},
+            expect_errors=True
+        )
+        print resp.body
+        self.assertEquals(resp.json['faultcode'], 'Client')
+        self.assertEquals(resp.status_code, 400)
diff --git a/wsmeext/cornice.py b/wsmeext/cornice.py
index 0e8a72f..0286374 100644
--- a/wsmeext/cornice.py
+++ b/wsmeext/cornice.py
@@ -17,6 +17,7 @@ And use it::
 from __future__ import absolute_import
 
 import inspect
+import sys
 
 import wsme
 from wsme.rest import json as restjson
@@ -36,6 +37,12 @@ class WSMEJsonRenderer(object):
     def __call__(self, data, context):
         response = context['request'].response
         response.content_type = 'application/json'
+        if 'faultcode' in data:
+            if data['faultcode'] == 'Client':
+                response.status_code = 400
+            else:
+                response.status_code = 500
+            return restjson.encode_error(None, data)
         return restjson.encode_result(data['result'], data['datatype'])
 
 
@@ -45,6 +52,12 @@ class WSMEXmlRenderer(object):
 
     def __call__(self, data, context):
         response = context['request'].response
+        if 'faultcode' in data:
+            if data['faultcode'] == 'Client':
+                response.status_code = 400
+            else:
+                response.status_code = 500
+            return restxml.encode_error(None, data)
         response.content_type = 'text/xml'
         return restxml.encode_result(data['result'], data['datatype'])
 
@@ -86,22 +99,26 @@ def signature(*args, **kwargs):
                     raise ValueError("Cannot do anything with these arguments")
             else:
                 request = args[0]
-            args, kwargs = combine_args(
-                funcdef,
-                (args_from_args(funcdef, (), request.matchdict),
-                 args_from_params(funcdef, request.params),
-                 args_from_body(funcdef, request.body, request.content_type))
-            )
-            wsme.runtime.check_arguments(funcdef, args, kwargs)
             request.override_renderer = 'wsme' + get_outputformat(request)
-            if funcdef.pass_request:
-                kwargs[funcdef.pass_request] = request
-            if with_self:
-                args.insert(0, self)
-            return {
-                'datatype': funcdef.return_type,
-                'result': f(*args, **kwargs)
-            }
+            try:
+                args, kwargs = combine_args(funcdef, (
+                    args_from_args(funcdef, (), request.matchdict),
+                    args_from_params(funcdef, request.params),
+                    args_from_body(funcdef, request.body, request.content_type)
+                ))
+                wsme.runtime.check_arguments(funcdef, args, kwargs)
+                if funcdef.pass_request:
+                    kwargs[funcdef.pass_request] = request
+                if with_self:
+                    args.insert(0, self)
+
+                result = f(*args, **kwargs)
+                return {
+                    'datatype': funcdef.return_type,
+                    'result': result
+                }
+            except:
+                return wsme.api.format_exception(sys.exc_info())
 
         callfunction.wsme_func = f
         return callfunction