MNIST Digit Classification
In this tutorial we will use a model trained on the MNIST dataset of handwritten digits to predict the number that the user draws.
There are several pieces to this tutorial, so please follow each step carefully. If you get lost, completed examples of each step can be found here.
If you haven't installed the PyTorch Live CLI yet, please follow this tutorial to get started.
Create a new React Native project
We will start by creating a new React Native project with the PyTorch Live (PTL) template using the CLI. Run the following command:
npx torchlive-cli init MNISTClassifier
Once that is done, let's go into a our newly created project and run it!
cd MNISTClassifier
- Android
- iOS (Simulator)
npx torchlive-cli run-android

npx torchlive-cli run-ios

Adding Basic UI
The aim of this tutorial is to help you become more familiar with PTL core components, so we will not spend time on how to style UI, but rather provide the layout and styles from the start.
Go ahead and start by copying the following code into the file src/demos/MyDemos.tsx:
import React, {useState} from 'react';
import {StyleSheet, Text, View} from 'react-native';
import {Canvas, CanvasRenderingContext2D} from 'react-native-pytorch-core';
import {useSafeAreaInsets} from 'react-native-safe-area-context';
export default function MNISTDemo() {
// Get safe area insets to account for notches, etc.
const insets = useSafeAreaInsets();
const [canvasSize, setCanvasSize] = useState<number>(0);
// `ctx` is drawing context to draw shapes
const [ctx, setCtx] = useState<CanvasRenderingContext2D>();
return (
<View
style={styles.container}
onLayout={event => {
const {layout} = event.nativeEvent;
setCanvasSize(Math.min(layout?.width || 0, layout?.height || 0));
}}>
<View style={[styles.instruction, {marginTop: insets.top}]}>
<Text style={styles.label}>Write a number</Text>
<Text style={styles.label}>Let's test the MNIST model</Text>
</View>
<Canvas
style={{
height: canvasSize,
width: canvasSize,
}}
onContext2D={setCtx}
/>
<View style={[styles.resultView]} pointerEvents="none">
<Text style={[styles.label, styles.secondary]}>
Highest confidence will go here
</Text>
<Text style={[styles.label, styles.secondary]}>
Second highest will go here
</Text>
</View>
</View>
);
}
const styles = StyleSheet.create({
container: {
height: '100%',
width: '100%',
backgroundColor: '#180b3b',
justifyContent: 'center',
alignItems: 'center',
},
resultView: {
position: 'absolute',
bottom: 0,
alignSelf: 'flex-start',
flexDirection: 'column',
padding: 15,
},
instruction: {
position: 'absolute',
top: 0,
alignSelf: 'flex-start',
flexDirection: 'column',
padding: 15,
},
label: {
fontSize: 16,
color: '#ffffff',
},
secondary: {
color: '#ffffff99',
},
});
Now you should see UI that looks exactly like the screenshot below.
- Android
- iOS (Simulator)
npx torchlive-cli run-android
npx torchlive-cli run-ios
Before we add more code, let's take a second to discuss some of what the above code does.
The PyTorch Live Canvas Component
We'll be using the PTL canvas in this tutorial to let the user draw numbers that we will try to classify.
Just like the name suggests, a canvas is a surface that we can programmatically draw on.
In order to draw things on a canvas, we use what is called the canvas context, the ctx state variable in this case.
Note that we haven't used the context to draw anything yet, so our canvas is essentially invisible.
...
export default function MNISTDemo() {
...
const [ctx, setCtx] = useState<CanvasRenderingContext2D>();
...
<Canvas
style={{
height: canvasSize,
width: canvasSize,
}}
onContext2D={setCtx}
/>
...
The onLayout Prop
In our code, we use the onLayout prop on the container view to get the dimensions of the screen space we are working with.
Once we have the dimensions of the screen, we find which is smaller between the screen width and height and then we use that to size our canvas.
This makes sure that our canvas is square and fits within the bounds of our screen in both portrait and landscape.
...
export default function MNISTDemo() {
// Get safe area insets to account for notches, etc.
const insets = useSafeAreaInsets();
const [canvasSize, setCanvasSize] = useState<number>(0);
...
return (
<View
style={styles.container}
onLayout={event => {
const {layout} = event.nativeEvent;
setCanvasSize(Math.min(layout?.width || 0, layout?.height || 0));
}}>
...
Results placeholders
Note that for now we just have placeholder text where we will put our model results. Later on, after we run the model, we will update the text there to display the results.
...
<View style={[styles.resultView]} pointerEvents="none">
<Text style={[styles.label, styles.secondary]}>
Highest confidence will go here
</Text>
<Text style={[styles.label, styles.secondary]}>
Second highest will go here
</Text>
</View>
...
Filling the Canvas
Like we mentioned in the previous section, our canvas is currently completely blank.
Let's change that and make a clear surface for users to draw on.
Here's a short summary of the changes we're introducing:
Import
useCallbackanduseEffectfrom React.Define a color for our canvas background (
COLOR_CANVAS_BACKGROUND). We'll use a lighter purple color to distinguish from the rest of the screen.Create a
drawfunction that will fill in our background. We create it withuseCallbackto make it so the function updates every time the context or size of the canvas change.Check to make sure context is not null so we have something to draw with.
Set the context's fill style to our canvas background purple (essentially choosing which marker to work with).
Fill in a rectangle that starts at the origin coordinate (0,0) on our canvas (the top left corner) and ends in the bottom right corner of our canvas so it covers the whole thing.
Call the
invalidatefunction on our canvas context to let the screen know that we have drawn new things for it to show.
Trigger the
drawanytime it changes with theuseEffectblock. Remember thatdrawchanges every time the canvas context or size changes, so essentially thisuseEffectruns every time the canvas changes.
The useCallback and useEffect that we imported as well as the useState function we already had imported are examples of React Hooks. Hooks allow React function components, like our MNISTDemo function component, to remember things.
You'll notice at the end of useCallback and useEffect we have a list []. This list is the list of "dependencies" for that hook. This just means that the hook will hold onto the value we give it until one of the "dependencies" changes, in which case it will update the value it remembers.
For more information on React Hooks, head over to the React docs where you can read or watch explanations.
- Changes
- Entire File
@@ -1,8 +1,10 @@
-import React, {useState} from 'react';
+import React, {useCallback, useEffect, useState} from 'react';
import {StyleSheet, Text, View} from 'react-native';
import {Canvas, CanvasRenderingContext2D} from 'react-native-pytorch-core';
import {useSafeAreaInsets} from 'react-native-safe-area-context';
+const COLOR_CANVAS_BACKGROUND = '#4F25C6';
+
export default function MNISTDemo() {
// Get safe area insets to account for notches, etc.
const insets = useSafeAreaInsets();
@@ -10,6 +12,20 @@
// `ctx` is drawing context to draw shapes
const [ctx, setCtx] = useState<CanvasRenderingContext2D>();
+ const draw = useCallback(() => {
+ if (ctx != null) {
+ // fill background by drawing a rect
+ ctx.fillStyle = COLOR_CANVAS_BACKGROUND;
+ ctx.fillRect(0, 0, canvasSize, canvasSize);
+
+ ctx.invalidate();
+ }
+ }, [ctx, canvasSize]);
+
+ useEffect(() => {
+ draw();
+ }, [draw]);
+
return (
<View
style={styles.container}
import React, {useCallback, useEffect, useState} from 'react';
import {StyleSheet, Text, View} from 'react-native';
import {Canvas, CanvasRenderingContext2D} from 'react-native-pytorch-core';
import {useSafeAreaInsets} from 'react-native-safe-area-context';
const COLOR_CANVAS_BACKGROUND = '#4F25C6';
export default function MNISTDemo() {
// Get safe area insets to account for notches, etc.
const insets = useSafeAreaInsets();
const [canvasSize, setCanvasSize] = useState<number>(0);
// `ctx` is drawing context to draw shapes
const [ctx, setCtx] = useState<CanvasRenderingContext2D>();
const draw = useCallback(() => {
if (ctx != null) {
// fill background by drawing a rect
ctx.fillStyle = COLOR_CANVAS_BACKGROUND;
ctx.fillRect(0, 0, canvasSize, canvasSize);
ctx.invalidate();
}
}, [ctx, canvasSize]);
useEffect(() => {
draw();
}, [draw]);
return (
<View
style={styles.container}
onLayout={event => {
const {layout} = event.nativeEvent;
setCanvasSize(Math.min(layout?.width || 0, layout?.height || 0));
}}>
<View style={[styles.instruction, {marginTop: insets.top}]}>
<Text style={styles.label}>Write a number</Text>
<Text style={styles.label}>Let's test the MNIST model</Text>
</View>
<Canvas
style={{
height: canvasSize,
width: canvasSize,
}}
onContext2D={setCtx}
/>
<View style={[styles.resultView]} pointerEvents="none">
<Text style={[styles.label, styles.secondary]}>
Highest confidence will go here
</Text>
<Text style={[styles.label, styles.secondary]}>
Second highest will go here
</Text>
</View>
</View>
);
}
const styles = StyleSheet.create({
container: {
height: '100%',
width: '100%',
backgroundColor: '#180b3b',
justifyContent: 'center',
alignItems: 'center',
},
resultView: {
position: 'absolute',
bottom: 0,
alignSelf: 'flex-start',
flexDirection: 'column',
padding: 15,
},
instruction: {
position: 'absolute',
top: 0,
alignSelf: 'flex-start',
flexDirection: 'column',
padding: 15,
},
label: {
fontSize: 16,
color: '#ffffff',
},
secondary: {
color: '#ffffff99',
},
});
Once you run your app, the My Demos screen should now look like this.
- Android
- iOS (Simulator)
npx torchlive-cli run-android
npx torchlive-cli run-ios
I know that was a lot of new stuff to simply paint our canvas light purple, but it provides us with a good foundation for when we draw more on our canvas.
Drawing with Touch Input
Now that we have a clear area for the user to draw on, let's make it so they can draw!
Let's go over what we will change to make drawing possible:
Import
useReffrom React.Define a color for the trail of the users touch (
COLOR_TRAIL_STROKE). We'll use white to make it stand out.Define a
TrailPointtype to keep our data safe, error free, and easy to use.Create a ref to a list of
TrailPointscalledtrailRefand set it to an empty list.Keep track of if the user has finished drawing with the
drawingDonestate variable and initialize it tofalse.Add support for drawing the trail to our draw function:
- Create a variable called
trailand set it to the current value of ourtrailRef. This is purely so we don't have to writetrailRef.currentevery time we need the trail. - Check to make sure the trail isn't null.
- Draw our background to cover anything previously drawn.
- Check to make sure our trail has at least 1 point.
- Set the context's
strokeColor- you can think of it as picking the marker color we'll draw lines with. - Set the context's line drawing style parameters (
lineWidth,lineJoin,lineCap, andmiterLimit). - Tell the context to start a line at the first point in the trail.
- Loop through points of the trail to add them to the line we are drawing.
- Tell the context via the
strokemethod to actually draw the line that we constructed. - Use the
invalidatemethod to tell the screen we have updates ready to draw.
- Create a variable called
Create functions for handling when a user touches the canvas (
handleStart,handleTouch, andhandleEnd).The
handleStartis called when the user first touches the canvas. It is a simple function that does the following:- Set the
drawingDonevariable tofalse. - Reset the trailRef to an
emptyList.
- Set the
The
handleMovefunction is called each time the device detects that the touch has changed positions since the starting touch.- Get the coordinates of the new touch location and store them in the
positionvariable. - If there are already points in the
trail, only add the new position if it's 5 pixels away from the last position (avoids keeping unnecessary points that slow down the app). - If there are no points in the
trail, add the new position. - Trigger the
drawfunction to display the newly updatedtrail.
- Get the coordinates of the new touch location and store them in the
The
handleEndfunction is called when the user's touch is no longer detected on the screen.- Simply set the
drawingDonestate variable totrue.
- Simply set the
Set the
onTouchStart,onTouchMove, andonTouchEndprops on our<Canvas />component tohandleStart,handleMove, andhandleEndrespectively.
- Changes
- Entire File
@@ -1,26 +1,88 @@
-import React, {useCallback, useEffect, useState} from 'react';
+import React, {useCallback, useEffect, useState, useRef} from 'react';
import {StyleSheet, Text, View} from 'react-native';
import {Canvas, CanvasRenderingContext2D} from 'react-native-pytorch-core';
import {useSafeAreaInsets} from 'react-native-safe-area-context';
const COLOR_CANVAS_BACKGROUND = '#4F25C6';
+const COLOR_TRAIL_STROKE = '#FFFFFF';
+
+type TrailPoint = {
+ x: number;
+ y: number;
+};
export default function MNISTDemo() {
// Get safe area insets to account for notches, etc.
const insets = useSafeAreaInsets();
const [canvasSize, setCanvasSize] = useState<number>(0);
+
// `ctx` is drawing context to draw shapes
const [ctx, setCtx] = useState<CanvasRenderingContext2D>();
+ const trailRef = useRef<TrailPoint[]>([]);
+ const [drawingDone, setDrawingDone] = useState(false);
+
const draw = useCallback(() => {
if (ctx != null) {
- // fill background by drawing a rect
- ctx.fillStyle = COLOR_CANVAS_BACKGROUND;
- ctx.fillRect(0, 0, canvasSize, canvasSize);
+ const trail = trailRef.current;
+ if (trail != null) {
+ // fill background by drawing a rect
+ ctx.fillStyle = COLOR_CANVAS_BACKGROUND;
+ ctx.fillRect(0, 0, canvasSize, canvasSize);
+
+ // Draw the trail
+
+ if (trail.length > 0) {
+ ctx.strokeStyle = COLOR_TRAIL_STROKE;
+ ctx.lineWidth = 25;
+ ctx.lineJoin = 'round';
+ ctx.lineCap = 'round';
+ ctx.miterLimit = 1;
+ ctx.beginPath();
+ ctx.moveTo(trail[0].x, trail[0].y);
+ for (let i = 1; i < trail.length; i++) {
+ ctx.lineTo(trail[i].x, trail[i].y);
+ }
+ ctx.stroke();
+ }
- ctx.invalidate();
+ ctx.invalidate();
+ }
}
- }, [ctx, canvasSize]);
+ }, [ctx, canvasSize, trailRef]);
+
+ // handlers for touch events
+ const handleMove = useCallback(
+ async event => {
+ const position: TrailPoint = {
+ x: event.nativeEvent.locationX,
+ y: event.nativeEvent.locationY,
+ };
+ const trail = trailRef.current;
+ if (trail.length > 0) {
+ const lastPosition = trail[trail.length - 1];
+ const dx = position.x - lastPosition.x;
+ const dy = position.y - lastPosition.y;
+ // add a point to trail if distance from last point > 5
+ if (dx * dx + dy * dy > 25) {
+ trail.push(position);
+ }
+ } else {
+ trail.push(position);
+ }
+ draw();
+ },
+ [trailRef, draw],
+ );
+
+ const handleStart = useCallback(() => {
+ setDrawingDone(false);
+ trailRef.current = [];
+ }, [trailRef, setDrawingDone]);
+
+ const handleEnd = useCallback(() => {
+ setDrawingDone(true);
+ }, [setDrawingDone]);
useEffect(() => {
draw();
@@ -35,7 +97,9 @@
}}>
<View style={[styles.instruction, {marginTop: insets.top}]}>
<Text style={styles.label}>Write a number</Text>
- <Text style={styles.label}>Let's test the MNIST model</Text>
+ <Text style={styles.label}>
+ Let's see if the AI model will get it right
+ </Text>
</View>
<Canvas
style={{
@@ -43,15 +107,20 @@
width: canvasSize,
}}
onContext2D={setCtx}
+ onTouchMove={handleMove}
+ onTouchStart={handleStart}
+ onTouchEnd={handleEnd}
/>
- <View style={[styles.resultView]} pointerEvents="none">
- <Text style={[styles.label, styles.secondary]}>
- Highest confidence will go here
- </Text>
- <Text style={[styles.label, styles.secondary]}>
- Second highest will go here
- </Text>
- </View>
+ {drawingDone && (
+ <View style={[styles.resultView]} pointerEvents="none">
+ <Text style={[styles.label, styles.secondary]}>
+ Highest confidence will go here
+ </Text>
+ <Text style={[styles.label, styles.secondary]}>
+ Second highest will go here
+ </Text>
+ </View>
+ )}
</View>
);
}
import React, {useCallback, useEffect, useState, useRef} from 'react';
import {StyleSheet, Text, View} from 'react-native';
import {Canvas, CanvasRenderingContext2D} from 'react-native-pytorch-core';
import {useSafeAreaInsets} from 'react-native-safe-area-context';
const COLOR_CANVAS_BACKGROUND = '#4F25C6';
const COLOR_TRAIL_STROKE = '#FFFFFF';
type TrailPoint = {
x: number;
y: number;
};
export default function MNISTDemo() {
// Get safe area insets to account for notches, etc.
const insets = useSafeAreaInsets();
const [canvasSize, setCanvasSize] = useState<number>(0);
// `ctx` is drawing context to draw shapes
const [ctx, setCtx] = useState<CanvasRenderingContext2D>();
const trailRef = useRef<TrailPoint[]>([]);
const [drawingDone, setDrawingDone] = useState(false);
const draw = useCallback(() => {
if (ctx != null) {
const trail = trailRef.current;
if (trail != null) {
// fill background by drawing a rect
ctx.fillStyle = COLOR_CANVAS_BACKGROUND;
ctx.fillRect(0, 0, canvasSize, canvasSize);
// Draw the trail
if (trail.length > 0) {
ctx.strokeStyle = COLOR_TRAIL_STROKE;
ctx.lineWidth = 25;
ctx.lineJoin = 'round';
ctx.lineCap = 'round';
ctx.miterLimit = 1;
ctx.beginPath();
ctx.moveTo(trail[0].x, trail[0].y);
for (let i = 1; i < trail.length; i++) {
ctx.lineTo(trail[i].x, trail[i].y);
}
ctx.stroke();
}
ctx.invalidate();
}
}
}, [ctx, canvasSize, trailRef]);
// handlers for touch events
const handleMove = useCallback(
async event => {
const position: TrailPoint = {
x: event.nativeEvent.locationX,
y: event.nativeEvent.locationY,
};
const trail = trailRef.current;
if (trail.length > 0) {
const lastPosition = trail[trail.length - 1];
const dx = position.x - lastPosition.x;
const dy = position.y - lastPosition.y;
// add a point to trail if distance from last point > 5
if (dx * dx + dy * dy > 25) {
trail.push(position);
}
} else {
trail.push(position);
}
draw();
},
[trailRef, draw],
);
const handleStart = useCallback(() => {
setDrawingDone(false);
trailRef.current = [];
}, [trailRef, setDrawingDone]);
const handleEnd = useCallback(() => {
setDrawingDone(true);
}, [setDrawingDone]);
useEffect(() => {
draw();
}, [draw]);
return (
<View
style={styles.container}
onLayout={event => {
const {layout} = event.nativeEvent;
setCanvasSize(Math.min(layout?.width || 0, layout?.height || 0));
}}>
<View style={[styles.instruction, {marginTop: insets.top}]}>
<Text style={styles.label}>Write a number</Text>
<Text style={styles.label}>
Let's see if the AI model will get it right
</Text>
</View>
<Canvas
style={{
height: canvasSize,
width: canvasSize,
}}
onContext2D={setCtx}
onTouchMove={handleMove}
onTouchStart={handleStart}
onTouchEnd={handleEnd}
/>
{drawingDone && (
<View style={[styles.resultView]} pointerEvents="none">
<Text style={[styles.label, styles.secondary]}>
Highest confidence will go here
</Text>
<Text style={[styles.label, styles.secondary]}>
Second highest will go here
</Text>
</View>
)}
</View>
);
}
const styles = StyleSheet.create({
container: {
height: '100%',
width: '100%',
backgroundColor: '#180b3b',
justifyContent: 'center',
alignItems: 'center',
},
resultView: {
position: 'absolute',
bottom: 0,
alignSelf: 'flex-start',
flexDirection: 'column',
padding: 15,
},
instruction: {
position: 'absolute',
top: 0,
alignSelf: 'flex-start',
flexDirection: 'column',
padding: 15,
},
label: {
fontSize: 16,
color: '#ffffff',
},
secondary: {
color: '#ffffff99',
},
});
Run this code and we should now be able to do some drawing like you can see in the video below.

As you will notice, the drawing seems to glitch out at times, especially as the trail gets longer and longer. Let's fix that next.
React Refs
Refs in React are a variable like state, but they don't cause the component to re-render when they are changed.
You can get or set the value of a ref via the .current property.
In our code, we access the trail with trailRef.current. We set the trail in our handleStart function to an empty list with trailRef.current = [].
Avoiding Excessive Re-rendering
The glitchiness we see in our code as it stands is because we are asking the screen to refresh before it is ready.
Mobile screens typically refresh 60 times per second (though some new phones refresh twice as often). When we display things with React, it takes care of matching our device's refresh rate.
While we are using React to render our <Canvas />, what we draw on our canvas we handle ourselves. Lucky for us, there is a simple way to make sure we don't render too often.
To address this, we will make a few updates to our code, mainly in the draw function:
Create a ref called
animationHandleRefthat can be anumberornulland set it tonull. We will use this ref to check if rendering is currently in process or not.Use the
animationHandleRefin thedrawfunction to control how often we rerender:- Start the function by checking if the
animationHandleRefis set to a non-null value. If it is, we want to end early, because we know the device is already working on rendering. - Wrap our code that does drawing in an inline function that we pass to
requestAnimationFrameand set theanimationHandleRef's value to what it returns. (Read more about this function in the note following the code.) - After telling our canvas we are ready for a rerender with
ctx.invalidate(), clear theanimationHandleRefby setting its value to null. - Add
animationHandleRefto thedrawfunction's callback dependencies list.
- Start the function by checking if the
- Changes
- Entire File
@@ -21,35 +21,40 @@
const trailRef = useRef<TrailPoint[]>([]);
const [drawingDone, setDrawingDone] = useState(false);
+ const animationHandleRef = useRef<number | null>(null);
const draw = useCallback(() => {
+ if (animationHandleRef.current != null) return;
if (ctx != null) {
- const trail = trailRef.current;
- if (trail != null) {
- // fill background by drawing a rect
- ctx.fillStyle = COLOR_CANVAS_BACKGROUND;
- ctx.fillRect(0, 0, canvasSize, canvasSize);
-
- // Draw the trail
+ animationHandleRef.current = requestAnimationFrame(() => {
+ const trail = trailRef.current;
+ if (trail != null) {
+ // fill background by drawing a rect
+ ctx.fillStyle = COLOR_CANVAS_BACKGROUND;
+ ctx.fillRect(0, 0, canvasSize, canvasSize);
- if (trail.length > 0) {
+ // Draw the trail
ctx.strokeStyle = COLOR_TRAIL_STROKE;
ctx.lineWidth = 25;
ctx.lineJoin = 'round';
ctx.lineCap = 'round';
ctx.miterLimit = 1;
- ctx.beginPath();
- ctx.moveTo(trail[0].x, trail[0].y);
- for (let i = 1; i < trail.length; i++) {
- ctx.lineTo(trail[i].x, trail[i].y);
+
+ if (trail.length > 0) {
+ ctx.beginPath();
+ ctx.moveTo(trail[0].x, trail[0].y);
+ for (let i = 1; i < trail.length; i++) {
+ ctx.lineTo(trail[i].x, trail[i].y);
+ }
}
ctx.stroke();
+ // Need to include this at the end, for now.
+ ctx.invalidate();
+ animationHandleRef.current = null;
}
-
- ctx.invalidate();
- }
+ });
}
- }, [ctx, canvasSize, trailRef]);
+ }, [animationHandleRef, ctx, canvasSize, trailRef]);
// handlers for touch events
const handleMove = useCallback(
import React, {useCallback, useEffect, useState, useRef} from 'react';
import {StyleSheet, Text, View} from 'react-native';
import {Canvas, CanvasRenderingContext2D} from 'react-native-pytorch-core';
import {useSafeAreaInsets} from 'react-native-safe-area-context';
const COLOR_CANVAS_BACKGROUND = '#4F25C6';
const COLOR_TRAIL_STROKE = '#FFFFFF';
type TrailPoint = {
x: number;
y: number;
};
export default function MNISTDemo() {
// Get safe area insets to account for notches, etc.
const insets = useSafeAreaInsets();
const [canvasSize, setCanvasSize] = useState<number>(0);
// `ctx` is drawing context to draw shapes
const [ctx, setCtx] = useState<CanvasRenderingContext2D>();
const trailRef = useRef<TrailPoint[]>([]);
const [drawingDone, setDrawingDone] = useState(false);
const animationHandleRef = useRef<number | null>(null);
const draw = useCallback(() => {
if (animationHandleRef.current != null) return;
if (ctx != null) {
animationHandleRef.current = requestAnimationFrame(() => {
const trail = trailRef.current;
if (trail != null) {
// fill background by drawing a rect
ctx.fillStyle = COLOR_CANVAS_BACKGROUND;
ctx.fillRect(0, 0, canvasSize, canvasSize);
// Draw the trail
ctx.strokeStyle = COLOR_TRAIL_STROKE;
ctx.lineWidth = 25;
ctx.lineJoin = 'round';
ctx.lineCap = 'round';
ctx.miterLimit = 1;
if (trail.length > 0) {
ctx.beginPath();
ctx.moveTo(trail[0].x, trail[0].y);
for (let i = 1; i < trail.length; i++) {
ctx.lineTo(trail[i].x, trail[i].y);
}
}
ctx.stroke();
// Need to include this at the end, for now.
ctx.invalidate();
animationHandleRef.current = null;
}
});
}
}, [animationHandleRef, ctx, canvasSize, trailRef]);
// handlers for touch events
const handleMove = useCallback(
async event => {
const position: TrailPoint = {
x: event.nativeEvent.locationX,
y: event.nativeEvent.locationY,
};
const trail = trailRef.current;
if (trail.length > 0) {
const lastPosition = trail[trail.length - 1];
const dx = position.x - lastPosition.x;
const dy = position.y - lastPosition.y;
// add a point to trail if distance from last point > 5
if (dx * dx + dy * dy > 25) {
trail.push(position);
}
} else {
trail.push(position);
}
draw();
},
[trailRef, draw],
);
const handleStart = useCallback(() => {
setDrawingDone(false);
trailRef.current = [];
}, [trailRef, setDrawingDone]);
const handleEnd = useCallback(() => {
setDrawingDone(true);
}, [setDrawingDone]);
useEffect(() => {
draw();
}, [draw]);
return (
<View
style={styles.container}
onLayout={event => {
const {layout} = event.nativeEvent;
setCanvasSize(Math.min(layout?.width || 0, layout?.height || 0));
}}>
<View style={[styles.instruction, {marginTop: insets.top}]}>
<Text style={styles.label}>Write a number</Text>
<Text style={styles.label}>
Let's see if the AI model will get it right
</Text>
</View>
<Canvas
style={{
height: canvasSize,
width: canvasSize,
}}
onContext2D={setCtx}
onTouchMove={handleMove}
onTouchStart={handleStart}
onTouchEnd={handleEnd}
/>
{drawingDone && (
<View style={[styles.resultView]} pointerEvents="none">
<Text style={[styles.label, styles.secondary]}>
Highest confidence will go here
</Text>
<Text style={[styles.label, styles.secondary]}>
Second highest will go here
</Text>
</View>
)}
</View>
);
}
const styles = StyleSheet.create({
container: {
height: '100%',
width: '100%',
backgroundColor: '#180b3b',
justifyContent: 'center',
alignItems: 'center',
},
resultView: {
position: 'absolute',
bottom: 0,
alignSelf: 'flex-start',
flexDirection: 'column',
padding: 15,
},
instruction: {
position: 'absolute',
top: 0,
alignSelf: 'flex-start',
flexDirection: 'column',
padding: 15,
},
label: {
fontSize: 16,
color: '#ffffff',
},
secondary: {
color: '#ffffff99',
},
});
What does requestAnimationFrame do?
requestAnimationFrame is a utility function that helps us run code when the screen is ready for the next rerender.
Input: a callback function as a parameter and then runs that function when the screen next refreshes.
Output: a number that functions as an ID for the callback. You can use that number to cancel the callback if you later decide you don't want to run the code. (We don't need that feature for this)
Once you have those changes in your code, go ahead and refresh the app and see how much smoother drawing is.
- Android
- iOS (Simulator)
npx torchlive-cli run-android

npx torchlive-cli run-ios

With silky smooth drawing in place, we are now ready to start working with the MNIST model.
Running the Model
We'll start by creating a React hook that provides a function for running inference on an input image. We'll follow React hooks naming conventions and call ours useMNISTModel.
Let's summarize the changes we're making:
- Import
ImageandMobileModelfromreact-native-pytorch-core. - Load the model file with the
requirefunction and call itmnistModel. - Create a type called
MNISTResultwith the following properties:num- a digit from 0 to 9.score- the confidence the model has in the input image being the givennum.
- Define a function called
useMNISTModelthat does the following:- Creates a React callback async function called
processImagethat takes inImageas a parameter and does the following.- Uses the
MobileModelapi to execute themnistModelwe loaded with a set of parameters that tell the model how much of the image to use and what the foreground and background colors are. - Transform the raw scores into
MNISTResultobjects. - Sort the results by
score. - return the sorted results.
- Uses the
- Returns an object containing the
processImagefunction we just created.
- Creates a React callback async function called
- Changes
- Entire File
@@ -1,6 +1,11 @@
import React, {useCallback, useEffect, useState, useRef} from 'react';
import {StyleSheet, Text, View} from 'react-native';
-import {Canvas, CanvasRenderingContext2D} from 'react-native-pytorch-core';
+import {
+ Canvas,
+ CanvasRenderingContext2D,
+ Image,
+ MobileModel,
+} from 'react-native-pytorch-core';
import {useSafeAreaInsets} from 'react-native-safe-area-context';
const COLOR_CANVAS_BACKGROUND = '#4F25C6';
@@ -11,6 +16,44 @@
y: number;
};
+// This is the custom model you have trained. See the tutorial for more on preparing a PyTorch model for mobile.
+const mnistModel = require('../../models/mnist.ptl');
+
+type MNISTResult = {
+ num: number;
+ score: number;
+};
+
+/**
+ * The React hook provides MNIST model inference on an input image.
+ */
+function useMNISTModel() {
+ const processImage = useCallback(async (image: Image) => {
+ // Runs model inference on input image
+ const {
+ result: {scores},
+ } = await MobileModel.execute<{scores: number[]}>(mnistModel, {
+ image,
+ crop_width: 1,
+ crop_height: 1,
+ scale_width: 28,
+ scale_height: 28,
+ colorBackground: COLOR_CANVAS_BACKGROUND,
+ colorForeground: COLOR_TRAIL_STROKE,
+ });
+
+ // Get the score of each number (index), and sort the array by the most likely first.
+ const sortedScore: MNISTResult[] = scores
+ .map((score, index) => ({score: score, num: index}))
+ .sort((a, b) => b.score - a.score);
+ return sortedScore;
+ }, []);
+
+ return {
+ processImage,
+ };
+}
+
export default function MNISTDemo() {
// Get safe area insets to account for notches, etc.
const insets = useSafeAreaInsets();
import React, {useCallback, useEffect, useState, useRef} from 'react';
import {StyleSheet, Text, View} from 'react-native';
import {
Canvas,
CanvasRenderingContext2D,
Image,
MobileModel,
} from 'react-native-pytorch-core';
import {useSafeAreaInsets} from 'react-native-safe-area-context';
const COLOR_CANVAS_BACKGROUND = '#4F25C6';
const COLOR_TRAIL_STROKE = '#FFFFFF';
type TrailPoint = {
x: number;
y: number;
};
// This is the custom model you have trained. See the tutorial for more on preparing a PyTorch model for mobile.
const mnistModel = require('../../models/mnist.ptl');
type MNISTResult = {
num: number;
score: number;
};
/**
* The React hook provides MNIST model inference on an input image.
*/
function useMNISTModel() {
const processImage = useCallback(async (image: Image) => {
// Runs model inference on input image
const {
result: {scores},
} = await MobileModel.execute<{scores: number[]}>(mnistModel, {
image,
crop_width: 1,
crop_height: 1,
scale_width: 28,
scale_height: 28,
colorBackground: COLOR_CANVAS_BACKGROUND,
colorForeground: COLOR_TRAIL_STROKE,
});
// Get the score of each number (index), and sort the array by the most likely first.
const sortedScore: MNISTResult[] = scores
.map((score, index) => ({score: score, num: index}))
.sort((a, b) => b.score - a.score);
return sortedScore;
}, []);
return {
processImage,
};
}
export default function MNISTDemo() {
// Get safe area insets to account for notches, etc.
const insets = useSafeAreaInsets();
const [canvasSize, setCanvasSize] = useState<number>(0);
// `ctx` is drawing context to draw shapes
const [ctx, setCtx] = useState<CanvasRenderingContext2D>();
const trailRef = useRef<TrailPoint[]>([]);
const [drawingDone, setDrawingDone] = useState(false);
const animationHandleRef = useRef<number | null>(null);
const draw = useCallback(() => {
if (animationHandleRef.current != null) return;
if (ctx != null) {
animationHandleRef.current = requestAnimationFrame(() => {
const trail = trailRef.current;
if (trail != null) {
// fill background by drawing a rect
ctx.fillStyle = COLOR_CANVAS_BACKGROUND;
ctx.fillRect(0, 0, canvasSize, canvasSize);
// Draw the trail
ctx.strokeStyle = COLOR_TRAIL_STROKE;
ctx.lineWidth = 25;
ctx.lineJoin = 'round';
ctx.lineCap = 'round';
ctx.miterLimit = 1;
if (trail.length > 0) {
ctx.beginPath();
ctx.moveTo(trail[0].x, trail[0].y);
for (let i = 1; i < trail.length; i++) {
ctx.lineTo(trail[i].x, trail[i].y);
}
}
ctx.stroke();
// Need to include this at the end, for now.
ctx.invalidate();
animationHandleRef.current = null;
}
});
}
}, [animationHandleRef, ctx, canvasSize, trailRef]);
// handlers for touch events
const handleMove = useCallback(
async event => {
const position: TrailPoint = {
x: event.nativeEvent.locationX,
y: event.nativeEvent.locationY,
};
const trail = trailRef.current;
if (trail.length > 0) {
const lastPosition = trail[trail.length - 1];
const dx = position.x - lastPosition.x;
const dy = position.y - lastPosition.y;
// add a point to trail if distance from last point > 5
if (dx * dx + dy * dy > 25) {
trail.push(position);
}
} else {
trail.push(position);
}
draw();
},
[trailRef, draw],
);
const handleStart = useCallback(() => {
setDrawingDone(false);
trailRef.current = [];
}, [trailRef, setDrawingDone]);
const handleEnd = useCallback(() => {
setDrawingDone(true);
}, [setDrawingDone]);
useEffect(() => {
draw();
}, [draw]);
return (
<View
style={styles.container}
onLayout={event => {
const {layout} = event.nativeEvent;
setCanvasSize(Math.min(layout?.width || 0, layout?.height || 0));
}}>
<View style={[styles.instruction, {marginTop: insets.top}]}>
<Text style={styles.label}>Write a number</Text>
<Text style={styles.label}>
Let's see if the AI model will get it right
</Text>
</View>
<Canvas
style={{
height: canvasSize,
width: canvasSize,
}}
onContext2D={setCtx}
onTouchMove={handleMove}
onTouchStart={handleStart}
onTouchEnd={handleEnd}
/>
{drawingDone && (
<View style={[styles.resultView]} pointerEvents="none">
<Text style={[styles.label, styles.secondary]}>
Highest confidence will go here
</Text>
<Text style={[styles.label, styles.secondary]}>
Second highest will go here
</Text>
</View>
)}
</View>
);
}
const styles = StyleSheet.create({
container: {
height: '100%',
width: '100%',
backgroundColor: '#180b3b',
justifyContent: 'center',
alignItems: 'center',
},
resultView: {
position: 'absolute',
bottom: 0,
alignSelf: 'flex-start',
flexDirection: 'column',
padding: 15,
},
instruction: {
position: 'absolute',
top: 0,
alignSelf: 'flex-start',
flexDirection: 'column',
padding: 15,
},
label: {
fontSize: 16,
color: '#ffffff',
},
secondary: {
color: '#ffffff99',
},
});
An even shorter summary: it takes in an Image and gives back a list of sorted results.
But, we don't have Images. We just have a trail on a canvas.
In the next section, we'll learn how to create an Image from the contents of our canvas that we can pass to the model.
Creating an Image from our Canvas
We are going to create another hook called useMNISTCanvasInference that uses the hook we just created (useMNISTModel).
This hook will take in the canvasSize and give us back two things:
result- a state variable that holds the sorted list ofMNISTResults from the model.classify- a function that takes in thecanvascontext, extracts an image from it, processes the image, and then updates theresultstate variable.
In our classify callback, we use some of the PTL core components, including the newly imported ImageUtil object.
The ImageUtil object allows us to take the imageData we pull from the canvas and turn it into an Image that can be used by our model.
You'll also see that we call the release function on both our imageData and our image variables as soon as we are done using them. This is a vital step to make sure we don't run out of memory on images we no longer need.
- Changes
- Entire File
@@ -4,6 +4,7 @@
Canvas,
CanvasRenderingContext2D,
Image,
+ ImageUtil,
MobileModel,
} from 'react-native-pytorch-core';
import {useSafeAreaInsets} from 'react-native-safe-area-context';
@@ -54,6 +55,48 @@
};
}
+/**
+ * The React hook provides MNIST inference using the image data extracted from
+ * a canvas.
+ *
+ * @param canvasSize The size of the square canvas
+ */
+function useMNISTCanvasInference(canvasSize: number) {
+ const [result, setResult] = useState<MNISTResult[]>();
+ const {processImage} = useMNISTModel();
+ const classify = useCallback(
+ async (ctx: CanvasRenderingContext2D) => {
+ // Return immediately if canvas is size 0
+ if (canvasSize === 0) {
+ return null;
+ }
+
+ // Get image data center crop
+ const imageData = await ctx.getImageData(0, 0, canvasSize, canvasSize);
+
+ // Convert image data to image.
+ const image: Image = await ImageUtil.fromImageData(imageData);
+
+ // Release image data to free memory
+ imageData.release();
+
+ // Run MNIST inference on the image
+ const result = await processImage(image);
+
+ // Release image to free memory
+ image.release();
+
+ // Set result state to force re-render of component that uses this hook
+ setResult(result);
+ },
+ [canvasSize, processImage, setResult],
+ );
+ return {
+ result,
+ classify,
+ };
+}
+
export default function MNISTDemo() {
// Get safe area insets to account for notches, etc.
const insets = useSafeAreaInsets();
import React, {useCallback, useEffect, useState, useRef} from 'react';
import {StyleSheet, Text, View} from 'react-native';
import {
Canvas,
CanvasRenderingContext2D,
Image,
ImageUtil,
MobileModel,
} from 'react-native-pytorch-core';
import {useSafeAreaInsets} from 'react-native-safe-area-context';
const COLOR_CANVAS_BACKGROUND = '#4F25C6';
const COLOR_TRAIL_STROKE = '#FFFFFF';
type TrailPoint = {
x: number;
y: number;
};
// This is the custom model you have trained. See the tutorial for more on preparing a PyTorch model for mobile.
const mnistModel = require('../../models/mnist.ptl');
type MNISTResult = {
num: number;
score: number;
};
/**
* The React hook provides MNIST model inference on an input image.
*/
function useMNISTModel() {
const processImage = useCallback(async (image: Image) => {
// Runs model inference on input image
const {
result: {scores},
} = await MobileModel.execute<{scores: number[]}>(mnistModel, {
image,
crop_width: 1,
crop_height: 1,
scale_width: 28,
scale_height: 28,
colorBackground: COLOR_CANVAS_BACKGROUND,
colorForeground: COLOR_TRAIL_STROKE,
});
// Get the score of each number (index), and sort the array by the most likely first.
const sortedScore: MNISTResult[] = scores
.map((score, index) => ({score: score, num: index}))
.sort((a, b) => b.score - a.score);
return sortedScore;
}, []);
return {
processImage,
};
}
/**
* The React hook provides MNIST inference using the image data extracted from
* a canvas.
*
* @param canvasSize The size of the square canvas
*/
function useMNISTCanvasInference(canvasSize: number) {
const [result, setResult] = useState<MNISTResult[]>();
const {processImage} = useMNISTModel();
const classify = useCallback(
async (ctx: CanvasRenderingContext2D) => {
// Return immediately if canvas is size 0
if (canvasSize === 0) {
return null;
}
// Get image data center crop
const imageData = await ctx.getImageData(0, 0, canvasSize, canvasSize);
// Convert image data to image.
const image: Image = await ImageUtil.fromImageData(imageData);
// Release image data to free memory
imageData.release();
// Run MNIST inference on the image
const result = await processImage(image);
// Release image to free memory
image.release();
// Set result state to force re-render of component that uses this hook
setResult(result);
},
[canvasSize, processImage, setResult],
);
return {
result,
classify,
};
}
export default function MNISTDemo() {
// Get safe area insets to account for notches, etc.
const insets = useSafeAreaInsets();
const [canvasSize, setCanvasSize] = useState<number>(0);
// `ctx` is drawing context to draw shapes
const [ctx, setCtx] = useState<CanvasRenderingContext2D>();
const trailRef = useRef<TrailPoint[]>([]);
const [drawingDone, setDrawingDone] = useState(false);
const animationHandleRef = useRef<number | null>(null);
const draw = useCallback(() => {
if (animationHandleRef.current != null) return;
if (ctx != null) {
animationHandleRef.current = requestAnimationFrame(() => {
const trail = trailRef.current;
if (trail != null) {
// fill background by drawing a rect
ctx.fillStyle = COLOR_CANVAS_BACKGROUND;
ctx.fillRect(0, 0, canvasSize, canvasSize);
// Draw the trail
ctx.strokeStyle = COLOR_TRAIL_STROKE;
ctx.lineWidth = 25;
ctx.lineJoin = 'round';
ctx.lineCap = 'round';
ctx.miterLimit = 1;
if (trail.length > 0) {
ctx.beginPath();
ctx.moveTo(trail[0].x, trail[0].y);
for (let i = 1; i < trail.length; i++) {
ctx.lineTo(trail[i].x, trail[i].y);
}
}
ctx.stroke();
// Need to include this at the end, for now.
ctx.invalidate();
animationHandleRef.current = null;
}
});
}
}, [animationHandleRef, ctx, canvasSize, trailRef]);
// handlers for touch events
const handleMove = useCallback(
async event => {
const position: TrailPoint = {
x: event.nativeEvent.locationX,
y: event.nativeEvent.locationY,
};
const trail = trailRef.current;
if (trail.length > 0) {
const lastPosition = trail[trail.length - 1];
const dx = position.x - lastPosition.x;
const dy = position.y - lastPosition.y;
// add a point to trail if distance from last point > 5
if (dx * dx + dy * dy > 25) {
trail.push(position);
}
} else {
trail.push(position);
}
draw();
},
[trailRef, draw],
);
const handleStart = useCallback(() => {
setDrawingDone(false);
trailRef.current = [];
}, [trailRef, setDrawingDone]);
const handleEnd = useCallback(() => {
setDrawingDone(true);
}, [setDrawingDone]);
useEffect(() => {
draw();
}, [draw]);
return (
<View
style={styles.container}
onLayout={event => {
const {layout} = event.nativeEvent;
setCanvasSize(Math.min(layout?.width || 0, layout?.height || 0));
}}>
<View style={[styles.instruction, {marginTop: insets.top}]}>
<Text style={styles.label}>Write a number</Text>
<Text style={styles.label}>
Let's see if the AI model will get it right
</Text>
</View>
<Canvas
style={{
height: canvasSize,
width: canvasSize,
}}
onContext2D={setCtx}
onTouchMove={handleMove}
onTouchStart={handleStart}
onTouchEnd={handleEnd}
/>
{drawingDone && (
<View style={[styles.resultView]} pointerEvents="none">
<Text style={[styles.label, styles.secondary]}>
Highest confidence will go here
</Text>
<Text style={[styles.label, styles.secondary]}>
Second highest will go here
</Text>
</View>
)}
</View>
);
}
const styles = StyleSheet.create({
container: {
height: '100%',
width: '100%',
backgroundColor: '#180b3b',
justifyContent: 'center',
alignItems: 'center',
},
resultView: {
position: 'absolute',
bottom: 0,
alignSelf: 'flex-start',
flexDirection: 'column',
padding: 15,
},
instruction: {
position: 'absolute',
top: 0,
alignSelf: 'flex-start',
flexDirection: 'column',
padding: 15,
},
label: {
fontSize: 16,
color: '#ffffff',
},
secondary: {
color: '#ffffff99',
},
});
With this second hook, we are ready to run our model with the user created drawings. Let's hook it up in the next section.
Running the Model & Displaying Results
While we add a decent amount of lines in this section, they are all simple changes.
Let's cut to the summary:
- Create a type called
NumberLabelSetso we know what kind of data we have access to about a number. - Create a list of
NumberLabelSetobjects and call itnumLabels. - Get the
classifymethod andresultstate variable by callinguseMNISTCanvasInferencefrom within our demo component. - Update the
handleEndfunction to check for a canvas context and then trigger the model. - Add
classifyas a dependency to thehandleEndcallback function. - Change the text in the results section to reflect the numbers from the model output.
- Changes
- Entire File
@@ -97,6 +97,54 @@
};
}
+type NumberLabelSet = {
+ english: string;
+ asciiSymbol: string;
+};
+
+const numLabels: NumberLabelSet[] = [
+ {
+ english: 'zero',
+ asciiSymbol: '🄌',
+ },
+ {
+ english: 'one',
+ asciiSymbol: '➊',
+ },
+ {
+ english: 'two',
+ asciiSymbol: '➋',
+ },
+ {
+ english: 'three',
+ asciiSymbol: '➌',
+ },
+ {
+ english: 'four',
+ asciiSymbol: '➍',
+ },
+ {
+ english: 'five',
+ asciiSymbol: '➎',
+ },
+ {
+ english: 'six',
+ asciiSymbol: '➏',
+ },
+ {
+ english: 'seven',
+ asciiSymbol: '➐',
+ },
+ {
+ english: 'eight',
+ asciiSymbol: '➑',
+ },
+ {
+ english: 'nine',
+ asciiSymbol: '➒',
+ },
+];
+
export default function MNISTDemo() {
// Get safe area insets to account for notches, etc.
const insets = useSafeAreaInsets();
@@ -105,6 +153,8 @@
// `ctx` is drawing context to draw shapes
const [ctx, setCtx] = useState<CanvasRenderingContext2D>();
+ const {classify, result} = useMNISTCanvasInference(canvasSize);
+
const trailRef = useRef<TrailPoint[]>([]);
const [drawingDone, setDrawingDone] = useState(false);
const animationHandleRef = useRef<number | null>(null);
@@ -173,7 +223,8 @@
const handleEnd = useCallback(() => {
setDrawingDone(true);
- }, [setDrawingDone]);
+ if (ctx != null) classify(ctx);
+ }, [setDrawingDone, classify, ctx]);
useEffect(() => {
draw();
@@ -205,10 +256,16 @@
{drawingDone && (
<View style={[styles.resultView]} pointerEvents="none">
<Text style={[styles.label, styles.secondary]}>
- Highest confidence will go here
+ {result &&
+ `${numLabels[result[0].num].asciiSymbol} it looks like ${
+ numLabels[result[0].num].english
+ }`}
</Text>
<Text style={[styles.label, styles.secondary]}>
- Second highest will go here
+ {result &&
+ `${numLabels[result[1].num].asciiSymbol} or it might be ${
+ numLabels[result[1].num].english
+ }`}
</Text>
</View>
)}
import React, {useCallback, useEffect, useState, useRef} from 'react';
import {StyleSheet, Text, View} from 'react-native';
import {
Canvas,
CanvasRenderingContext2D,
Image,
ImageUtil,
MobileModel,
} from 'react-native-pytorch-core';
import {useSafeAreaInsets} from 'react-native-safe-area-context';
const COLOR_CANVAS_BACKGROUND = '#4F25C6';
const COLOR_TRAIL_STROKE = '#FFFFFF';
type TrailPoint = {
x: number;
y: number;
};
// This is the custom model you have trained. See the tutorial for more on preparing a PyTorch model for mobile.
const mnistModel = require('../../models/mnist.ptl');
type MNISTResult = {
num: number;
score: number;
};
/**
* The React hook provides MNIST model inference on an input image.
*/
function useMNISTModel() {
const processImage = useCallback(async (image: Image) => {
// Runs model inference on input image
const {
result: {scores},
} = await MobileModel.execute<{scores: number[]}>(mnistModel, {
image,
crop_width: 1,
crop_height: 1,
scale_width: 28,
scale_height: 28,
colorBackground: COLOR_CANVAS_BACKGROUND,
colorForeground: COLOR_TRAIL_STROKE,
});
// Get the score of each number (index), and sort the array by the most likely first.
const sortedScore: MNISTResult[] = scores
.map((score, index) => ({score: score, num: index}))
.sort((a, b) => b.score - a.score);
return sortedScore;
}, []);
return {
processImage,
};
}
/**
* The React hook provides MNIST inference using the image data extracted from
* a canvas.
*
* @param canvasSize The size of the square canvas
*/
function useMNISTCanvasInference(canvasSize: number) {
const [result, setResult] = useState<MNISTResult[]>();
const {processImage} = useMNISTModel();
const classify = useCallback(
async (ctx: CanvasRenderingContext2D) => {
// Return immediately if canvas is size 0
if (canvasSize === 0) {
return null;
}
// Get image data center crop
const imageData = await ctx.getImageData(0, 0, canvasSize, canvasSize);
// Convert image data to image.
const image: Image = await ImageUtil.fromImageData(imageData);
// Release image data to free memory
imageData.release();
// Run MNIST inference on the image
const result = await processImage(image);
// Release image to free memory
image.release();
// Set result state to force re-render of component that uses this hook
setResult(result);
},
[canvasSize, processImage, setResult],
);
return {
result,
classify,
};
}
type NumberLabelSet = {
english: string;
asciiSymbol: string;
};
const numLabels: NumberLabelSet[] = [
{
english: 'zero',
asciiSymbol: '🄌',
},
{
english: 'one',
asciiSymbol: '➊',
},
{
english: 'two',
asciiSymbol: '➋',
},
{
english: 'three',
asciiSymbol: '➌',
},
{
english: 'four',
asciiSymbol: '➍',
},
{
english: 'five',
asciiSymbol: '➎',
},
{
english: 'six',
asciiSymbol: '➏',
},
{
english: 'seven',
asciiSymbol: '➐',
},
{
english: 'eight',
asciiSymbol: '➑',
},
{
english: 'nine',
asciiSymbol: '➒',
},
];
export default function MNISTDemo() {
// Get safe area insets to account for notches, etc.
const insets = useSafeAreaInsets();
const [canvasSize, setCanvasSize] = useState<number>(0);
// `ctx` is drawing context to draw shapes
const [ctx, setCtx] = useState<CanvasRenderingContext2D>();
const {classify, result} = useMNISTCanvasInference(canvasSize);
const trailRef = useRef<TrailPoint[]>([]);
const [drawingDone, setDrawingDone] = useState(false);
const animationHandleRef = useRef<number | null>(null);
const draw = useCallback(() => {
if (animationHandleRef.current != null) return;
if (ctx != null) {
animationHandleRef.current = requestAnimationFrame(() => {
const trail = trailRef.current;
if (trail != null) {
// fill background by drawing a rect
ctx.fillStyle = COLOR_CANVAS_BACKGROUND;
ctx.fillRect(0, 0, canvasSize, canvasSize);
// Draw the trail
ctx.strokeStyle = COLOR_TRAIL_STROKE;
ctx.lineWidth = 25;
ctx.lineJoin = 'round';
ctx.lineCap = 'round';
ctx.miterLimit = 1;
if (trail.length > 0) {
ctx.beginPath();
ctx.moveTo(trail[0].x, trail[0].y);
for (let i = 1; i < trail.length; i++) {
ctx.lineTo(trail[i].x, trail[i].y);
}
}
ctx.stroke();
// Need to include this at the end, for now.
ctx.invalidate();
animationHandleRef.current = null;
}
});
}
}, [animationHandleRef, ctx, canvasSize, trailRef]);
// handlers for touch events
const handleMove = useCallback(
async event => {
const position: TrailPoint = {
x: event.nativeEvent.locationX,
y: event.nativeEvent.locationY,
};
const trail = trailRef.current;
if (trail.length > 0) {
const lastPosition = trail[trail.length - 1];
const dx = position.x - lastPosition.x;
const dy = position.y - lastPosition.y;
// add a point to trail if distance from last point > 5
if (dx * dx + dy * dy > 25) {
trail.push(position);
}
} else {
trail.push(position);
}
draw();
},
[trailRef, draw],
);
const handleStart = useCallback(() => {
setDrawingDone(false);
trailRef.current = [];
}, [trailRef, setDrawingDone]);
const handleEnd = useCallback(() => {
setDrawingDone(true);
if (ctx != null) classify(ctx);
}, [setDrawingDone, classify, ctx]);
useEffect(() => {
draw();
}, [draw]);
return (
<View
style={styles.container}
onLayout={event => {
const {layout} = event.nativeEvent;
setCanvasSize(Math.min(layout?.width || 0, layout?.height || 0));
}}>
<View style={[styles.instruction, {marginTop: insets.top}]}>
<Text style={styles.label}>Write a number</Text>
<Text style={styles.label}>
Let's see if the AI model will get it right
</Text>
</View>
<Canvas
style={{
height: canvasSize,
width: canvasSize,
}}
onContext2D={setCtx}
onTouchMove={handleMove}
onTouchStart={handleStart}
onTouchEnd={handleEnd}
/>
{drawingDone && (
<View style={[styles.resultView]} pointerEvents="none">
<Text style={[styles.label, styles.secondary]}>
{result &&
`${numLabels[result[0].num].asciiSymbol} it looks like ${
numLabels[result[0].num].english
}`}
</Text>
<Text style={[styles.label, styles.secondary]}>
{result &&
`${numLabels[result[1].num].asciiSymbol} or it might be ${
numLabels[result[1].num].english
}`}
</Text>
</View>
)}
</View>
);
}
const styles = StyleSheet.create({
container: {
height: '100%',
width: '100%',
backgroundColor: '#180b3b',
justifyContent: 'center',
alignItems: 'center',
},
resultView: {
position: 'absolute',
bottom: 0,
alignSelf: 'flex-start',
flexDirection: 'column',
padding: 15,
},
instruction: {
position: 'absolute',
top: 0,
alignSelf: 'flex-start',
flexDirection: 'column',
padding: 15,
},
label: {
fontSize: 16,
color: '#ffffff',
},
secondary: {
color: '#ffffff99',
},
});
When you run the code, you should see it display results properly in the bottom left corner like the screen recording below.
- Android
- iOS (Simulator)
npx torchlive-cli run-android

npx torchlive-cli run-ios

And with that we have a working MNIST classifier!