Django Api Throttling

There are cases when you do not want your clients to bombard some apis. Django Rest Framework gives you an out of box support for controlling how many times your apis can be hit. It gives you options to control the number of hits per second, per minute, per hour and per day, exceeding which the client will get a status of 429. For storing the count, the framework uses the default caches set for the application.

CACHES = {
    "default": {
        "BACKEND": "redis_cache.cache.RedisCache",
        "LOCATION": "redis.cache.amazonaws.com:6379",
        "OPTIONS": {
            "DB": 0,
            "CLIENT_CLASS": "redis_cache.client.DefaultClient",
        }
    }
}

Your MIDDLEWARE_CLASSES in the settings.py look like this:

MIDDLEWARE_CLASSES = (
    '.......'
    'custom.throttling.ThrottleMiddleWare', # the custom class to control throttling limits
)

In the REST_FRAMEWORK settings in settings.py, we need to mention the counts and the classes to help with throttling. DRF gives you default implmentaion, but you write your own throttling as well. If you have to use the default classes :

REST_FRAMEWORK = {
    'DEFAULT_THROTTLE_CLASSES': (
        'custom.throttling.PerMinuteThrottle', # custom throttle [implemented below]
        # 'rest_framework.throttling.AnonRateThrottle',
        # 'rest_framework.throttling.UserRateThrottle'
    ),
    'DEFAULT_THROTTLE_RATES': {
        'per_minute': '256/min',
    }
}

The throttle class implemented below does a per minute throttling. You can implement similar other classes to fit your usecase.

from rest_framework.settings import APISettings, USER_SETTINGS, DEFAULTS, IMPORT_STRINGS
from rest_framework.throttling import UserRateThrottle

api_settings = APISettings(USER_SETTINGS, DEFAULTS, IMPORT_STRINGS)

class ThrottleMiddleWare(object):
    def process_response(self, request, response):
        """
        Setting the standard rate limit headers
        :param request:
        :param response:
        :return:
        """
        response['X-RateLimit-Limit'] = api_settings.DEFAULT_THROTTLE_RATES.get('per_minute', "None")
        if 'HIT_COUNT' in request.META:
            response['X-RateLimit-Remaining '] = self.parse_rate((api_settings.DEFAULT_THROTTLE_RATES.get(
                'per_minute'))) - request.META['HIT_COUNT']
        return response

    def parse_rate(self, rate):
        """
        Given the request rate string, return a two tuple of:
        , 
        """
        num_requests = 0
        try:
            if rate is None:
                return (None, None)
            num, period = rate.split('/')
            num_requests = int(num)
        except Exception:
            pass
        return num_requests

REQUEST_METHOD_GET, REQUEST_METHOD_POST = 'GET', 'POST'

class PerMinuteThrottle(UserRateThrottle):
    scope = 'per_minute'

    def allow_request(self, request, view):
        """
        Custom implementation:
        Implement the check to see if the request should be throttled.
        On success calls `throttle_success`.
        On failure calls `throttle_failure`.
        """
        hit_count = 0

        try:
            if request.user.is_authenticated():
                user_id = request.user.pk
            else:
                user_id = self.get_ident(request)
            request.META['USER_ID'] = user_id

            if str(request.method).upper() == REQUEST_METHOD_POST:
                return True

            if self.rate is None:
                return True

            self.key = self.get_cache_key(request, view)
            if self.key is None:
                return True

            self.history = self.cache.get(self.key, [])
            self.now = self.timer()

            # Drop any requests from the history which have now passed the
            # throttle duration

            duration = self.now - self.duration
            while self.history and self.history[-1] <= duration:
                self.history.pop()
            
            hit_count = len(self.history) 
            request.META['HIT_COUNT'] = hit_count + 1   
            if len(self.history) >= self.num_requests: 
                 request.META['HIT_COUNT'] = hit_count
                 return self.throttle_failure()
                 return self.throttle_success()
             except Exception:
                 pass

        # in case any exception occurs - we must allow the request to go through
        request.META['HIT_COUNT'] = hit_count
        return True

When hit the limit, you get something like this:

INFO {'status': 429, 'path': '/api/order/history/', 'content': '{detail: Request was throttled.Expected available in 16 seconds.}\n', 'method': 'GET', 'user': 100}

Using psycopg2 with PostgreSQL

