Machine Learning for Drummers

TL;DR: In this post, I build an app that classifies whether an audio sample is a kick drum, snare drum, or other drum sample with 87% accuracy using 🎉machine learning. 🎉

First and foremost, I’m a drummer. At my day job, I work on machine learning systems for recommending music to people at Spotify. But outside my 9-to-5, I’m a musician, and my journey through music started as a drummer. When I’m not drumming in my spare time, I’ll often be creating electronic music - with a lot of percussion in it, of course.

If you’re not familiar with electronic music production, many (if not most) modern electronic music uses drum samples rather than real, live recordings of drummers to provide the rhythm. These drum samples are often distributed commercially, as sample packs, or created by musicians and shared for free online. Often, though, these samples can be hard to use, as their labeling and classification leaves a lot to be desired:


Various companies have tried to tackle this problem by creating their own
proprietary formats for sample packs, such as Native Instruments’ Battery or Kontakt formats. Both use explicit metadata and allow users to browse samples by a variety of tags. However, these are all (usually) expensive software packages and require you to learn their workflows.

In an effort to better understand how to use machine learning techniques, I
decided to use machine learning to try to solve this fairly simple problem:

Is a given audio file a sample of a kick drum, snare drum, hi-hat, other percussion, or something else?

For example, which drums do these two samples sound like?

Humans have no trouble classifying these two sounds, as we’ve likely heard them tens of thousands of times before. The human brain is great at this kind of problem - computers, however, require some training.

In machine learning, this is often called a classification problem, because it takes some data and classifies (as in chooses a class for) it. You might think of this as a kind of automated sorting system (although I’m using the word “sorting” here to mean “sort into groups” rather than “to put in a specific ranking or order”).

For those unfamiliar with machine learning, you might say:

Why not just train the computer to learn what a kick drum is (and so on) by giving it a whole bunch of data?

