Secure PyTorch Models with eBPF

Avi Lumelsky
InfoSec Write-ups
Published in
10 min readJul 23, 2023

--

This article was not generated by GPT

In this blog, I will present secimport — a toolkit for creating and running sandboxed applications in Python that utilizes eBPF (bpftrace) to secure Python runtimes.

I will start with why it is needed (feel free to skip that part),
and then demonstrate how to run PyTorch models securely.

Photo by Hitesh Choudhary on Unsplash

In part 1 of the series, I introduced OS and Application tracing and sandboxing for Python. I wrote about a minimal working solution (MVP) with dtrace, that secures Python runtimes all the way to the syscall level.
For an in-depth explanation of the existing sandbox solutions — check it out!

Table Of Contents:

  • Evaluating Insecure Code
  • A word about the pickle protocol and Supply Chain Attacks
  • PyTorch sandbox example
  • Prevent PyTorch code execution with secimport
  • Conclusion

Evaluating insecure code

In today’s software development landscape, adding new libraries to our codebase might be challenging. We lack clarity on package expectations (what it should do in order to function) and imported packages can manipulate our environment without our knowledge.

If you look at HuggingFace, repositories often store the PyTorch definition of the model, which is Python code. Someone must have reviewed it, you think… Isn’t it considered secure? It has enough stars…

We rely on stars as a credibility metric and that should change.
We star repositories to bookmark them — few of us really intend to contribute and will dive into the code. many of us use it without thinking twice. We use someone else’s code without reviewing it.
I think it is just a matter of time before big security incidents will take place.

2K stars in a few days are not rare for LLMs since they became a thing while keeping track is impossible, nor reviewing their code.
Moreover — faking stars is also easy these days. Just get a good README for your project (with a GPT), and you know someone out there will probably try your code.

Security measurements should take place in Python’s runtime.
One example is the pickle protocol. Many major frameworks (and Python’s multiprocessing) rely on pickle as a building block.

Why is pickle an issue?

It’s vulnerable by design, that’s why:

import pickle
class Demo:
def __reduce__(self):
return (eval, ("__import__('os').system('echo Exploited!')",))


In: pickle.dumps(Demo())
Out: b"\x80\x04\x95F\x00\x00\x00\x00\x00\x00\x00\x8c\x08builtins\x94\x8c\x04eval\x94\x93\x94\x8c*__import__('os').system('echo Exploited!')\x94\x85\x94R\x94."

In another terminal or environment, loading someone else’s pickled code will result in — you're right — an exploit:

import pickle
pickle.loads(b"\x80\x04\x95F\x00\x00\x00\x00\x00\x00\x00\x8c\x08builtins\x94\x8c\x04eval\x94\x93\x94\x8c*__import__('os').system('echo Exploited!')\x94\x85\x94R\x94.")
Exploited!
0

Let’s see secimport blocking this pickle exploit with eBPF:

Regular Python (left) vs Secimport (right).

In the figure above, secimport was able to block the pickle exploit because we have defined a policy in advance.
We ran the Python process using “secimport run” — which runs the Python process under eBPF supervision, in real-time.

Photo by Jase Bloor on Unsplash

Let’s look at Pytorch as an example. This official message in PyTorch’s package documentation which I am sure many of you have missed if you ever worked with Pytorch:

PyTorch models are very easy to exploit. For example, this blog shows how any torch model can be patched with an exploit that’s based on pickle, which is very similar to the example above, but more sophisticated and passes all code static security scans.

We should thrive to avoid code execution that is not expected.
You can either use the CLI to trace the logic of the pickled file you allow,
Or you can use the secimport Python API to make sure “pickle” is not running odd syscalls when loading a specific file:

import secimport
pickle = secimport.secure_import("pickle")
pickle.loads(b"\x80\x04\x95F\x00\x00\x00\x00\x00\x00\x00\x8c\x08builtins\x94\x8c\x04eval\x94\x93\x94\x8c*__import__('os').system('echo Exploited!')\x94\x85\x94R\x94.")

[1] 28027 killed ipython

