OR-1 dataflow CPU sketch
1#!/usr/bin/env bash
2set -euo pipefail
3
4# Pre-install Python packages for the NBCU interview environment.
5# Run this BEFORE the interview — torch alone is ~2.5GB.
6
7if [ ! -d ".venv" ]; then
8 echo "No .venv found. Enter the devshell first: nix develop"
9 exit 1
10fi
11
12source .venv/bin/activate
13
14echo "Installing core ML stack (trying ROCm torch first, CPU fallback)..."
15echo " Attempting ROCm 6.4 wheels..."
16if uv pip install \
17 torch \
18 torchvision \
19 --index-url https://download.pytorch.org/whl/rocm6.4; then
20 echo " ROCm wheels installed. Checking GPU visibility..."
21 if python -c "import torch; print(f' GPU available: {torch.cuda.is_available()}'); assert torch.cuda.is_available()"; then
22 echo " ROCm torch with GPU acceleration ready."
23 else
24 echo ""
25 echo " ROCm torch installed but GPU not detected."
26 echo " To debug GPU detection later:"
27 echo " - Check /dev/kfd exists and is accessible"
28 echo " - Check user is in 'video' and 'render' groups"
29 echo " - Run: HSA_OVERRIDE_GFX_VERSION=10.3.0 python -c \"import torch; print(torch.cuda.is_available())\""
30 echo ""
31 fi
32else
33 echo " ROCm wheel install failed. Falling back to CPU torch."
34 uv pip install torch torchvision
35fi
36
37uv pip install \
38 numpy \
39 scipy \
40 scikit-learn \
41 matplotlib \
42 Pillow
43
44echo ""
45echo "Installing geospatial stack..."
46uv pip install \
47 rasterio \
48 laspy[lazrs] \
49 opencv-python-headless \
50 shapely \
51 geopandas \
52 fiona
53
54echo ""
55echo "Installing extras (seaborn for confusion matrix plots, etc.)..."
56uv pip install \
57 seaborn \
58 h5py
59
60echo ""
61echo "Verifying critical imports..."
62python -c "
63import torch
64print(f' torch {torch.__version__}')
65if torch.cuda.is_available():
66 print(f' GPU: {torch.cuda.get_device_name(0)}')
67 print(f' ROCm/CUDA: {torch.version.hip or torch.version.cuda}')
68else:
69 print(f' CPU only')
70
71import torchvision
72print(f' torchvision {torchvision.__version__}')
73
74import numpy
75print(f' numpy {numpy.__version__}')
76
77import scipy
78print(f' scipy {scipy.__version__}')
79
80import sklearn
81print(f' sklearn {sklearn.__version__}')
82
83import rasterio
84print(f' rasterio {rasterio.__version__}')
85
86import laspy
87print(f' laspy {laspy.__version__}')
88
89import cv2
90print(f' opencv {cv2.__version__}')
91
92import shapely
93print(f' shapely {shapely.__version__}')
94
95print()
96print('All imports OK.')
97"
98
99echo ""
100echo "Smoke-testing torch inference..."
101python -c "
102import torch
103from torchvision.models.segmentation import deeplabv3_resnet50, DeepLabV3_ResNet50_Weights
104
105weights = DeepLabV3_ResNet50_Weights.DEFAULT
106model = deeplabv3_resnet50(weights=weights)
107model.eval()
108
109x = torch.randint(0, 256, (3, 256, 256), dtype=torch.uint8)
110x = weights.transforms()(x).unsqueeze(0)
111
112with torch.no_grad():
113 out = model(x)['out']
114
115print(f' DeepLabV3 inference OK — output shape: {out.shape}')
116print(f' Predicted classes: {torch.unique(out.argmax(dim=1)).tolist()}')
117"
118
119echo ""
120echo "Smoke-testing rasterio + laspy..."
121python -c "
122import rasterio
123import numpy as np
124from rasterio.transform import from_bounds
125
126# Write and read a tiny GeoTIFF
127data = np.random.randint(0, 255, (3, 64, 64), dtype=np.uint8)
128transform = from_bounds(-73.58, 45.49, -73.55, 45.52, 64, 64)
129
130with rasterio.open(
131 '/tmp/_test_rio.tif', 'w', driver='GTiff',
132 height=64, width=64, count=3, dtype='uint8',
133 crs='EPSG:4326', transform=transform,
134) as dst:
135 dst.write(data)
136
137with rasterio.open('/tmp/_test_rio.tif') as src:
138 print(f' rasterio read/write OK — CRS: {src.crs}, shape: {src.shape}')
139
140import laspy
141header = laspy.LasHeader(point_format=0, version='1.2')
142header.offsets = [0, 0, 0]
143header.scales = [0.001, 0.001, 0.001]
144las = laspy.LasData(header)
145las.x = np.random.uniform(0, 100, 1000)
146las.y = np.random.uniform(0, 100, 1000)
147las.z = np.random.uniform(0, 50, 1000)
148las.write('/tmp/_test_las.las')
149las2 = laspy.read('/tmp/_test_las.las')
150print(f' laspy read/write OK — {las2.header.point_count} points')
151
152import os
153os.remove('/tmp/_test_rio.tif')
154os.remove('/tmp/_test_las.las')
155"
156
157echo ""
158echo "========================================="
159echo " Environment ready."
160echo "========================================="
161echo ""