Source code for cervmongo.aio

#  aio.py
#
#  Copyright 2020 Anthony "antcer1213" Cervantes <anthony.cervantes@cerver.info>
#
#  MIT License
#
#  Permission is hereby granted, free of charge, to any person obtaining a copy
#  of this software and associated documentation files (the "Software"), to deal
#  in the Software without restriction, including without limitation the rights
#  to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
#  copies of the Software, and to permit persons to whom the Software is
#  furnished to do so, subject to the following conditions:
#
#  The above copyright notice and this permission notice shall be included in all
#  copies or substantial portions of the Software.
#
#  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
#  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
#  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
#  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
#  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
#  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
#  SOFTWARE.
#
#
__all__ = ["SUPPORT_ASYNCIO_CLIENT", "SUPPORT_ASYNCIO_BUCKET", "get_async_client", "get_async_doc", "AsyncIOClient", "AsyncIODoc"]

from os import path as os_path
from pymongo import WriteConcern
from dateutil.parser import parse as dateparse
import types
from jsonschema import validate
from functools import partial
import typing
import copy
import logging

from .models import (
                GenericResponse,
                StandardResponse,
                MongoListResponse,
                MongoDictResponse,
                )
from .vars import (
                StringEnum,
                IntEnum,
                PAGINATION_SORT_FIELDS,
                ENUM,
                DOC_ID,
                )
from .utils import (
                get_file_meta_information,
                parse_string_header,
                format_string_for_id,
                silent_drop_kwarg,
                current_datetime,
                file_and_fileobj,
                detect_mimetype,
                dict_to_query,
                clean_kwargs,
                current_date,
                json_load,
                json_dump,
                logger,
                )
from .config import Config

try:
    from motor.motor_asyncio import AsyncIOMotorClient as MongoClient
    from motor.motor_asyncio import AsyncIOMotorGridFSBucket as GridFSBucket
    SUPPORT_ASYNCIO_CLIENT = True #: True if motor package is installed else False
    SUPPORT_ASYNCIO_BUCKET = True #: True if motor package is installed else False
except:
    logger.warning("motor is not installed. needed if using asyncio")
    class MongoClient: pass
    class GridFSBucket: pass # NOTE: in case of refereneces
    SUPPORT_ASYNCIO_CLIENT = False #: True if motor package is installed else False
    SUPPORT_ASYNCIO_BUCKET = False #: True if motor package is installed else False