After running this code, a log file is automatically created, containing everything you need to know about the process and why it was killed:

$ less /tmp/.secimport/sandbox_pickle.log

@posix_spawn from /Users/avilumelsky/Downloads/Python-3.10.0/Lib/threading.py
DETECTED SHELL:
depth=8
sandboxed_depth=0
sandboxed_module=/Users/avilumelsky/Downloads/Python-3.10.0/Lib/pickle.py

TERMINATING SHELL:
libsystem_kernel.dylib`__posix_spawn+0xa
...
libsystem_kernel.dylib`__posix_spawn+0xa
libsystem_c.dylib`system+0x18b
python.exe`os_system+0xb3
KILLED
:

I hope you understand the problem, many industry projects rely on this insecure format. Anyway — the behavior of a program, even an AI model, should always be known in advance.

Supply Chain Attacks in AI

There are many ways to exploit Python users, with minimal effort.

  • PyTorch has numerous design problems.
  • Image files
  • Typos
  • HuggingFace, AutoGPT, and similar SaaS AI companies work in this way (rely on pickle and frameworks that are insecure by design).

Models are usually open-source, and we need them to be portable.
PyTorch nn.Module instances (models) are managed through code.
Using someone else’s model results in loading unsafe code, in your private environment.

PyTorch Sandbox Example

Assuming you understood by now that in Python we can run arbitrary code pretty easily, Let’s try to secure a given PyTorch model.

1. We will run an example from PyTorch’s documentation,
With and Without a Sandbox.
2. We will add a malicious line to the code.
3. The sandbox will log the violation, and then block the activity before it occurs (IDS mode vs IPS mode)

We will use a random example from torch (==2.0.1):

# -*- coding: utf-8 -*-
import time
import torch
import math

class Polynomial3(torch.nn.Module):
def __init__(self):
"""
In the constructor we instantiate four parameters and assign them as
member parameters.
"""
super().__init__()
self.a = torch.nn.Parameter(torch.randn(()))
self.b = torch.nn.Parameter(torch.randn(()))
self.c = torch.nn.Parameter(torch.randn(()))
self.d = torch.nn.Parameter(torch.randn(()))

def forward(self, x):
"""
In the forward function we accept a Tensor of input data and we must return
a Tensor of output data. We can use Modules defined in the constructor as
well as arbitrary operators on Tensors.
"""
# import os; os.system('ps')
return self.a + self.b * x + self.c * x ** 2 + self.d * x ** 3

def string(self):
"""
Just like any class in Python, you can also define custom method on PyTorch modules
"""
return f'y = {self.a.item()} + {self.b.item()} x + {self.c.item()} x^2 + {self.d.item()} x^3'

start_time = time.time()
# Create Tensors to hold input and outputs.
x = torch.linspace(-math.pi, math.pi, 2000)
y = torch.sin(x)

# Construct our model by instantiating the class defined above
model = Polynomial3()

# Construct our loss function and an Optimizer. The call to model.parameters()
# in the SGD constructor will contain the learnable parameters (defined
# with torch.nn.Parameter) which are members of the model.
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-6)
for t in range(2000):
# Forward pass: Compute predicted y by passing x to the model
y_pred = model(x)

# Compute and print loss
loss = criterion(y_pred, y)
if t % 100 == 99:
print(t, loss.item())

# Zero gradients, perform a backward pass, and update the weights.
optimizer.zero_grad()
loss.backward()
optimizer.step()

print(f'Result: {model.string()}')
print("--- %s seconds ---" % (time.time() - start_time))

Run the code:

root@3ecd9c9b5613:/workspace# Python-3.10.0/python -m pip install torch
root@3ecd9c9b5613:/workspace# Python-3.10.0/python pytorch_example.py
99 674.6323852539062
...
1999 9.19102668762207
Result: y = -0.013432367704808712 + 0.8425596952438354 x + 0.0023173068184405565 x^2 + -0.09131323546171188 x^3
--- 0.6940326690673828 seconds ---

The code ran just fine.

Create a tailor-made sandbox for that code.

Now, I want to create a security policy for that code, so nothing but that code can run. It is done using secimport trace command to trace the code, and secimport build to build a sandbox from the trace.