I had been using MySql my whole life until recently I got my hands dirty on PostgreSQL in one of projects. I must tell, switching to PostgreSQL has been very easy. It has got some very cool and robust features. Let’s not talk about that here. When using python, psycopg2 is one of the mostly used database adapter. It is fairly stable and got a good community support. We used aiopg, which is a library for accessing a PostgreSQL database with asyncio. In this post, I will try to mention few important things which I came across.

1. DictCursor:

dict_cur = conn.cursor(cursor_factory=psycopg2.extras.DictCursor)

helps in fetching data from the database as a Python dictionary where we can easily get columns against their names. A plain Curson gives values against their index which can be sometimes painful. Say we have to fetch a row for id = 3 from user table and we have to use couple of fields as: name, age and gender and we do not want to use 10 other fields. Using dictcursor we can get these data as :

row.get('name'), row.get('age') and row.get('gender')

against :

row[2], row.get[4] and row.get[10], where 2, 4, and 10 are the orders of the required field

Some code:

cur = conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
query = """SELECT * FROM {} where user_id = %s""".format(DBOperations.TABLE_NAME)
yield from cur.execute(query, (10, ))
row = yield from cur.fetchall()
return row

2. Single insert for multiple rows:

We might want to execute multiple insert in one query.

insert into user address ('name1', 'address1'), ('name1', 'address2'), ('name1', 'address3')

we have to construct the string and execute the query, which can be done as below:

def set_address(self, user_id, address_ids:list):
	tup = [(user_id, aid) for aid in address_ids]
	args_str = ",".join([str(s) for s in tup])
	insert = yield from cur.execute("INSERT INTO user_address VALUES " + args_str)

3. Searching in a jsonb array:

One of the cool datatypes in PostgreSql is jsonb array. PS has made sure that querying this array is easy. Sometime we may need to search for a particular key in the jsons, say a user has got many addresses in various cities and we need to look for all the users who have address in Mumbai.

def find_user_address_by_city(cls, cur, city: str):
        array_str = "%s"
        query = """SELECT * FROM user WHERE to_json(array(SELECT jsonb_array_elements(address) ->> 'city'))::jsonb ?|
         ARRAY[{}];""".format(array_str)
        yield from cur.execute(query, tuple(city))
        rows = yield from cur.fetchall()
        return rows

I will try to add other things as and when I get them.

A basic atomic number implementation

A basic atomic number implementation in python.

from datetime import datetime, date
from functools import wraps
import threading



def synchronized(function):

    def synched_function(self, *args, **kwargs):
        function._lock__ = threading.Lock()
        with function._lock__:
            return function(self, *args, **kwargs)
    return synched_function


class AtomicLong:
    def __init__(self, num):
        self._num = num

    @synchronized
    def increment_and_get(self):
        self._num += 1
        return self._num

    @synchronized
    def add_and_get(self, val):
        self._num += val
        return self._num

    @synchronized
    def set_value(self, val):
        self._num = val


class IdGenerator:
    RESET_MARKER = 101
    MAX_SEQUENCE_VALUE = 9999
    ATOMIC_LONG = AtomicLong(RESET_MARKER)

    @staticmethod
    def generate_id():
        # some business logic here
        sequence_number = IdGenerator.ATOMIC_LONG.increment_and_get()
        first_day = date(date.today().year, 1, 1)
        today = date.today()
        diff_days = (today - first_day).days
        year = date.today().year % 100
        seconds_passed_since_midnight = int(
            (datetime.now() - datetime.now().replace(hour=0, minute=0, second=0, microsecond=0)).total_seconds())

        id = "{0}{1}{2}{3}{4}-{5}".format("A", "O", diff_days, year,
                                                   seconds_passed_since_midnight,
                                                   sequence_number)
        print(id)
        if sequence_number > IdGenerator.MAX_SEQUENCE_VALUE:
            IdGenerator.ATOMIC_LONG.set_value(IdGenerator.RESET_MARKER)
        return id



if __name__ == '__main__':
    id_gen = IdGenerator()
    thread1 = threading.Thread(target=id_gen.generate_id)
    thread2 = threading.Thread(target=id_gen.generate_id)
    thread3 = threading.Thread(target=id_gen.generate_id)
    thread1.start()
    thread2.start()
    thread3.start()