[docs]class AsyncIOClient(MongoClient): """ High-level AsyncIOMotorClient subclass with additional methods added for ease-of-use, having some automated conveniences and defaults. """ _MONGO_URI = lambda _: getattr(Config, "MONGO_URI", None) _DEFAULT_COLLECTION = None _KWARGS = None _LOGGING_COND_GET = None _LOGGING_COND_POST = None _LOGGING_COND_PUT = None _LOGGING_COND_PATCH = None _LOGGING_COND_DELETE = None def __init__(self, mongo_uri=None, default_collection=None, **kwargs): self._MONGO_URI = mongo_uri or self._MONGO_URI if callable(self._MONGO_URI): self._MONGO_URI = self._MONGO_URI() self._DEFAULT_COLLECTION = default_collection or self._DEFAULT_COLLECTION if kwargs: self._KWARGS = kwargs.copy() for kwarg in kwargs.keys(): if kwarg.lower() in ('logging_cond_get', 'logging_cond_post', 'logging_cond_put', 'logging_cond_patch', 'logging_cond_delete'): setattr(self, kwarg.upper(), kwargs.pop(kwarg)) MongoClient.__init__(self, self._MONGO_URI, **kwargs) db = self.get_default_database() logger.info("db detected '{}' of type '{}'".format(db.name, type(db.name))) if not getattr(db, "name", None) or db.name == "None": logger.warning("database not provided in MONGO_URI, assign with method set_database") logger.warning("gridfsbucket not instantiated due to missing database") else: global SUPPORT_ASYNCIO_BUCKET if SUPPORT_ASYNCIO_BUCKET: logger.debug("gridfsbucket instantiated under self.FILES") self.FILES = GridFSBucket(db) else: logger.warning("gridfsbucket not instantiated due to missing 'tornado' package") self.FILES = None def __repr__(self): db = self.get_default_database() if not getattr(db, "name", None) or db.name == "None": return "<cervmongo.AsyncIOClient>" else: return f"<cervmongo.AsyncIOClient.{db.name}>" def _process_record_id_type(self, record): one = False if isinstance(record, str): one = True if "$oid" in record: record = {"$in": [json_load(record), record]} else: try: record = {"$in": [DOC_ID.__supertype__(record), record]} except: pass elif isinstance(record, DOC_ID.__supertype__): record = record one = True elif isinstance(record, dict): if "$oid" in record or "$regex" in record: record = json_dump(record) record = json_load(record) one = True return (record, one)
[docs] def set_database(self, database): Config.set_mongo_db(database) if self._KWARGS: AsyncIOClient.__init__(self, mongo_uri=Config.MONGO_URI, default_collection=self._DEFAULT_COLLECTION, **self._KWARGS) else: AsyncIOClient.__init__(self, mongo_uri=Config.MONGO_URI, default_collection=self._DEFAULT_COLLECTION)
[docs] def COLLECTION(self, collection:str): self._DEFAULT_COLLECTION = collection class CollectionClient: __parent__ = CLIENT = self # INFO: variables _DEFAULT_COLLECTION = collection _MONGO_URI = self._MONGO_URI # INFO: general methods GENERATE_ID = self.GENERATE_ID COLLECTION = self.COLLECTION # INFO: GridFS file operations UPLOAD = self.UPLOAD DOWNLOAD = self.DOWNLOAD ERASE = self.ERASE # INFO: truncated Collection methods INDEX = partial(self.INDEX, collection) ADD_FIELD = partial(self.ADD_FIELD, collection) REMOVE_FIELD = partial(self.REMOVE_FIELD, collection) DELETE = partial(self.DELETE, collection) GET = partial(self.GET, collection) POST = partial(self.POST, collection) PUT = partial(self.PUT, collection) PATCH = partial(self.PATCH, collection) REPLACE = partial(self.REPLACE, collection) SEARCH = partial(self.SEARCH, collection) PAGINATED_QUERY = partial(self.PAGINATED_QUERY, collection) def __repr__(s): return "<cervmongo.AsyncIOClient.CollectionClient>" def get_client(s): return s.CLIENT return CollectionClient()
[docs] async def PAGINATED_QUERY(self, collection, limit:int=20, sort:PAGINATION_SORT_FIELDS=PAGINATION_SORT_FIELDS["_id"], after:str=None, before:str=None, page:int=None, endpoint:str="/", ordering:int=-1, query:dict={}, **kwargs): """ Returns paginated results of collection w/ query. Available pagination methods: - **Cursor-based (default)** - after - before - limit (results per page, default 20) - **Time-based** (a datetime field must be selected) - sort (set to datetime field) - after (records after this time) - before (records before this time) - limit (results per page, default 20) - **Offset-based** (not recommended) - limit (results per page, default 20) - page """ collection = collection or self._DEFAULT_COLLECTION assert collection, "collection must be of type str" if isinstance(sort, ENUM.__supertype__): sort = sort.value total_docs = await self.GET(collection, query, count=True, empty=0) if not page: if sort == "_id": pagination_method = "cursor" else: pagination_method = "time" cursor = await self.GET(collection, query, limit=limit, key=sort, before=before, after=after, sort=ordering, empty=[]) else: assert page >= 1, "page must be equal to or greater than 1" pagination_method = "offset" cursor = await self.GET(collection, query, perpage=limit, key=sort, page=page, sort=ordering, empty=[]) results = [ record async for record in cursor ] # INFO: determine 'cursor' template if sort == "_id": template = "_{_id}" else: template = "{date}_{_id}" new_after = None new_before = None if results: _id = results[-1]["_id"] try: date = results[-1][sort].isoformat() except: date = None if len(results) == limit: new_after = template.format(_id=_id, date=date) _id = results[0]["_id"] try: date = results[0][sort].isoformat() except: date = None if any((after, before)): new_before = template.format(_id=_id, date=date) if pagination_method in ("cursor", "time"): if before: check_ahead = await self.GET(collection, query, limit=limit, key=sort, before=new_before, empty=0, count=True) if not check_ahead: new_before = None elif after: check_ahead = await self.GET(collection, query, limit=limit, key=sort, after=new_after, empty=0, count=True) if not check_ahead: new_after = None response = { "data": results, "details": { "pagination_method": pagination_method, "query": dict_to_query(query), "sort": sort, "unique_id": getattr(self, "_UNIQUE_ID", "_id"), "total": total_docs, "count": len(results), "limit": limit } } endpoint = endpoint # TODO: Refactor if pagination_method in ("cursor", "time"): response["details"]["cursors"] = { "after": new_after, "before": new_before } before_url_template = "{endpoint}?sort={sort}&limit={limit}&before={before}" after_url_template = "{endpoint}?sort={sort}&limit={limit}&after={after}" else: # INFO: pagination_method == "offset" response["details"]["cursors"] = { "prev_page": page - 1 if page > 1 else None, "next_page": page + 1 if (page * limit) <= total_docs else None } before_url_template = "{endpoint}?sort={sort}&limit={limit}&page={page}" after_url_template = "{endpoint}?sort={sort}&limit={limit}&page={page}" if new_before: response["details"]["previous"] = before_url_template.format( endpoint=endpoint, sort=sort, page=page, limit=limit, after=new_after, before=new_before) else: response["details"]["previous"] = None if new_after: response["details"]["next"] = after_url_template.format( endpoint=endpoint, sort=sort, page=page, limit=limit, after=new_after, before=new_before) else: response["details"]["next"] = None return response
PAGINATED_QUERY.clean_kwargs = lambda kwargs: _clean_kwargs(ONLY=("limit", "sort", "after", "before", "page", "endpoint", "query"), kwargs=kwargs)
[docs] def GENERATE_ID(self, _id=None): if _id: return DOC_ID.__supertype__(_id) else: return DOC_ID.__supertype__()
[docs] async def UPLOAD(self, fileobj, filename:str=None, content_type:str=None, extension:str=None, **kwargs): assert self.FILES, "GridFS instance not initialized, run method 'set_database' with the desired database and try again" fileobj = file_and_fileobj(fileobj) metadata = get_file_meta_information(fileobj, filename=filename, content_type=content_type, extension=extension) filename = metadata['filename'] metadata.update(kwargs) file_id = await self.FILES.upload_from_stream(filename, fileobj, metadata=metadata) return file_id
[docs] async def ERASE(self, filename_or_id, revision:int=-1): assert self.FILES, "GridFS instance not initialized, run method 'set_database' with the desired database and try again" fs_doc = await self.DOWNLOAD(filename_or_id, revision=revision) await self.FILES.delete(fs_doc._id) await fs_doc.close()
[docs] async def DOWNLOAD(self, filename_or_id=None, revision:int=-1, skip:int=None, limit:int=None, sort:int=-1, **query): assert self.FILES, "GridFS instance not initialized, run method 'set_database' with the desired database and try again" revision = int(revision) if filename_or_id: if isinstance(filename_or_id, DOC_ID.__supertype__): return await self.FILES.open_download_stream(filename_or_id) else: return await self.FILES.open_download_stream_by_name(filename_or_id, revision=revision) return self.FILES.find(query, limit=limit, skip=skip, sort=sort, no_cursor_timeout=True)
[docs] async def DELETE(self, collection, record, soft:bool=False, one:bool=False): db = self.get_default_database() if not collection: if hasattr(self, '_DEFAULT_COLLECTION'): collection = self._DEFAULT_COLLECTION assert collection, "collection must be of type str" o_collection = collection[:] collection = db[collection] if not isinstance(record, (list, tuple)): record, _one = self._process_record_id_type(record) one = _one if _one else one if _one: record = {"_id": record} else: record = self._process_record_id_type(record)[0] if soft: data_record = await self.GET(o_collection, record) try: await self.PUT("deleted."+o_collection, data_record) except: data_record.pop("_id") await self.PUT("deleted."+o_collection, data_record) if isinstance(record, (str, ObjectId)): return await collection.delete_one({"_id": record}) elif isinstance(record, dict): if one: return await collection.delete_one(record) else: return await collection.delete_many(record) else: results = [] for _id in record: results.append(await collection.delete_one({"_id": _id})) return results
[docs] def INDEX(self, collection, key:str="_id", sort:int=1, unique:bool=False, reindex:bool=False): db = self.get_default_database() if not collection: if hasattr(self, '_DEFAULT_COLLECTION'): collection = self._DEFAULT_COLLECTION assert collection, "collection must be of type str" collection = db[collection] name = "%sIndex%s" % (key, "Asc" if sort == 1 else "Desc") try: if not name in collection.index_information(): collection.create_index([ (key, sort)], name=name, background=True, unique=unique) except: #print((_traceback())) pass
[docs] async def ADD_FIELD(self, collection, field:str, value:typing.Union[typing.Dict, typing.List, str, int, float, bool]='', data=False, query:dict={}): if not collection: if hasattr(self, '_DEFAULT_COLLECTION'): collection = self._DEFAULT_COLLECTION assert collection, "collection must be of type str" query.update({field: {"$exists": False}}) if data: records = await self.GET(collection, query, fields={ data: True}, empty=[]) else: records = await self.GET(collection, query, fields={ "_id": True}, empty=[]) for record in records: if data: await self.PATCH(collection, record["_id"], {"$set": { field: record[data]}}) else: await self.PATCH(collection, record["_id"], {"$set": { field: value}})
[docs] async def REMOVE_FIELD(self, collection, field:str, query:dict={}) -> None: if not collection: collection = self._DEFAULT_COLLECTION assert collection, "collection must be of type str" query.update({field: {"$exists": True}}) records = await self.GET(collection, query, distinct=True) for record in records: await self.PATCH(collection, record, {"$unset": {field: ""}})
[docs] async def GET(self, collection, id_or_query:typing.Union[DOC_ID, str, typing.Dict]={}, sort:int=1, key:str="_id", count:bool=None, search:str=None, fields:dict=None, page:int=None, perpage:int=False, limit:int=None, after:str=None, before:str=None, empty=None, distinct:str=None, one:bool=False, **kwargs): db = self.get_default_database() collection = collection or self._DEFAULT_COLLECTION assert collection, "collection not provided" if not isinstance(collection, (list, tuple, types.GeneratorType)): collection = [collection] cols = list(set(collection)) results = [] number_of_results = len(cols) if distinct == True: distinct = "_id" id_or_query, _one = self._process_record_id_type(id_or_query) one = _one if _one else one if _one: query = {"_id": id_or_query} else: query = id_or_query for collection in cols: collection = db[collection] if query or not search: if count and not limit: if query: results.append(await collection.count_documents(query, **kwargs)) else: results.append(await collection.estimated_document_count(**kwargs)) elif distinct: cursor = await collection.distinct(distinct, filter=query, **kwargs) results.append(sorted(cursor)) elif perpage: total = (page - 1) * perpage cursor = collection.find(query, projection=fields, **kwargs) results.append(cursor.sort([(key, sort)]).skip(total).limit(perpage)) elif limit: if any((query, after, before)): query = {"$and": [ query ]} if after or before: if after: sort_value, _id_value = after.split("_") _id_value = DOC_ID.__supertype__(_id_value) query["$and"].append({"$or": [ {key: {"$lt": _id_value}} ]}) if key != "_id": sort_value = dateparse(sort_value) query["$and"][-1]["$or"].append({key: {"$lt": sort_value}, "_id": {"$lt": _id_value}}) elif before: sort_value, _id_value = before.split("_") _id_value = DOC_ID.__supertype__(_id_value) query["$and"].append({"$or": [ {key: {"$gt": _id_value}} ]}) if key != "_id": sort_value = dateparse(sort_value) query["$and"][-1]["$or"].append({key: {"$gt": sort_value}, "_id": {"$gt": _id_value}}) if count: try: cursor = await collection.count_documents(query, limit=limit, hint=[(key, sort)], **kwargs) except: cursor = len(await collection.find(query, fields, **kwargs).sort([(key, sort)]).to_list(limit)) results.append(cursor) else: cursor = collection.find(query, projection=fields, **kwargs).sort([(key, sort)]).limit(limit) results.append(cursor) elif one: val = await collection.find_one(query, projection=fields, sort=[(key, sort)], **kwargs) results.append(val if val else empty) else: cursor = collection.find(query, projection=fields, **kwargs).sort([(key, sort)]) results.append(cursor) elif search: try: if count: results.append(await cursor.count_documents({"$text": {"$search": search}})) elif distinct: results.append(await collection.distinct(distinct, filter={"$text": {"$search": search}})) else: cursor = collection.find({"$text": {"$search": search}}) if perpage: total = (page - 1) * perpage results.append(cursor.sort([(key, sort)]).skip(total).limit(perpage)) else: results.append(cursor.sort([(key, sort)])) except: cursor = await collection.command('textIndex', search=search) if count: results.append(cursor.count()) elif distinct: results.append(cursor.distinct(distinct)) else: if perpage: total = (page - 1) * perpage results.append(cursor.sort([(key, sort)]).skip(total).limit(perpage)) else: results.append(cursor.sort([(key, sort)])) else: raise Error("unidentified error") if number_of_results == 1: return results[0] else: return results
[docs] async def SEARCH(self, collection, search:str, **kwargs): if not collection: if hasattr(self, '_DEFAULT_COLLECTION'): collection = self._DEFAULT_COLLECTION assert collection, "collection must be of type str" return await self.GET(collection, search=search, **kwargs)
[docs] async def POST(self, collection, record_or_records:typing.Union[typing.List, typing.Dict]): db = self.get_default_database() collection = collection or self._DEFAULT_COLLECTION assert collection, "collection must be of type str" collection = db[collection] if isinstance(record_or_records, (list, tuple)): return await collection.insert_many(record_or_records) elif isinstance(record_or_records, dict): return await collection.insert_one(record_or_records) else: raise TypeError("invalid record type '{}' provided".format(type(record_or_records)))
[docs] async def PUT(self, collection, record_or_records:typing.Union[typing.List, typing.Dict]): """ creates or replaces record(s) with exact _id provided, _id is required with record object(s) returns original document, if replaced """ db = self.get_default_database() collection = collection or self._DEFAULT_COLLECTION assert collection, "collection must be of type str" collection = db[collection] if isinstance(record_or_records, (list, tuple)): assert all([ record.get("_id", None) for record in record_or_records ]), "not all records provided contained an _id" return await collection.insert_many(record_or_records, ordered=False) elif isinstance(record_or_records, dict): assert record_or_records.get("_id", None), "no _id provided" query = {"_id": record_or_records["_id"]} return await collection.find_one_and_replace(query, record_or_records, upsert=True) else: raise TypeError("invalid record type '{}' provided".format(type(record_or_records)))
[docs] async def REPLACE(self, collection, original, replacement:dict, upsert=False): db = self.get_default_database() if not collection: if hasattr(self, '_DEFAULT_COLLECTION'): collection = self._DEFAULT_COLLECTION assert collection, "collection must be of type str" collection = db[collection] return await collection.replace_one({"_id": original}, replacement, upsert=upsert)
[docs] async def PATCH(self, collection, id_or_query:typing.Union[DOC_ID, typing.Dict, typing.List, str], updates:typing.Union[typing.Dict, typing.List], upsert:bool=False, w:int=1): db = self.get_default_database() collection = collection or self._DEFAULT_COLLECTION assert collection, "collection not provided" collection = db[collection] if w != 1: WRITE = WriteConcern(w=w) collection = collection.with_options(write_concern=WRITE) if isinstance(id_or_query, (str, DOC_ID.__supertype__)): assert isinstance(updates, dict), "updates must be dict" id_or_query, _ = self._process_record_id_type(id_or_query) query = {"_id": id_or_query} set_on_insert_id = {"$setOnInsert": query} updates.update(set_on_insert_id) results = await collection.update_one(query, updates, upsert=upsert) return results elif isinstance(id_or_query, dict): assert isinstance(updates, dict), "updates must be dict" results = await collection.update_many(id_or_query, updates, upsert=upsert) return results elif isinstance(id_or_query, (tuple, list)): assert isinstance(updates, (tuple, list)), "updates must be list or tuple" results = [] for i, _id in enumerate(id_or_query): _id, _ = self._process_record_id_type(id_or_query) query = {"_id": _id} set_on_insert_id = {"$setOnInsert": query} updates[i].update(set_on_insert_id) result = await collection.update_one(query, updates[i], upsert=upsert) results.append(result) return results else: raise Error("unidentified error")
[docs]class AsyncIODoc(AsyncIOClient): """ Custom MongoClient subclass with customizations for creating standardized documents and adding json schema validation. """ _MONGO_URI = lambda _: getattr(Config, "MONGO_URI", None) _DOC_TYPE:str = None #: MongoDB collection to use _DOC_ID:str = "_id" _DOC_SAMPLE:str = None _DOC_SCHEMA:str = None _DOC_MARSHMALLOW:str = False _DOC_DEFAULTS:dict = {} _DOC_RESTRICTED_KEYS:list = [] _DOC_ENUMS:list = [] _DOC_SETTINGS:str = None def __init__(self, _id=None, doc_type:str=None, doc_sample:typing.Union[typing.Dict, str]=None, doc_schema:typing.Union[typing.Dict, str]=None, doc_id:str=None, mongo_uri:str=None, **kwargs): self._MONGO_URI = mongo_uri or self._MONGO_URI if callable(self._MONGO_URI): self._MONGO_URI = self._MONGO_URI() # INFO: set default collection self._DOC_TYPE = doc_type or self._DOC_TYPE assert self._DOC_TYPE, "collection must be of type str" self._DEFAULT_COLLECTION = self._DOC_TYPE # INFO: location for sample record, used as template self._DOC_SAMPLE = doc_sample or self._DOC_SAMPLE # INFO: path to validation schema self._DOC_SCHEMA = doc_schema or self._DOC_SCHEMA # INFO: sets the unique id field for the document type, if any (cannot be _id) self._DOC_ID = doc_id or self._DOC_ID assert self._DOC_ID, "unique id field name must be of type str" for kwarg in kwargs.keys(): if kwarg.lower() in ('doc_settings', 'doc_marshmallow', 'doc_defaults', 'doc_restricted_keys'): setattr(self, "_{}".format(kwarg.upper()), kwargs.pop(kwarg)) # Initial Record object with sample else start blank dict if self._DOC_SAMPLE: if isinstance(self._DOC_SAMPLE, str): sample_full_path = os_path.join(Config.JSON_SAMPLE_PATH, self._DOC_SAMPLE) with open(sample_full_path) as file1: sample = json_load(file1.read()) elif isinstance(self._DOC_SAMPLE, dict): sample = self._DOC_SAMPLE else: raise TypeError("_DOC_SAMPLE is invalid type '{}', valid types are dict and str".format(type(self._DOC_SAMPLE))) sample_parent_found = sample.pop("__parent__", None) while sample_parent_found: parent_sample_full_path = os_path.join(Config.JSON_SAMPLE_PATH, sample_parent_found) with open(parent_sample_full_path) as _file: parent_sample = json_load(_file.read()) parent_sample.update(sample) sample = parent_sample sample_parent_found = sample.pop("__parent__", None) self.sample = sample else: self.sample = {} # INFO: Load schema else start blank dict to add manual validation entries if self._DOC_SCHEMA: if isinstance(self._DOC_SCHEMA, str): schema_full_path = os_path.join(Config.JSON_SCHEMA_PATH, self._DOC_SCHEMA) with open(schema_full_path) as _file: self.schema = json_load(_file.read()) elif isinstance(self._DOC_SCHEMA, dict): self.schema = self._DOC_SCHEMA else: raise TypeError("_DOC_SCHEMA is invalid type '{}', valid types are dict and str".format(type(self._DOC_SCHEMA))) else: self.schema = {} AsyncIOClient.__init__(self, **kwargs) db = self.get_default_database() if not getattr(db, "name", None) or db.name == "None": raise Exception("database not provided in MongoDB URI") else: self._DOC_DB = db.name # Initialize enums if not self._DOC_ENUMS: pass #enums_record = self.GET(self._DOC_SETTINGS, "enums" # INFO: If class has a _DOC_ID assigned, create unique index if self._DOC_ID != "_id": self.INDEX(self._DOC_TYPE, key=self._DOC_ID, sort=1, unique=True) self.load(_id) def __repr__(self): if self.RECORD.get("_id", None): _id = self.id() return f"<cervmongo.AsyncIODoc.{self._DOC_DB}.{self._DOC_TYPE}.{self._DOC_ID}:{_id}>" else: return "<cervmongo.AsyncIODoc>" def __enter__(self): return self def __exit__(self): self.close() def _process_restrictions(self, record:dict=None): """removes restricted keys from record and return record""" try: if record: assert isinstance(record, (MongoDictResponse, dict)), "Needs to be a dictionary, got {}".format(type(record)) return {key: value for key, value in record.items() if not key in self._DOC_RESTRICTED_KEYS} else: return {key: value for key, value in self.RECORD.items() if not key in self._DOC_RESTRICTED_KEYS} except: logger.exception("encountered error when cleaning self.RECORD, returning empty dict") return {} def _p_r(self, record:dict=None): """truncated alias for _process_restrictions""" return self._process_restrictions(record=record) def _generate_unique_id(self, template:str="{total}", **kwargs): return template.format(**kwargs).upper() def _timestamp(self, value=None): if value: try: value = dateparse(value) except: value = current_datetime() else: value = current_datetime() return value def _guess_corresponding_fieldname(self, _type="unknown", related_field:str=""): time_fields = ("date", "datetime", "time") if _type in time_fields: # NOTE: a timestamp is 'mostly' accompanied by a user or relation if related_field: if "_" in related_field: field_parts = related_field.split("_") elif "-" in related_field: field_parts = related_field.split("-") else: field_parts = related_field.split() for field_part in field_parts: field_part = field_part.strip(" _-").lower() if any(x in field_part for x in time_fields): continue else: return f"{field_part}_by" return "for" else: return "by" else: # NOTE: an unknown type field has a timestamp pairing or desc if related_field: return f"{related_field}_description" else: return "field_description" async def _related_record(self, collection:str=None, field:str="_id", value=False, additional:dict={}): additional.update({ field: True }) record = await self.GET(collection, {field: value}, fields=additional, one=True, empty={}) assert record, 'Error: No record found' record['key'] = field record['collection'] = collection record[field] = record[field] return record
[docs] async def load(self, _id=None): # If _id specified on init, load actual record versus blank template if _id: if self._DOC_ID: self.RECORD = await self.GET(self._DOC_TYPE, {self._DOC_ID: _id}, one=True) else: self.RECORD = await self.GET(self._DOC_TYPE, _id) else: self.RECORD = copy.deepcopy(self.sample) if not self.RECORD: self.RECORD = {} return StandardResponse( data=self._p_r(self.RECORD), details={ "state": "unsaved" if not self.RECORD.get("_id", None) else "saved", "unique_id": self._DOC_ID })
[docs] async def view(self, _id=False): if not _id: return StandardResponse(data=self._p_r(self.RECORD)) else: if self._DOC_ID: return StandardResponse( data=self._p_r(self.GET(self._DOC_TYPE, {self._DOC_ID: _id}, one=True, empty={})), details={ "unique_id": self._DOC_ID } ) else: return StandardResponse( data=self._p_r(self.GET(self._DOC_TYPE, {"_id": _id}, one=True, empty={})), details={ "unique_id": self._DOC_ID } )
[docs] async def reload(self): assert self.RECORD.get("_id") self.RECORD = await self.GET(self._DOC_TYPE, self.RECORD["_id"]) return { "data": self._p_r(self.RECORD), "details": { "unique_id": self._DOC_ID } }
[docs] def id(self): return self.RECORD.get(self._DOC_ID, None)
[docs] async def create(self, save:bool=False, trigger=None, template:str="{total}", query:dict={}, **kwargs): assert self.RECORD.get("_id") is None, """Cannot use create method on an existing record. Use patch method instead.""" if self._DOC_MARSHMALLOW: self._DOC_MARSHMALLOW().load(kwargs) self.RECORD.update(kwargs) elif self.sample: # INFO: removing invalid keys based on sample record [ silent_drop_kwarg(kwargs, x, reason="not in self.sample") for x in list(kwargs.keys()) if not x in self.sample ] self.RECORD.update(kwargs) else: self.RECORD.update(kwargs) kwargs['total'] = str(self.GET(self._DOC_TYPE, query).count() + 1).zfill(6) if self._DOC_ID and not self.RECORD.get(self._DOC_ID): self.RECORD[self._DOC_ID] = self._generate_unique_id(template=template, **kwargs) if save: await self.save(trigger=None) return { "data": self._p_r(self.RECORD), "details": {"unique_id": self._DOC_ID, "collection": self._DOC_TYPE, "_id": self.RECORD[self._DOC_ID]} }
[docs] async def add_enum(name:str, value): self.RECORD
[docs] async def push(self, **kwargs): assert self.RECORD.get("_id"), """Cannot use push method on a non-existing record. Use create method instead.""" if "_id" in kwargs: kwargs.pop("_id") await self.PATCH(None, self.RECORD["_id"], {"$push": kwargs}) await self.reload() keys = list(kwargs.keys()) values = [ kwargs[key] for key in keys ] return { "data": self._p_r(self.RECORD), "details": { "action": "push", "desc": "push a value to end of array, not set", "field": kwargs.keys(), "value": kwargs.values() } }
[docs] async def pull(self, **kwargs): assert self.RECORD.get("_id"), """Cannot use pull method on a non-existing record. Use create method instead.""" if "_id" in kwargs: kwargs.pop("_id") await self.PATCH(None, self.RECORD["_id"], {"$pull": kwargs}) await self.reload() keys = list(kwargs.keys()) values = [ kwargs[key] for key in keys ] return { "data": self._p_r(self.RECORD), "details": { "action": "pull", "desc": "pull all instances of a value from an array", "field": kwargs.keys(), "value": kwargs.values() } }
[docs] async def increment(self, query:dict={}, **kwargs): assert self.RECORD.get("_id"), """Cannot use increment method on a non-existing record. Use create method instead.""" if "_id" in kwargs: kwargs.pop("_id") query.update({"_id": self.RECORD["_id"]}) await self.PATCH(None, query, {"$inc": kwargs}, multi=True) await self.reload() keys = list(kwargs.keys()) values = [ kwargs[key] for key in keys ] return { "data": self._p_r(self.RECORD), "details": { "action": "increment", "desc": "increment the integer fields by the amount provided", "field": keys, "increment": values } }
[docs] async def update(self, query:dict={}, **kwargs): assert self.RECORD.get("_id"), """Cannot use update method on a non-existing record. Use create method instead.""" if "_id" in kwargs: kwargs.pop("_id") query.update({"_id": self.RECORD["_id"]}) keys = list(kwargs.keys()) old_values = [ self.RECORD.get(key, None) for key in keys ] await self.PATCH(None, query, {"$set": kwargs}, multi=True) await self.reload() new_values = [ self.RECORD.get(key, None) for key in keys ] return { "data": self._p_r(self.RECORD), "details": { "action": "update", "desc": "replace existing field values with new values", "field": keys, "values_old": old_values, "values_new": new_values } }
[docs] async def patch(self, save:bool=False, trigger=None, **kwargs): assert self.RECORD.get("_id"), """Cannot use patch method on a non-existing record.Use create method instead.""" if "_id" in kwargs: kwargs.pop("_id") if self._DOC_MARSHMALLOW: _DOC_MARSHMALLOW().load(kwargs, partial=True) self.RECORD.update(kwargs) elif self.sample: assert all([ x in self.sample for x in kwargs.keys()]) self.RECORD.update(kwargs) else: self.RECORD.update(kwargs) if save: await self.save(trigger=trigger) return { "data": self._p_r(self.RECORD), "details": { "processed": True, "diff": kwargs } }
[docs] async def save(self, trigger=None): _id = None if self._DOC_DEFAULTS: for key, value in self._DOC_DEFAULTS.items(): if not self.RECORD.get(key, None): self.RECORD[key] = value if self.RECORD.get("_id", None): _id = self.RECORD.pop("_id", None) try: if self._DOC_MARSHMALLOW: self._DOC_MARSHMALLOW().load(self.RECORD) else: validate(self.RECORD, self.schema) except: raise else: if _id: self.RECORD["_id"] = _id await self.PUT(None, self.RECORD) else: result = await self.POST(None, self.RECORD) self.RECORD["_id"] = result.inserted_id if trigger: trigger() return { "data": self._p_r(self.RECORD), "details": {} }
[docs] async def close(self): # TODO: clean closing logic await self.load() return { "data": self.RECORD, "details": {} }
[docs] async def add_relation(self, key:str, *args, **kwargs): self.RECORD[key] = await self._related_record(*args, **kwargs) return { "data": self._p_r(self.RECORD), "details": {"key": key, "relation": self.RECORD[key]} }
[docs] async def add_timestamp(self, key:str, value=None, relation:dict={}): self.RECORD[key] = self._timestamp(value) if relation: related_key = relation.pop("key", self._guess_corresponding_fieldname(_type="datetime", related_field=key)) self.add_relation(related_key, **relation) return { "data": self._p_r(self.RECORD), "details": {"key": key, "value": self.RECORD[key]} }
[docs] async def add_object(self, field:str, object_name:str=None, key:str=None, value=None, **kwargs): if not object_name: if not key: self.RECORD[field] = {} else: self.RECORD[field] = {key: value if value else self.GENERATE_ID()} else: template_path = os_path.join(_DOC_SAMPLE, object_name + ".json") assert os.path.exists(template_path), "path does not exist" with open(template_path) as file1: self.RECORD[field] = json_util.loads(file1.read()) if key: self.RECORD[field][key] = value if value else self.GENERATE_ID() return { "data": self._p_r(self.RECORD), "details": {field: self.RECORD[field], "state": "unsaved"} }
[docs]def get_async_client() -> AsyncIOClient: """returns AsyncIOClient class""" global SUPPORT_ASYNCIO_CLIENT if not SUPPORT_ASYNCIO_CLIENT: raise Exception("motor not installed") return AsyncIOClient
[docs]def get_async_doc() -> AsyncIODoc: """returns AsyncIODoc class""" global SUPPORT_ASYNCIO_CLIENT if not SUPPORT_ASYNCIO_CLIENT: raise Exception("motor not installed") return AsyncIODoc