You can read how it works here or here, but keep on reading.

root@ec15bafca930:/workspace/examples/cli/ebpf/torch_demo# secimport trace --entrypoint pytorch_example.py 
>>> secimport trace

TRACING: ['/root/.local/lib/python3.10/site-packages/secimport/profiles/trace.bt', '-c', 'bash -c "/workspace/Python-3.10.0/python pytorch_example.py"', '-o', 'trace.log']

Press CTRL+D/CTRL+C to stop the trace;

/workspace/examples/cli/ebpf/torch_demo/pytorch_example.py:36: UserWarning: Failed to initialize NumPy: No module named 'numpy' (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:84.)
x = torch.linspace(-math.pi, math.pi, 2000)
...
...
Result: y = -0.04786265641450882 + 0.8422093987464905 x + 0.008257105946540833 x^2 + -0.09126341342926025 x^3
--- 1.5200915336608887 seconds ---


TRACING DONE;

Let’s build a sandbox from the traced code.

It will create a mapping of syscalls per module that runs in your code during the trace. Any new syscalls in new places, or code changes that add logic, will result in something secimport treats as a “violation”.
In the following example, secimport builds a YAML/JSON policy for your code, by analyzing the trace.

root@ec15bafca930:/workspace/examples/cli/ebpf/torch_demo# secimport build 
>>> secimport build

SECIMPORT COMPILING...

CREATED JSON TEMPLATE: sandbox.json
CREATED YAML TEMPLATE: sandbox.yaml

compiling template sandbox.yaml

[debug] adding syscall write to blocklist for module general_requirements
[debug] adding syscall writev to blocklist for module general_requirements
...
[debug] adding syscall stat to allowlist for module /workspace/Python-3.10.0/lib/python3.10/site-packages/py.py
[debug] adding syscall clock_gettime to allowlist for module /workspace/Python-3.10.0/lib/python3.10/site-packages/py.py
[debug] adding syscall exit_group to allowlist for module /workspace/Python-3.10.0/lib/python3.10/site-packages/py.py
...
[debug] adding syscall mmap to allowlist for module /workspace/Python-3.10.0/lib/python3.10/site-packages/torch/ao/
[debug] adding syscall brk to allowlist for module /workspace/Python-3.10.0/lib/python3.10/site-packages/torch/aut
[debug] adding syscall close to allowlist for module /workspace/Python-3.10.0/lib/python3.10/site-packages/torch/aut
...
[debug] adding syscall mmap to allowlist for module /workspace/Python-3.10.0/lib/python3.10/site-packages/torch/jit
[debug] adding syscall brk to allowlist for module /workspace/Python-3.10.0/lib/python3.10/site-packages/torch/lib
[debug] adding syscall clock_gettime to allowlist for module /workspace/Python-3.10.0/lib/python3.10/site-packages/torch/nn/
[debug] adding syscall write to allowlist for module /workspace/Python-3.10.0/lib/python3.10/site-packages/torch/nn/
...
DTRACE SANDBOX: sandbox.d
BPFTRCE SANDBOX: sandbox.bt

SANDBOX READY: sandbox.bt

Now, let’s run the original code in a sandbox:


root@3ecd9c9b5613:/workspace# secimport run --entrypoint pytorch_example.py

99 3723.3251953125
...
1999 11.318828582763672
Result: y = -0.04061822220683098 + 0.8255564570426941 x + 0.007007318548858166 x^2 + -0.08889468014240265 x^3
--- 0.8806719779968262 seconds ---

SANDBOX EXITED;

Nice! It ran inside the sandbox without any errors, as expected.
Now, we all want to see the sandbox in action — let’s modify the code to do something new.

Blocking code execution with secimport

Now, let’s uncomment the “os.system” command and see if secimport recognizes that change. “os.system” can also be obfuscated or use ‘subprocess’ module instead to run commands — but since we monitor at the syscall level, we don’t care! our eBPF sandbox should see everything.

We will use the same sandbox from the previous step. This time, the program will execute the “ps” command in the model’s forward().

