From e84325bd1da3d0f96b48c63b3eec789699dbeea7 Mon Sep 17 00:00:00 2001
From: Damien George <damien.p.george@gmail.com>
Date: Tue, 8 Dec 2015 23:34:59 +0000
Subject: [PATCH] py: Add mp_get_stream_raise to factor out check for stream
 methods.

---
 py/stream.c | 69 ++++++++++++++++++++++-------------------------------
 py/stream.h |  7 ++++++
 2 files changed, 36 insertions(+), 40 deletions(-)

diff --git a/py/stream.c b/py/stream.c
index 025d6aee5..d29db621b 100644
--- a/py/stream.c
+++ b/py/stream.c
@@ -49,12 +49,21 @@ STATIC mp_obj_t stream_readall(mp_obj_t self_in);
 
 #define STREAM_CONTENT_TYPE(stream) (((stream)->is_text) ? &mp_type_str : &mp_type_bytes)
 
-STATIC mp_obj_t stream_read(mp_uint_t n_args, const mp_obj_t *args) {
-    struct _mp_obj_base_t *o = (struct _mp_obj_base_t *)MP_OBJ_TO_PTR(args[0]);
-    if (o->type->stream_p == NULL || o->type->stream_p->read == NULL) {
+const mp_stream_p_t *mp_get_stream_raise(mp_obj_t self_in, int flags) {
+    mp_obj_base_t *o = (mp_obj_base_t*)MP_OBJ_TO_PTR(self_in);
+    const mp_stream_p_t *stream_p = o->type->stream_p;
+    if (stream_p == NULL
+        || ((flags & MP_STREAM_OP_READ) && stream_p->read == NULL)
+        || ((flags & MP_STREAM_OP_WRITE) && stream_p->write == NULL)
+        || ((flags & MP_STREAM_OP_IOCTL) && stream_p->ioctl == NULL)) {
         // CPython: io.UnsupportedOperation, OSError subclass
-        nlr_raise(mp_obj_new_exception_msg(&mp_type_OSError, "Operation not supported"));
+        nlr_raise(mp_obj_new_exception_msg(&mp_type_OSError, "stream operation not supported"));
     }
+    return stream_p;
+}
+
+STATIC mp_obj_t stream_read(mp_uint_t n_args, const mp_obj_t *args) {
+    const mp_stream_p_t *stream_p = mp_get_stream_raise(args[0], MP_STREAM_OP_READ);
 
     // What to do if sz < -1?  Python docs don't specify this case.
     // CPython does a readall, but here we silently let negatives through,
@@ -65,7 +74,7 @@ STATIC mp_obj_t stream_read(mp_uint_t n_args, const mp_obj_t *args) {
     }
 
     #if MICROPY_PY_BUILTINS_STR_UNICODE
-    if (o->type->stream_p->is_text) {
+    if (stream_p->is_text) {
         // We need to read sz number of unicode characters.  Because we don't have any
         // buffering, and because the stream API can only read bytes, we must read here
         // in units of bytes and must never over read.  If we want sz chars, then reading
@@ -85,7 +94,7 @@ STATIC mp_obj_t stream_read(mp_uint_t n_args, const mp_obj_t *args) {
                 nlr_raise(mp_obj_new_exception_msg_varg(&mp_type_MemoryError, "out of memory"));
             }
             int error;
-            mp_uint_t out_sz = o->type->stream_p->read(MP_OBJ_FROM_PTR(o), p, more_bytes, &error);
+            mp_uint_t out_sz = stream_p->read(args[0], p, more_bytes, &error);
             if (out_sz == MP_STREAM_ERROR) {
                 vstr_cut_tail_bytes(&vstr, more_bytes);
                 if (mp_is_nonblocking_error(error)) {
@@ -156,7 +165,7 @@ STATIC mp_obj_t stream_read(mp_uint_t n_args, const mp_obj_t *args) {
     vstr_t vstr;
     vstr_init_len(&vstr, sz);
     int error;
-    mp_uint_t out_sz = o->type->stream_p->read(MP_OBJ_FROM_PTR(o), vstr.buf, sz, &error);
+    mp_uint_t out_sz = stream_p->read(args[0], vstr.buf, sz, &error);
     if (out_sz == MP_STREAM_ERROR) {
         vstr_clear(&vstr);
         if (mp_is_nonblocking_error(error)) {
@@ -170,19 +179,15 @@ STATIC mp_obj_t stream_read(mp_uint_t n_args, const mp_obj_t *args) {
         nlr_raise(mp_obj_new_exception_arg1(&mp_type_OSError, MP_OBJ_NEW_SMALL_INT(error)));
     } else {
         vstr.len = out_sz;
-        return mp_obj_new_str_from_vstr(STREAM_CONTENT_TYPE(o->type->stream_p), &vstr);
+        return mp_obj_new_str_from_vstr(STREAM_CONTENT_TYPE(stream_p), &vstr);
     }
 }
 
 mp_obj_t mp_stream_write(mp_obj_t self_in, const void *buf, size_t len) {
-    struct _mp_obj_base_t *o = (struct _mp_obj_base_t *)MP_OBJ_TO_PTR(self_in);
-    if (o->type->stream_p == NULL || o->type->stream_p->write == NULL) {
-        // CPython: io.UnsupportedOperation, OSError subclass
-        nlr_raise(mp_obj_new_exception_msg(&mp_type_OSError, "Operation not supported"));
-    }
+    const mp_stream_p_t *stream_p = mp_get_stream_raise(self_in, MP_STREAM_OP_WRITE);
 
     int error;
-    mp_uint_t out_sz = o->type->stream_p->write(self_in, buf, len, &error);
+    mp_uint_t out_sz = stream_p->write(self_in, buf, len, &error);
     if (out_sz == MP_STREAM_ERROR) {
         if (mp_is_nonblocking_error(error)) {
             // http://docs.python.org/3/library/io.html#io.RawIOBase.write
@@ -210,11 +215,7 @@ STATIC mp_obj_t stream_write_method(mp_obj_t self_in, mp_obj_t arg) {
 }
 
 STATIC mp_obj_t stream_readinto(mp_uint_t n_args, const mp_obj_t *args) {
-    struct _mp_obj_base_t *o = (struct _mp_obj_base_t *)MP_OBJ_TO_PTR(args[0]);
-    if (o->type->stream_p == NULL || o->type->stream_p->read == NULL) {
-        // CPython: io.UnsupportedOperation, OSError subclass
-        nlr_raise(mp_obj_new_exception_msg(&mp_type_OSError, "Operation not supported"));
-    }
+    const mp_stream_p_t *stream_p = mp_get_stream_raise(args[0], MP_STREAM_OP_READ);
     mp_buffer_info_t bufinfo;
     mp_get_buffer_raise(args[1], &bufinfo, MP_BUFFER_WRITE);
 
@@ -230,7 +231,7 @@ STATIC mp_obj_t stream_readinto(mp_uint_t n_args, const mp_obj_t *args) {
     }
 
     int error;
-    mp_uint_t out_sz = o->type->stream_p->read(MP_OBJ_FROM_PTR(o), bufinfo.buf, len, &error);
+    mp_uint_t out_sz = stream_p->read(args[0], bufinfo.buf, len, &error);
     if (out_sz == MP_STREAM_ERROR) {
         if (mp_is_nonblocking_error(error)) {
             return mp_const_none;
@@ -242,11 +243,7 @@ STATIC mp_obj_t stream_readinto(mp_uint_t n_args, const mp_obj_t *args) {
 }
 
 STATIC mp_obj_t stream_readall(mp_obj_t self_in) {
-    struct _mp_obj_base_t *o = (struct _mp_obj_base_t *)MP_OBJ_TO_PTR(self_in);
-    if (o->type->stream_p == NULL || o->type->stream_p->read == NULL) {
-        // CPython: io.UnsupportedOperation, OSError subclass
-        nlr_raise(mp_obj_new_exception_msg(&mp_type_OSError, "Operation not supported"));
-    }
+    const mp_stream_p_t *stream_p = mp_get_stream_raise(self_in, MP_STREAM_OP_READ);
 
     mp_uint_t total_size = 0;
     vstr_t vstr;
@@ -255,7 +252,7 @@ STATIC mp_obj_t stream_readall(mp_obj_t self_in) {
     mp_uint_t current_read = DEFAULT_BUFFER_SIZE;
     while (true) {
         int error;
-        mp_uint_t out_sz = o->type->stream_p->read(self_in, p, current_read, &error);
+        mp_uint_t out_sz = stream_p->read(self_in, p, current_read, &error);
         if (out_sz == MP_STREAM_ERROR) {
             if (mp_is_nonblocking_error(error)) {
                 // With non-blocking streams, we read as much as we can.
@@ -286,16 +283,12 @@ STATIC mp_obj_t stream_readall(mp_obj_t self_in) {
     }
 
     vstr.len = total_size;
-    return mp_obj_new_str_from_vstr(STREAM_CONTENT_TYPE(o->type->stream_p), &vstr);
+    return mp_obj_new_str_from_vstr(STREAM_CONTENT_TYPE(stream_p), &vstr);
 }
 
 // Unbuffered, inefficient implementation of readline() for raw I/O files.
 STATIC mp_obj_t stream_unbuffered_readline(mp_uint_t n_args, const mp_obj_t *args) {
-    struct _mp_obj_base_t *o = (struct _mp_obj_base_t *)MP_OBJ_TO_PTR(args[0]);
-    if (o->type->stream_p == NULL || o->type->stream_p->read == NULL) {
-        // CPython: io.UnsupportedOperation, OSError subclass
-        nlr_raise(mp_obj_new_exception_msg(&mp_type_OSError, "Operation not supported"));
-    }
+    const mp_stream_p_t *stream_p = mp_get_stream_raise(args[0], MP_STREAM_OP_READ);
 
     mp_int_t max_size = -1;
     if (n_args > 1) {
@@ -316,7 +309,7 @@ STATIC mp_obj_t stream_unbuffered_readline(mp_uint_t n_args, const mp_obj_t *arg
         }
 
         int error;
-        mp_uint_t out_sz = o->type->stream_p->read(MP_OBJ_FROM_PTR(o), p, 1, &error);
+        mp_uint_t out_sz = stream_p->read(args[0], p, 1, &error);
         if (out_sz == MP_STREAM_ERROR) {
             if (mp_is_nonblocking_error(error)) {
                 if (vstr.len == 1) {
@@ -347,7 +340,7 @@ done:
         }
     }
 
-    return mp_obj_new_str_from_vstr(STREAM_CONTENT_TYPE(o->type->stream_p), &vstr);
+    return mp_obj_new_str_from_vstr(STREAM_CONTENT_TYPE(stream_p), &vstr);
 }
 
 // TODO take an optional extra argument (what does it do exactly?)
@@ -373,11 +366,7 @@ mp_obj_t mp_stream_unbuffered_iter(mp_obj_t self) {
 }
 
 STATIC mp_obj_t stream_seek(mp_uint_t n_args, const mp_obj_t *args) {
-    struct _mp_obj_base_t *o = (struct _mp_obj_base_t *)MP_OBJ_TO_PTR(args[0]);
-    if (o->type->stream_p == NULL || o->type->stream_p->ioctl == NULL) {
-        // CPython: io.UnsupportedOperation, OSError subclass
-        nlr_raise(mp_obj_new_exception_msg(&mp_type_OSError, "Operation not supported"));
-    }
+    const mp_stream_p_t *stream_p = mp_get_stream_raise(args[0], MP_STREAM_OP_IOCTL);
 
     struct mp_stream_seek_t seek_s;
     // TODO: Could be uint64
@@ -388,7 +377,7 @@ STATIC mp_obj_t stream_seek(mp_uint_t n_args, const mp_obj_t *args) {
     }
 
     int error;
-    mp_uint_t res = o->type->stream_p->ioctl(MP_OBJ_FROM_PTR(o), MP_STREAM_SEEK, (mp_uint_t)(uintptr_t)&seek_s, &error);
+    mp_uint_t res = stream_p->ioctl(args[0], MP_STREAM_SEEK, (mp_uint_t)(uintptr_t)&seek_s, &error);
     if (res == MP_STREAM_ERROR) {
         nlr_raise(mp_obj_new_exception_arg1(&mp_type_OSError, MP_OBJ_NEW_SMALL_INT(error)));
     }
diff --git a/py/stream.h b/py/stream.h
index b3958d301..0471777c7 100644
--- a/py/stream.h
+++ b/py/stream.h
@@ -37,6 +37,13 @@ MP_DECLARE_CONST_FUN_OBJ(mp_stream_write_obj);
 MP_DECLARE_CONST_FUN_OBJ(mp_stream_seek_obj);
 MP_DECLARE_CONST_FUN_OBJ(mp_stream_tell_obj);
 
+// these are for mp_get_stream_raise and can be or'd together
+#define MP_STREAM_OP_READ (1)
+#define MP_STREAM_OP_WRITE (2)
+#define MP_STREAM_OP_IOCTL (4)
+
+const mp_stream_p_t *mp_get_stream_raise(mp_obj_t self_in, int flags);
+
 // Iterator which uses mp_stream_unbuffered_readline_obj
 mp_obj_t mp_stream_unbuffered_iter(mp_obj_t self);
 
-- 
GitLab