This is mostly correct already! (Hooray, you’re a machine learning

The trouble comes from deciding what data means in the above sentence. We could:

  1. Give the computer all of the data we have and let “machine learning” figure out what’s important and what’s not.
  2. or give the computer all of the data we have, but do a bit of pre-processing first to hint at parts of the data that might be important, then have “machine learning” classify our samples for us.

Option 1 above is tricky, as our data comes in many different forms - long audio files, short audio files, different formats, different bit depths, sample rates, and so on, which would add a ton of complexity to our algorithm. Throwing all of this at a machine and asking it to make sense of it would require a lot of data for it to figure out what we humans already know.

Instead of making the computer do a ton of extra work, we can use option 2 as a middle ground: we can choose some things about the audio samples that we think might be relevant to the problem, and provide those things to a machine learning algorithm and have it do the math for us. These things are known as features.

(If this word is confusing, think of a feature just like a feature of, say, a TV - only instead of “42-inch screen” and “HDMI input”, our features might be “4.2 seconds long” and “maximum loudness 12dB”. The word means the same thing in both contexts.)

This process of figuring out what features we want to use is commonly known as feature extraction, which makes sense. Given our input data (audio files), let’s come up with a list of features that us, as humans, might find relevant to deciding if the file is a kick drum or a snare drum.

These are just some of the many features that might be useful for solving our classification problem, but let’s start with these four and see how far we get.

As with all machine learning problems, to teach the machine to do something, you have to have some sort of training data. In this case, I’m going to use a handful of samples - roughly 20-30 from each instrument - from the tens of thousands of samples I have in my sample collection. When choosing these samples, I want to find:

I put together a list of these samples - 100 files, roughly 50 megabytes of sample data, in five separate folders: kick, snare, hat, percussion, and other. (Most of these samples are from and are licensed under a Creative Commons Attribution License, so special thanks to waveplay, Seidhepriest, and quartertone for making their samples available for free!)

Now that we’ve got some data to train on, let’s write some code to perform the feature extraction mentioned earlier. These features aren’t super hard for us to calculate, but they’re also not super simple, so I’ve written some code below to extract them by using librosa, a wonderful Python library for audio analysis by the wonderful Brian McFee et al.

(All of the code in this blog post is available on Github - feel free to download it and try running it on your own machine if you’re interested.)

# from
def features_for(file):
    # Load and trim the audio file to only the parts that aren't silent.
    audio, rate = load_and_trim(file)

    # Use poorly_estimate_fundamental to figure out what the rough
    # pitch is, along with the standard deviation - how much it varies.
    fundamental, f_stddev = poorly_estimate_fundamental(audio, rate)

    # Like an equalizer, find out how loud each "frequency band" is.
    # In this case, we're just splitting up the audio spectrum into
    # three very wide sections, low, mid, and high.
    low, mid, high = average_eq_bands(audio, 3)

    return {
        "duration":              librosa.get_duration(audio, rate),
        "start_loudness":        loudness_at(audio, 0),
        "mid_loudness":          loudness_at(audio, len(audio) / 2),
        "end_loudness":          loudness_at(audio, len(audio)),
        "fundamental_freq":      fundamental,
        "fundamental_deviation": f_stddev,
        "average_eq_low":        low,
        "average_eq_mid":        mid,
        "average_eq_high":       high,

Now we’ve got a number of features extracted from each sample. We can save these as one large JSON file for use later by our machine learning algorithm. (We haven’t done any learning yet, just figured out the data that we want to learn with.)

You can think of these features as measurements we’re taking of the samples, without having to use the entire contents of the samples themselves. (And that’s very true in this case - we started with over 50 megabytes of samples, but the features themselves are only 150 kilobytes - that’s more than 300 times smaller!)

Now, we can take these features and give them to a machine learning
algorithm and have it learn from them. But hold on a sec - let’s get specific about which algorithm we’re talking about, and about what learning means in this context.

We’re going to use an algorithm called a decision tree in this post, which is a commonly used machine learning algorithm that doesn’t involve some of the buzzwords that you may have heard, like “neural networks,” “deep learning,” or “artificial intelligence.” A decision tree is a system that splits data into categories by learning thresholds for each feature in a recursive way. (If that’s confusing, don’t worry too much about it - but checkout R2D3’s amazing visual example of how decision trees work if you’re curious).

# from

def train_and_evaluate_model():
    # First, let's read the features that we got from feature_extract.
    features, classes, sample_names, _, _ = read_data()

    # Let's use this percentage of the data to train, and the rest for
    # testing. Why not just train on all the data? That would result in
    # a model that is overfitted, or overly good at the data that it's
    # seen and does poorly with data that it hasn't seen.
    training_percentage = 0.75
    num_training_samples = int(len(features) * training_percentage)

    # Here we separate all of our features and classes into just the
    # ones we want to train on...
    train_features = features[:num_training_samples]
    train_classes = classes[:num_training_samples]

    # ...and we do the training, which creates our model!
    model = DecisionTreeClassifier().fit(train_features, train_classes)

In this case, trains a model by creating a decision tree -
which is our model - whose weights are statistically determined by the data that we pass in. Again, the specifics aren’t necessary to understand for the rest of this post, but here’s what a similar model looks like when visualized:
Each new sample is passed into this tree, and the features that we provided are evaluated from the top down. For example, if a new sample has average_eq_2_10 ≤ -56.77, as the top block in the diagram shows, the decision tree would move to the left and then check its fundamental_5 feature. It would continue to do so until it reaches the bottom of the tree, or a “leaf” (ha, tree, leaf, get it?), where it would declare that the given sample is whatever class (or colour, in this diagram) that the leaf is.

Now, if we run, we should see two lists: one of the training
accuracy (how well the model predicted the kind of sample for samples that it saw during training) and the test accuracy (now well the model predicted samples that it hadn’t seen before). Our training accuracy is 100%, which is not surprising - that data was used to create the model in the first place! And thanks to the features we selected, of the samples that the model hadn’t seen before, it got most guesses (~87%) correct. This is pretty good for a first try! (If you run this code on your own laptop, you should find that it takes roughly 12 seconds to train on the provided example data.)


Our 87% performance is decent, but that 13% error rate might be considered an example of what’s called overfitting - our model has been trained to be overly specific and be completely accurate for data that it’s seen before, but it has trouble when it sees data that’s new to it. In some sense, this is similar to how humans learn; when someone sees something new that they hadn’t seen in school or heard about before, they’re bound to make mistakes.

To avoid overfitting our model, we could take a number of approaches:

All three of these are valid approaches, and they’re also left up to the reader to investigate. We could also try other classification methods instead of using a decision tree, although surprisingly a naïve decision tree works pretty well for this problem.

So! We’ve built a machine learning classifier for drum samples. That’s kinda cool. There are a couple things to note about this system:

If you’ve got your own sample library, or want to give this problem a try with samples you’ve found online, go for it! All of the code from this blog post is available here on Github, and you can pop in your own sample packs and have fun. Some other things to try:

Special thanks to Jamie Wong, Zameer Manji, Isaac Ezer, and Mark Koh for their proofreading and feedback on this post.


Now read this

A DeepDream Web Service for $5 a Month

Google’s DeepDream neural net image processing library is a stunning application of advanced technology. If you haven’t heard of it, DeepDream uses an image recognition system in reverse - instead of trying to identify which objects are... Continue →