Let’s change this:

def forward(self, x):
return self.a + self.b * x + self.c * x ** 2 + self.d * x ** 3

To be:

def forward(self, x):
import os; os.system('ps')
return self.a + self.b * x + self.c * x ** 2 + self.d * x ** 3

Secimport is capable of detecting that the package is imported. But it will not do anything on the machine. Afterward, secimport will catch the “os.system” call, resulting in a syscall.

On Mac, syscalls 56 and 61 (CLONE and WAIT4) will be used.
On Linux, syscall 59 alone will be used (EXECVE).
Other Python multiprocessing libraries will call other syscalls (Fork/Spawn) — but it will always be the same behavior on the same interpreter on the same OS.
At the end of the day, for the kernel, everything is syscalls.

Photo by Thomas Park on Unsplash

Let’s run the code once again after we injected an arbitrary command.
We should expect the sandbox to log the violation.

root@3ecd9c9b5613:/workspace# secimport run --entrypoint pytorch_example.py
>>> secimport run
...

[SECURITY PROFILE VIOLATED]: /workspace/examples/cli/ebpf/torch_demo/pytorch_example.py called syscall 56 at depth 561032
[SECURITY PROFILE VIOLATED]: /workspace/examples/cli/ebpf/torch_demo/pytorch_example.py called syscall 61 at depth 561032
PID TTY TIME CMD
1 pts/0 00:00:00 sh
11 pts/0 00:00:00 bash
4279 pts/0 00:00:00 python
4280 pts/0 00:00:00 sh
4281 pts/0 00:00:06 bpftrace
4285 pts/0 00:00:06 python
8289 pts/0 00:00:00 sh
8290 pts/0 00:00:00 ps
1999 9.100260734558105
Result: y = 0.017583630979061127 + 0.8593376278877258 x + -0.003033468732610345 x^2 + -0.09369975328445435 x^3

That’s awesome! secimport logged 2 violations BEFORE the ps command actually ran:

[SECURITY PROFILE VIOLATED]: /workspace/examples/cli/ebpf/torch_demo/pytorch_example.py called syscall 56 at depth 561032
[SECURITY PROFILE VIOLATED]: /workspace/examples/cli/ebpf/torch_demo/pytorch_example.py called syscall 61 at depth 561032
  1. Syscall no. 56 (clone)
  2. Syscall no. 61 (sys_wait4)

On other machines, python was using syscall no. 59 alone (EXECVE) instead of 56 then 61. The called syscalls depend on the OS and may vary, yet on the same OS and interpreter, secimport is always consistent.

More than detection — How to prevent code execution?

In the case above secimport only logged the policy violations.
Code Execution Prevention is easy with these 2 flags:

  • secimport run … — stop_on_violation
[SECURITY PROFILE VIOLATED]: <stdin> called syscall 56 at depth 8022
^^^ STOPPING PROCESS 85918 DUE TO SYSCALL VIOLATION ^^^
  • secimport run … — kill_on_violation
[SECURITY PROFILE VIOLATED]: <stdin> called syscall 56 at depth 8022
^^^ KILLING PROCESS 86466 DUE TO SYSCALL VIOLATION ^^^
KILLED.

By simply adding one of these flags, you can really block code in your production runtime, before it happens. This is the strongest kind of protection one can expect.

One would not use these flags in every project, you might want to log in by default, but this Kill and Stop behavior can solve so many problems, or give the security teams the ability to restrict 3rd party code in Python’s runtime.

Conclusion

In this blog post, we went through secimport usage and how we could secure torch models runtimes all the way to the kernel’s syscalls, per module, using secimport’ CLI.

Secimport enables Python users to confine different modules in their code with different privileges and rules.
I encourage you to try secimport for your use case.

Thank you for reading this far.
If you like my work, let me know.
You are welcome to contact me or comment if you have any questions or ideas, and of course, I welcome you to star it on GitHub and contribute!

References:

By the way, I am doing this in my spare time. I also really love coffee!

Check out my previous releases:

--

--

A business-oriented security researcher, who loves Privacy and AI, with deep security insights.